Skip to content

Commit 4f6df07

Browse files
committed
lint
Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent 3fd7546 commit 4f6df07

12 files changed

Lines changed: 85 additions & 184 deletions

File tree

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,21 +264,16 @@ def _run_cmd(self, cmd: str, devices_info: str = "") -> subprocess.CompletedProc
264264
look_up_option(self.device_setting["MN_START_METHOD"], ["bcprun"])
265265
except ValueError as err:
266266
raise NotImplementedError(
267-
f"{self.device_setting['MN_START_METHOD']} is not supported yet."
268-
"Try modify BundleAlgo._run_cmd for your cluster."
267+
f"{self.device_setting['MN_START_METHOD']} is not supported yet.Try modify BundleAlgo._run_cmd for your cluster."
269268
) from err
270269

271270
return _run_cmd_bcprun(cmd, n=self.device_setting["NUM_NODES"], p=self.device_setting["n_devices"])
272271
elif int(self.device_setting["n_devices"]) > 1:
273-
return _run_cmd_torchrun(
274-
cmd, nnodes=1, nproc_per_node=self.device_setting["n_devices"], env=ps_environ, check=True
275-
)
272+
return _run_cmd_torchrun(cmd, nnodes=1, nproc_per_node=self.device_setting["n_devices"], env=ps_environ, check=True)
276273
else:
277274
return run_cmd(cmd.split(), run_cmd_verbose=True, env=ps_environ, check=True)
278275

