Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 118 additions & 12 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
def _input_validator(
preds: Sequence[dict[str, Tensor]],
targets: Sequence[dict[str, Tensor]],
iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = "bbox",
iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox",
ignore_score: bool = False,
) -> None:
"""Ensure the correct input format of `preds` and `targets`."""
if isinstance(iou_type, str):
iou_type = (iou_type,)

name_map = {"bbox": "boxes", "segm": "masks"}
name_map = {"bbox": "boxes", "segm": "masks", "keypoints": "keypoints"}
if any(tp not in name_map for tp in iou_type):
raise Exception(f"IOU type {iou_type} is not supported")
item_val_name = [name_map[tp] for tp in iou_type]
Expand Down Expand Up @@ -110,16 +110,18 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor:


def _validate_iou_type_arg(
iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = "bbox",
) -> tuple[Literal["bbox", "segm"], ...]:
iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox",
) -> tuple[Literal["bbox", "segm", "keypoints"], ...]:
"""Validate that iou type argument is correct."""
allowed_iou_types = ("segm", "bbox")
allowed_iou_types = ("segm", "bbox", "keypoints")
if isinstance(iou_type, str):
iou_type = (iou_type,)
if any(tp not in allowed_iou_types for tp in iou_type):
raise ValueError(
f"Expected argument `iou_type` to be one of {allowed_iou_types} or a tuple of, but got {iou_type}"
)
if "segm" in iou_type and "keypoints" in iou_type:
raise NotImplementedError()
return iou_type


