5959from monai .networks .nets import resnet
6060from monai .utils import BlendMode , PytorchPadMode , ensure_tuple_rep , optional_import
6161
62- BalancedPositiveNegativeSampler , _ = optional_import (
63- "torchvision.models.detection._utils" , name = "BalancedPositiveNegativeSampler"
64- )
62+ BalancedPositiveNegativeSampler , _ = optional_import ("torchvision.models.detection._utils" , name = "BalancedPositiveNegativeSampler" )
6563Matcher , _ = optional_import ("torchvision.models.detection._utils" , name = "Matcher" )
6664
6765
@@ -328,9 +326,7 @@ def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_p
328326 self .encode_gt = encode_gt
329327 self .decode_pred = decode_pred
330328
331- def set_regular_matcher (
332- self , fg_iou_thresh : float , bg_iou_thresh : float , allow_low_quality_matches : bool = True
333- ) -> None :
329+ def set_regular_matcher (self , fg_iou_thresh : float , bg_iou_thresh : float , allow_low_quality_matches : bool = True ) -> None :
334330 """
335331 Using for training. Set torchvision matcher that matches anchors with ground truth boxes.
336332
@@ -342,12 +338,9 @@ def set_regular_matcher(
342338 """
343339 if fg_iou_thresh < bg_iou_thresh :
344340 raise ValueError (
345- "Require fg_iou_thresh >= bg_iou_thresh. "
346- f"Got fg_iou_thresh={ fg_iou_thresh } , bg_iou_thresh={ bg_iou_thresh } ."
341+ f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={ fg_iou_thresh } , bg_iou_thresh={ bg_iou_thresh } ."
347342 )
348- self .proposal_matcher = Matcher (
349- fg_iou_thresh , bg_iou_thresh , allow_low_quality_matches = allow_low_quality_matches
350- )
343+ self .proposal_matcher = Matcher (fg_iou_thresh , bg_iou_thresh , allow_low_quality_matches = allow_low_quality_matches )
351344
352345 def set_atss_matcher (self , num_candidates : int = 4 , center_in_gt : bool = False ) -> None :
353346 """
@@ -496,9 +489,7 @@ def forward(
496489 """
497490 # 1. Check if input arguments are valid
498491 if self .training :
499- targets = check_training_targets (
500- input_images , targets , self .spatial_dims , self .target_label_key , self .target_box_key
501- )
492+ targets = check_training_targets (input_images , targets , self .spatial_dims , self .target_label_key , self .target_box_key )
502493 self ._check_detector_training_components ()
503494
504495 # 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
@@ -518,12 +509,8 @@ def forward(
518509 ensure_dict_value_to_list_ (head_outputs )
519510 else :
520511 if self .inferer is None :
521- raise ValueError (
522- "`self.inferer` is not defined." "Please refer to function self.set_sliding_window_inferer(*)."
523- )
524- head_outputs = predict_with_inferer (
525- images , self .network , keys = [self .cls_key , self .box_reg_key ], inferer = self .inferer
526- )
512+ raise ValueError ("`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*)." )
513+ head_outputs = predict_with_inferer (images , self .network , keys = [self .cls_key , self .box_reg_key ], inferer = self .inferer )
527514
528515 # 4. Generate anchors and store it in self.anchors: List[Tensor]
529516 self .generate_anchors (images , head_outputs )
@@ -545,7 +532,10 @@ def forward(
545532
546533 # 6(2). If during inference, return detection results
547534 detections = self .postprocess_detections (
548- head_outputs , self .anchors , image_sizes , num_anchor_locs_per_level # type: ignore
535+ head_outputs ,
536+ self .anchors ,
537+ image_sizes ,
538+ num_anchor_locs_per_level , # type: ignore
549539 )
550540 return detections
551541
@@ -554,9 +544,7 @@ def _check_detector_training_components(self):
554544 Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.
555545 """
556546 if not hasattr (self , "proposal_matcher" ):
557- raise AttributeError (
558- "Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)."
559- )
547+ raise AttributeError ("Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)." )
560548 if self .fg_bg_sampler is None and self .debug :
561549 warnings .warn (
562550 "No balanced sampler is used. Negative samples are likely to "
@@ -653,9 +641,7 @@ def postprocess_detections(
653641 """
654642
655643 # recover level sizes, HWA or HWDA for each level
656- num_anchors_per_level = [
657- num_anchor_locs * self .num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level
658- ]
644+ num_anchors_per_level = [num_anchor_locs * self .num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level ]
659645
660646 # split outputs per level
661647 split_head_outputs : dict [str , list [Tensor ]] = {}
@@ -672,9 +658,7 @@ def postprocess_detections(
672658 detections : list [dict [str , Tensor ]] = []
673659
674660 for index in range (num_images ):
675- box_regression_per_image = [
676- br [index ] for br in box_regression
677- ] # List[Tensor], each sized (HWA, 2*spatial_dims)
661+ box_regression_per_image = [br [index ] for br in box_regression ] # List[Tensor], each sized (HWA, 2*spatial_dims)
678662 logits_per_image = [cl [index ] for cl in class_logits ] # List[Tensor], each sized (HWA, self.num_classes)
679663 anchors_per_image , img_spatial_size = split_anchors [index ], image_sizes [index ]
680664 # decode box regression into boxes
@@ -687,13 +671,11 @@ def postprocess_detections(
687671 boxes_per_image , logits_per_image , img_spatial_size
688672 )
689673
690- detections .append (
691- {
692- self .target_box_key : selected_boxes , # Tensor, sized (N, 2*spatial_dims)
693- self .pred_score_key : selected_scores , # Tensor, sized (N, )
694- self .target_label_key : selected_labels , # Tensor, sized (N, )
695- }
696- )
674+ detections .append ({
675+ self .target_box_key : selected_boxes , # Tensor, sized (N, 2*spatial_dims)
676+ self .pred_score_key : selected_scores , # Tensor, sized (N, )
677+ self .target_label_key : selected_labels , # Tensor, sized (N, )
678+ })
697679
698680 return detections
699681
@@ -722,9 +704,7 @@ def compute_loss(
722704 """
723705 matched_idxs = self .compute_anchor_matched_idxs (anchors , targets , num_anchor_locs_per_level )
724706 losses_cls = self .compute_cls_loss (head_outputs_reshape [self .cls_key ], targets , matched_idxs )
725- losses_box_regression = self .compute_box_loss (
726- head_outputs_reshape [self .box_reg_key ], targets , anchors , matched_idxs
727- )
707+ losses_box_regression = self .compute_box_loss (head_outputs_reshape [self .box_reg_key ], targets , anchors , matched_idxs )
728708 return {self .cls_key : losses_cls , self .box_reg_key : losses_box_regression }
729709
730710 def compute_anchor_matched_idxs (
@@ -757,9 +737,7 @@ def compute_anchor_matched_idxs(
757737 # anchors_per_image: Tensor, targets_per_image: Dice[str, Tensor]
758738 if targets_per_image [self .target_box_key ].numel () == 0 :
759739 # if no GT boxes
760- matched_idxs .append (
761- torch .full ((anchors_per_image .size (0 ),), - 1 , dtype = torch .int64 , device = anchors_per_image .device )
762- )
740+ matched_idxs .append (torch .full ((anchors_per_image .size (0 ),), - 1 , dtype = torch .int64 , device = anchors_per_image .device ))
763741 continue
764742
765743 # matched_idxs_per_image (Tensor[int64]): Tensor sized (sum(HWA),) or (sum(HWDA),)
@@ -787,7 +765,7 @@ def compute_anchor_matched_idxs(
787765 )
788766
789767 if self .debug :
790- print (f"Max box overlap between anchors and gt boxes: { torch .max (match_quality_matrix ,dim = 1 )[0 ]} ." )
768+ print (f"Max box overlap between anchors and gt boxes: { torch .max (match_quality_matrix , dim = 1 )[0 ]} ." )
791769
792770 if torch .max (matched_idxs_per_image ) < 0 :
793771 warnings .warn (
@@ -799,9 +777,7 @@ def compute_anchor_matched_idxs(
799777 matched_idxs .append (matched_idxs_per_image )
800778 return matched_idxs
801779
802- def compute_cls_loss (
803- self , cls_logits : Tensor , targets : list [dict [str , Tensor ]], matched_idxs : list [Tensor ]
804- ) -> Tensor :
780+ def compute_cls_loss (self , cls_logits : Tensor , targets : list [dict [str , Tensor ]], matched_idxs : list [Tensor ]) -> Tensor :
805781 """
806782 Compute classification losses.
807783
@@ -919,9 +895,7 @@ def get_cls_train_sample_per_image(
919895 gt_classes_target = torch .zeros_like (cls_logits_per_image ) # (sum(HW(D)A), self.num_classes)
920896 gt_classes_target [
921897 foreground_idxs_per_image , # fg anchor idx in
922- targets_per_image [self .target_label_key ][
923- matched_idxs_per_image [foreground_idxs_per_image ]
924- ], # fg class label
898+ targets_per_image [self .target_label_key ][matched_idxs_per_image [foreground_idxs_per_image ]], # fg class label
925899 ] = 1.0
926900
927901 if self .fg_bg_sampler is None :
@@ -993,9 +967,9 @@ def get_box_train_sample_per_image(
993967
994968 # select only the foreground boxes
995969 # matched GT boxes for foreground anchors
996- matched_gt_boxes_per_image = targets_per_image [self .target_box_key ][
997- matched_idxs_per_image [ foreground_idxs_per_image ]
998- ]. to ( box_regression_per_image . device )
970+ matched_gt_boxes_per_image = targets_per_image [self .target_box_key ][matched_idxs_per_image [ foreground_idxs_per_image ]]. to (
971+ box_regression_per_image . device
972+ )
999973 # predicted box regression for foreground anchors
1000974 box_regression_per_image = box_regression_per_image [foreground_idxs_per_image , :]
1001975 # foreground anchors
0 commit comments