279-
def train(
280-
self, train_params: None | dict = None, device_setting: None | dict = None
281-
) -> subprocess.CompletedProcess:
276+
def train(self, train_params: None | dict = None, device_setting: None | dict = None) -> subprocess.CompletedProcess:
282277
"""
283278
Load the run function in the training script of each model. Training parameter is predefined by the
284279
algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.
@@ -369,9 +364,7 @@ def get_output_path(self):
369364

370365

371366
# path to download the algo_templates
372-
default_algo_zip = (
373-
f"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz"
374-
)
367+
default_algo_zip = f"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz"
375368

376369
# default algorithms
377370
default_algos = {
@@ -396,7 +389,7 @@ def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:
396389
try:
397390
download_and_extract(url=url, filepath=algo_compressed_file, output_dir=os.path.dirname(at_path))
398391
except Exception as e:
399-
msg = f"Download and extract of {url} failed, attempt {i+1}/{download_attempts}."
392+
msg = f"Download and extract of {url} failed, attempt {i + 1}/{download_attempts}."
400393
if i < download_attempts - 1:
401394
warnings.warn(msg)
402395
time.sleep(i)
@@ -660,6 +653,7 @@ def generate(
660653
gen_algo.export_to_disk(output_folder, name, fold=f_id)
661654

662655
algo_to_pickle(gen_algo, template_path=algo.template_path)
663-
self.history.append(
664-
{AlgoKeys.ID: name, AlgoKeys.ALGO: gen_algo}
665-
) # track the previous, may create a persistent history
656+
self.history.append({
657+
AlgoKeys.ID: name,
658+
AlgoKeys.ALGO: gen_algo,
659+
}) # track the previous, may create a persistent history

monai/apps/detection/networks/retinanet_detector.py

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@
5959
from monai.networks.nets import resnet
6060
from monai.utils import BlendMode, PytorchPadMode, ensure_tuple_rep, optional_import
6161

62-
BalancedPositiveNegativeSampler, _ = optional_import(
63-
"torchvision.models.detection._utils", name="BalancedPositiveNegativeSampler"
64-
)
62+
BalancedPositiveNegativeSampler, _ = optional_import("torchvision.models.detection._utils", name="BalancedPositiveNegativeSampler")
6563
Matcher, _ = optional_import("torchvision.models.detection._utils", name="Matcher")
6664

6765

@@ -328,9 +326,7 @@ def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_p
328326
self.encode_gt = encode_gt
329327
self.decode_pred = decode_pred
330328

331-
def set_regular_matcher(
332-
self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True
333-
) -> None:
329+
def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True) -> None:
334330
"""
335331
Using for training. Set torchvision matcher that matches anchors with ground truth boxes.
336332
@@ -342,12 +338,9 @@ def set_regular_matcher(
342338
"""
343339
if fg_iou_thresh < bg_iou_thresh:
344340
raise ValueError(
345-
"Require fg_iou_thresh >= bg_iou_thresh. "
346-
f"Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
341+
f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
347342
)
348-
self.proposal_matcher = Matcher(
349-
fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches
350-
)
343+
self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches)
351344

352345
def set_atss_matcher(self, num_candidates: int = 4, center_in_gt: bool = False) -> None:
353346
"""
@@ -496,9 +489,7 @@ def forward(
496489
"""
497490
# 1. Check if input arguments are valid
498491
if self.training:
499-
targets = check_training_targets(
500-
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
501-
)
492+
targets = check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
502493
self._check_detector_training_components()
503494

504495
# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
@@ -518,12 +509,8 @@ def forward(
518509
ensure_dict_value_to_list_(head_outputs)
519510
else:
520511
if self.inferer is None:
521-
raise ValueError(
522-
"`self.inferer` is not defined." "Please refer to function self.set_sliding_window_inferer(*)."
523-
)
524-
head_outputs = predict_with_inferer(
525-
images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer
526-
)
512+
raise ValueError("`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*).")
513+
head_outputs = predict_with_inferer(images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer)
527514

528515
# 4. Generate anchors and store it in self.anchors: List[Tensor]
529516
self.generate_anchors(images, head_outputs)
@@ -545,7 +532,10 @@ def forward(
545532

546533
# 6(2). If during inference, return detection results
547534
detections = self.postprocess_detections(
548-
head_outputs, self.anchors, image_sizes, num_anchor_locs_per_level # type: ignore
535+
head_outputs,
536+
self.anchors,
537+
image_sizes,
538+
num_anchor_locs_per_level, # type: ignore
549539
)
550540
return detections
551541

@@ -554,9 +544,7 @@ def _check_detector_training_components(self):
554544
Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.
555545
"""
556546
if not hasattr(self, "proposal_matcher"):
557-
raise AttributeError(
558-
"Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)."
559-
)
547+
raise AttributeError("Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*).")
560548
if self.fg_bg_sampler is None and self.debug:
561549
warnings.warn(
562550
"No balanced sampler is used. Negative samples are likely to "
@@ -653,9 +641,7 @@ def postprocess_detections(
653641
"""
654642

655643
# recover level sizes, HWA or HWDA for each level
656-
num_anchors_per_level = [
657-
num_anchor_locs * self.num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level
658-
]
644+
num_anchors_per_level = [num_anchor_locs * self.num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level]
659645

660646
# split outputs per level
661647
split_head_outputs: dict[str, list[Tensor]] = {}
@@ -672,9 +658,7 @@ def postprocess_detections(
672658
detections: list[dict[str, Tensor]] = []
673659

674660
for index in range(num_images):
675-
box_regression_per_image = [
676-
br[index] for br in box_regression
677-
] # List[Tensor], each sized (HWA, 2*spatial_dims)
661+
box_regression_per_image = [br[index] for br in box_regression] # List[Tensor], each sized (HWA, 2*spatial_dims)
678662
logits_per_image = [cl[index] for cl in class_logits] # List[Tensor], each sized (HWA, self.num_classes)
679663
anchors_per_image, img_spatial_size = split_anchors[index], image_sizes[index]
680664
# decode box regression into boxes
@@ -687,13 +671,11 @@ def postprocess_detections(
687671
boxes_per_image, logits_per_image, img_spatial_size
688672
)
689673

690-
detections.append(
691-
{
692-
self.target_box_key: selected_boxes, # Tensor, sized (N, 2*spatial_dims)
693-
self.pred_score_key: selected_scores, # Tensor, sized (N, )
694-
self.target_label_key: selected_labels, # Tensor, sized (N, )
695-
}
696-
)
674+
detections.append({
675+
self.target_box_key: selected_boxes, # Tensor, sized (N, 2*spatial_dims)
676+
self.pred_score_key: selected_scores, # Tensor, sized (N, )
677+
self.target_label_key: selected_labels, # Tensor, sized (N, )
678+
})
697679

698680
return detections
699681

@@ -722,9 +704,7 @@ def compute_loss(
722704
"""
723705
matched_idxs = self.compute_anchor_matched_idxs(anchors, targets, num_anchor_locs_per_level)
724706
losses_cls = self.compute_cls_loss(head_outputs_reshape[self.cls_key], targets, matched_idxs)
725-
losses_box_regression = self.compute_box_loss(
726-
head_outputs_reshape[self.box_reg_key], targets, anchors, matched_idxs
727-
)
707+
losses_box_regression = self.compute_box_loss(head_outputs_reshape[self.box_reg_key], targets, anchors, matched_idxs)
728708
return {self.cls_key: losses_cls, self.box_reg_key: losses_box_regression}
729709

730710
def compute_anchor_matched_idxs(
@@ -757,9 +737,7 @@ def compute_anchor_matched_idxs(
757737
# anchors_per_image: Tensor, targets_per_image: Dice[str, Tensor]
758738
if targets_per_image[self.target_box_key].numel() == 0:
759739
# if no GT boxes
760-
matched_idxs.append(
761-
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
762-
)
740+
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device))
763741
continue
764742

765743
# matched_idxs_per_image (Tensor[int64]): Tensor sized (sum(HWA),) or (sum(HWDA),)
@@ -787,7 +765,7 @@ def compute_anchor_matched_idxs(
787765
)
788766

789767
if self.debug:
790-
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix,dim=1)[0]}.")
768+
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix, dim=1)[0]}.")
791769