Expand Down Expand Up @@ -194,11 +196,13 @@ def _get_coco_datasets(
groundtruth_labels: List[Tensor],
groundtruth_box: List[Tensor],
groundtruth_mask: List[Tensor],
groundtruth_keypoint: List[Tensor],
groundtruth_crowds: List[Tensor],
groundtruth_area: List[Tensor],
detection_labels: List[Tensor],
detection_box: List[Tensor],
detection_mask: List[Tensor],
detection_keypoint: List[Tensor],
detection_scores: List[Tensor],
iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = ("bbox",),
average: Literal["macro", "micro"] = "micro",
Expand All @@ -221,6 +225,7 @@ def _get_coco_datasets(
labels=groundtruth_labels,
boxes=groundtruth_box if len(groundtruth_box) > 0 else None,
masks=groundtruth_mask if len(groundtruth_mask) > 0 else None,
keypoints=groundtruth_keypoint if len(groundtruth_keypoint) > 0 else None,
crowds=groundtruth_crowds,
area=groundtruth_area,
iou_type=iou_type,
Expand All @@ -231,6 +236,7 @@ def _get_coco_datasets(
labels=detection_labels,
boxes=detection_box if len(detection_box) > 0 else None,
masks=detection_mask if len(detection_mask) > 0 else None,
keypoints=detection_keypoint if len(detection_keypoint) > 0 else None,
scores=detection_scores,
iou_type=iou_type,
all_labels=all_labels,
Expand All @@ -244,10 +250,23 @@ def _get_coco_datasets(
return coco_preds, coco_target

def _coco_stats_to_tensor_dict(
self, stats: list[float], prefix: str, max_detection_thresholds: list[int]
self, stats: list[float], prefix: str, max_detection_thresholds: list[int], iou_type: str
) -> dict[str, Tensor]:
"""Converts the output of COCOeval.stats to a dict of tensors."""
mdt = max_detection_thresholds
if iou_type == "keypoints":
return {
f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32),
f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32),
f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32),
f"{prefix}map_medium": torch.tensor([stats[3]], dtype=torch.float32),
f"{prefix}map_large": torch.tensor([stats[4]], dtype=torch.float32),
f"{prefix}mar_{mdt[0]}": torch.tensor([stats[5]], dtype=torch.float32),
f"{prefix}mar_{mdt[1]}": torch.tensor([stats[6]], dtype=torch.float32),
f"{prefix}mar_{mdt[2]}": torch.tensor([stats[7]], dtype=torch.float32),
f"{prefix}mar_medium": torch.tensor([stats[8]], dtype=torch.float32),
f"{prefix}mar_large": torch.tensor([stats[9]], dtype=torch.float32),
}
return {
f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32),
f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32),
Expand Down Expand Up @@ -319,11 +338,21 @@ def coco_to_tm(
target[t["image_id"]]["boxes"] = []
if "segm" in iou_type:
target[t["image_id"]]["masks"] = []
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(t, "bbox"):
target[t["image_id"]]["boxes"] = []
target[t["image_id"]]["keypoints"] = []
target[t["image_id"]]["num_keypoints"] = []

if "bbox" in iou_type:
target[t["image_id"]]["boxes"].append(t["bbox"])
if "segm" in iou_type:
target[t["image_id"]]["masks"].append(gt.annToMask(t))
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(t, "bbox"):
target[t["image_id"]]["boxes"].append(t["bbox"])
target[t["image_id"]]["keypoints"].append(t["keypoints"])
target[t["image_id"]]["num_keypoints"].append(len(t["keypoints"]) // 3)
target[t["image_id"]]["labels"].append(t["category_id"])
target[t["image_id"]]["iscrowd"].append(t["iscrowd"])
target[t["image_id"]]["area"].append(t["area"])
Expand All @@ -336,19 +365,35 @@ def coco_to_tm(
preds[p["image_id"]]["boxes"] = []
if "segm" in iou_type:
preds[p["image_id"]]["masks"] = []
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(p, "bbox"):
preds[p["image_id"]]["boxes"] = []
preds[p["image_id"]]["keypoints"] = []
preds[p["image_id"]]["num_keypoints"] = []
if "bbox" in iou_type:
preds[p["image_id"]]["boxes"].append(p["bbox"])
if "segm" in iou_type:
preds[p["image_id"]]["masks"].append(gt.annToMask(p))
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(p, "bbox"):
preds[p["image_id"]]["boxes"].append(p["bbox"])
preds[p["image_id"]]["keypoints"].append(p["keypoints"])
preds[p["image_id"]]["num_keypoints"].append(len(p["keypoints"]) // 3)
preds[p["image_id"]]["scores"].append(p["score"])
preds[p["image_id"]]["labels"].append(p["category_id"])

for k in target: # add empty predictions for images without predictions
if k not in preds:
preds[k] = {"scores": [], "labels": []}
if "bbox" in iou_type:
preds[k]["boxes"] = []
if "segm" in iou_type:
preds[k]["masks"] = []
if "keypoints" in iou_type:
preds[k]["keypoints"] = []
preds[k]["num_keypoints"] = []
if "bbox" not in iou_type and hasattr(preds[k], "bbox"):
preds[k]["boxes"] = []

batched_preds, batched_target = [], []
for key in target:
Expand All @@ -360,6 +405,12 @@ def coco_to_tm(
bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32)
if "segm" in iou_type:
bp["masks"] = torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8)
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(bp, "bbox"):
bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32)
bp["keypoints"] = torch.tensor(np.array(preds[key]["keypoints"]), dtype=torch.float32)
bp["num_keypoints"] = torch.tensor(np.array(preds[key]["num_keypoints"]), dtype=torch.int32)

batched_preds.append(bp)

bt = {
Expand All @@ -371,6 +422,11 @@ def coco_to_tm(
bt["boxes"] = torch.tensor(target[key]["boxes"], dtype=torch.float32)
if "segm" in iou_type:
bt["masks"] = torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8)
if "keypoints" in iou_type:
if "bbox" not in iou_type and hasattr(bp, "bbox"):
bt["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32)
bt["keypoints"] = torch.tensor(np.array(target[key]["keypoints"]), dtype=torch.float32)
bt["num_keypoints"] = torch.tensor(np.array(target[key]["num_keypoints"]), dtype=torch.int32)
batched_target.append(bt)

return batched_preds, batched_target
Expand All @@ -380,11 +436,13 @@ def tm_to_coco(
groundtruth_labels: List[Tensor],
groundtruth_box: List[Tensor],
groundtruth_mask: List[Tensor],
groundtruth_keypoint: List[Tensor],
groundtruth_crowds: List[Tensor],
groundtruth_area: List[Tensor],
detection_labels: List[Tensor],
detection_box: List[Tensor],
detection_mask: List[Tensor],
detection_keypoint: List[Tensor],
detection_scores: List[Tensor],
name: str = "tm_map_input",
iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = ("bbox",),
Expand All @@ -400,11 +458,13 @@ def tm_to_coco(
groundtruth_labels: List of tensors containing the ground truth labels
groundtruth_box: List of tensors containing the ground truth bounding boxes
groundtruth_mask: List of tensors containing the ground truth segmentation masks
groundtruth_keypoint: List of tensors containing the ground truth keypoints
groundtruth_crowds: List of tensors indicating whether ground truth annotations are crowd annotations
groundtruth_area: List of tensors containing the area of ground truth annotations
detection_labels: List of tensors containing the predicted labels
detection_box: List of tensors containing the predicted bounding boxes
detection_mask: List of tensors containing the predicted segmentation masks
detection_keypoint: List of tensors containing the predicted keypoints
detection_scores: List of tensors containing the confidence scores for predictions
name: Name of the output file, which will be appended with "_preds.json" and "_target.json"
iou_type: Type of IoU calculation to use. Can be either "bbox" for bounding box or "segm" for segmentation
Expand Down Expand Up @@ -440,6 +500,7 @@ def tm_to_coco(
labels=groundtruth_labels,
boxes=groundtruth_box if len(groundtruth_box) > 0 else None,
masks=groundtruth_mask if len(groundtruth_mask) > 0 else None,
keypoints=groundtruth_keypoint if len(groundtruth_keypoint) > 0 else None,
crowds=groundtruth_crowds,
area=groundtruth_area,
all_labels=all_labels,
Expand All @@ -450,6 +511,7 @@ def tm_to_coco(
labels=detection_labels,
boxes=detection_box if len(detection_box) > 0 else None,
masks=detection_mask if len(detection_mask) > 0 else None,
keypoints=detection_keypoint if len(detection_keypoint) > 0 else None,
scores=detection_scores,
all_labels=all_labels,
iou_type=iou_type,
Expand Down Expand Up @@ -485,6 +547,7 @@ def _get_coco_format(
all_labels: List[Tensor],
boxes: Optional[List[Tensor]] = None,
masks: Optional[List[Tensor]] = None,
keypoints: Optional[List[Tensor]] = None,
scores: Optional[List[Tensor]] = None,
crowds: Optional[List[Tensor]] = None,
area: Optional[List[Tensor]] = None,
Expand All @@ -509,6 +572,15 @@ def _get_coco_format(
image_masks = masks[image_id]
if len(image_masks) == 0 and boxes is None:
continue
if keypoints is not None:
if boxes is None:
kp = keypoints[image_id].view(-1, 17, 3)
v_mask = kp[..., 2:] > 0 # visible points
mins = torch.masked.amin(kp[..., :2], mask=v_mask, dim=1) # [N, 2] -> [xmin, ymin]
maxs = torch.masked.amax(kp[..., :2], mask=v_mask, dim=1)
image_boxes = torch.cat([mins, maxs - mins], dim=-1).nan_to_num(0).cpu().tolist()
image_keypoints = keypoints[image_id]
image_keypoints = image_keypoints.cpu().tolist()
image_labels = image_labels.cpu().tolist() # type: ignore[assignment]

images.append({"id": image_id})
Expand All @@ -521,12 +593,20 @@ def _get_coco_format(
if masks is not None and len(image_masks) > 0:
image_mask = image_masks[k]
image_mask = {"size": image_mask[0], "counts": image_mask[1]}
if keypoints is not None and len(image_keypoints) > 0:
image_keypoint = image_keypoints[k]
image_box = image_boxes[k]

if "bbox" in iou_type and len(image_box) != 4:
raise ValueError(
f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})"
)

if "keypoints" in iou_type and len(image_keypoint) != 51:
raise ValueError(
f"Invalid input keypoint of sample {image_id}, element {k} (expected 3 values, got {len(image_keypoint)})"
)

if not isinstance(image_label, int):
raise ValueError(
f"Invalid input class of sample {image_id}, element {k}"
Expand All @@ -535,13 +615,16 @@ def _get_coco_format(

area_stat_box = None
area_stat_mask = None

if area is not None and area[image_id][k].cpu().tolist() > 0: # type: ignore[operator]
area_stat = area[image_id][k].cpu().tolist()
else:
area_stat = self.mask_utils.area(image_mask) if "segm" in iou_type else image_box[2] * image_box[3]
if len(iou_type) > 1:
area_stat_box = image_box[2] * image_box[3]
area_stat_mask = self.mask_utils.area(image_mask)
if "bbox" in iou_type:
area_stat_box = image_box[2] * image_box[3]
if "segm" in iou_type:
area_stat_mask = self.mask_utils.area(image_mask)

annotation = {
"id": annotation_id,
Expand All @@ -558,6 +641,10 @@ def _get_coco_format(
annotation["bbox"] = image_box
if masks is not None:
annotation["segmentation"] = image_mask
if keypoints is not None:
annotation["bbox"] = image_box
annotation["keypoints"] = image_keypoint
annotation["num_keypoints"] = len(image_keypoint) // 3

if scores is not None:
score = scores[image_id][k].cpu().tolist()
Expand Down Expand Up @@ -596,6 +683,7 @@ def _warning_on_too_many_detections(limit: int) -> None:
def _get_safe_item_values(
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"], ...]],
box_format: str,
keypoint_format: str,
max_detection_thresholds: List[int],
coco_backend: CocoBackend,
item: dict[str, Any],
Expand Down Expand Up @@ -633,7 +721,7 @@ def _get_safe_item_values(
"""
from torchvision.ops import box_convert

output = [None, None]
output = [None, None, None]
if "bbox" in iou_type:
boxes = _fix_empty_tensors(item["boxes"])
if boxes.numel() > 0:
Expand All @@ -645,6 +733,18 @@ def _get_safe_item_values(
rle = coco_backend.mask_utils.encode(np.asfortranarray(i))
masks.append((tuple(rle["size"]), rle["counts"]))
output[1] = tuple(masks) # type: ignore[call-overload]
if "keypoints" in iou_type:

def keypoint_convert(keypoints, in_fmt="xy", out_fmt="xyv"):
# the keypoint format is expected to be:
# [number_of_keypoints * 3] -> x_0,y_0,viz_0, x_1,y_1,viz_1, ..., x_n,y_n,viz_n
if in_fmt == "xy" and out_fmt == "xyv":
keypoints = torch.cat([keypoints, torch.ones_like(keypoints[..., :1])], dim=-1)
keypoints = keypoints.flatten(-2) # [num_keypoints, 3] --> [num_keypoints * 3]
return keypoints

keypoints = _fix_empty_tensors(item["keypoints"])
output[2] = keypoint_convert(keypoints, keypoint_format, out_fmt="xyv")

def _valid_output_len(idx: int) -> bool:
val = output[idx]
Expand All @@ -668,13 +768,15 @@ def _calculate_map_with_coco(
groundtruth_labels: List[Tensor],
groundtruth_box: List[Tensor],
groundtruth_mask: List[Tensor],
groundtruth_keypoint: List[Tensor],
groundtruth_crowds: List[Tensor],
groundtruth_area: List[Tensor],
detection_labels: List[Tensor],
detection_box: List[Tensor],
detection_mask: List[Tensor],
detection_keypoint: List[Tensor],
detection_scores: List[Tensor],
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"], ...]],
iou_type: Union[Literal["bbox", "segm", "keypoints"], Tuple[Literal["bbox", "segm", "keypoints"], ...]],
average: Literal["macro", "micro"],
iou_thresholds: List[float],
rec_thresholds: List[float],
Expand All @@ -686,11 +788,13 @@ def _calculate_map_with_coco(
groundtruth_labels,
groundtruth_box,
groundtruth_mask,
groundtruth_keypoint,
groundtruth_crowds,
groundtruth_area,
detection_labels,
detection_box,
detection_mask,
detection_keypoint,
detection_scores,
iou_type,
average=average,
Expand All @@ -709,7 +813,7 @@ def _calculate_map_with_coco(
if len(coco_preds.imgs) == 0 or len(coco_target.imgs) == 0:
result_dict.update(
coco_backend._coco_stats_to_tensor_dict(
12 * [-1.0], prefix=prefix, max_detection_thresholds=max_detection_thresholds
12 * [-1.0], prefix=prefix, max_detection_thresholds=max_detection_thresholds, iou_type=i_type
)
)
else:
Expand All @@ -724,7 +828,7 @@ def _calculate_map_with_coco(
stats = coco_eval.stats
result_dict.update(
coco_backend._coco_stats_to_tensor_dict(
stats, prefix=prefix, max_detection_thresholds=max_detection_thresholds
stats, prefix=prefix, max_detection_thresholds=max_detection_thresholds, iou_type=i_type
)
)

Expand All @@ -748,11 +852,13 @@ def _calculate_map_with_coco(
groundtruth_labels,
groundtruth_box,
groundtruth_mask,
groundtruth_keypoint,
groundtruth_crowds,
groundtruth_area,
detection_labels,
detection_box,
detection_mask,
detection_keypoint,
detection_scores,
iou_type,
average="macro",
Expand Down
Loading
Loading