792770
if torch.max(matched_idxs_per_image) < 0:
793771
warnings.warn(
@@ -799,9 +777,7 @@ def compute_anchor_matched_idxs(
799777
matched_idxs.append(matched_idxs_per_image)
800778
return matched_idxs
801779

802-
def compute_cls_loss(
803-
self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor]
804-
) -> Tensor:
780+
def compute_cls_loss(self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor]) -> Tensor:
805781
"""
806782
Compute classification losses.
807783
@@ -919,9 +895,7 @@ def get_cls_train_sample_per_image(
919895
gt_classes_target = torch.zeros_like(cls_logits_per_image) # (sum(HW(D)A), self.num_classes)
920896
gt_classes_target[
921897
foreground_idxs_per_image, # fg anchor idx in
922-
targets_per_image[self.target_label_key][
923-
matched_idxs_per_image[foreground_idxs_per_image]
924-
], # fg class label
898+
targets_per_image[self.target_label_key][matched_idxs_per_image[foreground_idxs_per_image]], # fg class label
925899
] = 1.0
926900

927901
if self.fg_bg_sampler is None:
@@ -993,9 +967,9 @@ def get_box_train_sample_per_image(
993967

994968
# select only the foreground boxes
995969
# matched GT boxes for foreground anchors
996-
matched_gt_boxes_per_image = targets_per_image[self.target_box_key][
997-
matched_idxs_per_image[foreground_idxs_per_image]
998-
].to(box_regression_per_image.device)
970+
matched_gt_boxes_per_image = targets_per_image[self.target_box_key][matched_idxs_per_image[foreground_idxs_per_image]].to(
971+
box_regression_per_image.device
972+
)
999973
# predicted box regression for foreground anchors
1000974
box_regression_per_image = box_regression_per_image[foreground_idxs_per_image, :]
1001975
# foreground anchors

monai/apps/detection/utils/anchor_utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,7 @@ def __init__(
136136
self.indexing = look_up_option(indexing, ["ij", "xy"])
137137

138138
self.aspect_ratios = aspect_ratios
139-
self.cell_anchors = [
140-
self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)
141-
]
139+
self.cell_anchors = [self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)]
142140

143141
# This comment comes from torchvision.
144142
# TODO: https://github.com/pytorch/pytorch/issues/26792
@@ -174,13 +172,13 @@ def generate_anchors(
174172
if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):
175173
raise ValueError(
176174
f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
177-
{len(aspect_ratios_t.shape)-1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
175+
{len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
178176
)
179177

180178
if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):
181179
raise ValueError(
182180
f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
183-
shape (_,{self.spatial_dims-1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
181+
shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
184182
)
185183

186184
# if 2d, w:h = 1:aspect_ratios
@@ -253,8 +251,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253251
# compute anchor centers regarding to the image.
254252
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255253
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
257-
for axis in range(self.spatial_dims)
254+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] for axis in range(self.spatial_dims)
258255
]
259256

260257
# to support torchscript, cannot directly use torch.meshgrid(shifts_centers).
@@ -307,10 +304,7 @@ def forward(self, images: Tensor, feature_maps: list[Tensor]) -> list[Tensor]:
307304
batchsize = images.shape[0]
308305
dtype, device = feature_maps[0].dtype, feature_maps[0].device
309306
strides = [
310-
[
311-
torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device)
312-
for axis in range(self.spatial_dims)
313-
]
307+
[torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device) for axis in range(self.spatial_dims)]
314308
for g in grid_sizes
315309
]
316310

monai/apps/detection/utils/detector_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def check_training_targets(
8080
for i in range(len(targets)):
8181
target = targets[i]
8282
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
83-
raise ValueError(
84-
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
85-
)
83+
raise ValueError(f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}.")
8684

8785
boxes = target[target_box_key]
8886
if not isinstance(boxes, torch.Tensor):
@@ -91,12 +89,10 @@ def check_training_targets(
9189
if boxes.numel() == 0:
9290
warnings.warn(
9391
f"Warning: Given target boxes has shape of {boxes.shape}. "
94-
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
92+
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2 * spatial_dims}])."
9593
)
9694
else:
97-
raise ValueError(
98-
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
99-
)
95+
raise ValueError(f"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.).")
10096
if not torch.is_floating_point(boxes):
10197
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
10298
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore

0 commit comments

Comments
 (0)