From 08e3ef27a4f8f2057ff10779ddb305f8d6add70c Mon Sep 17 00:00:00 2001 From: SebastianJanampa Date: Wed, 25 Mar 2026 12:35:37 -0600 Subject: [PATCH 1/2] keypoint support for map --- src/torchmetrics/detection/helpers.py | 158 ++++++++++++++++--- src/torchmetrics/detection/mean_ap.py | 32 +++- src/torchmetrics/functional/detection/map.py | 19 ++- tests/unittests/detection/__init__.py | 3 + tests/unittests/detection/test_map.py | 57 +++++-- 5 files changed, 219 insertions(+), 50 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 0b46b698638..04593f77b15 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -41,14 +41,14 @@ def _input_validator( preds: Sequence[dict[str, Tensor]], targets: Sequence[dict[str, Tensor]], - iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = "bbox", + iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox", ignore_score: bool = False, ) -> None: """Ensure the correct input format of `preds` and `targets`.""" if isinstance(iou_type, str): iou_type = (iou_type,) - name_map = {"bbox": "boxes", "segm": "masks"} + name_map = {"bbox": "boxes", "segm": "masks", "keypoints": "keypoints"} if any(tp not in name_map for tp in iou_type): raise Exception(f"IOU type {iou_type} is not supported") item_val_name = [name_map[tp] for tp in iou_type] @@ -110,16 +110,18 @@ def _fix_empty_tensors(boxes: Tensor) -> Tensor: def _validate_iou_type_arg( - iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = "bbox", -) -> tuple[Literal["bbox", "segm"], ...]: + iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox", +) -> tuple[Literal["bbox", "segm", "keypoints"], ...]: """Validate that iou type argument is correct.""" - allowed_iou_types = ("segm", "bbox") + allowed_iou_types = ("segm", "bbox", "keypoints") if isinstance(iou_type, str): iou_type = (iou_type,) if any(tp not in allowed_iou_types for tp in iou_type): raise ValueError( f"Expected argument `iou_type` to be one of {allowed_iou_types} or a tuple of, but got {iou_type}" ) + if "segm" in iou_type and "keypoints" in iou_type: + raise NotImplementedError() return iou_type @@ -194,11 +196,13 @@ def _get_coco_datasets( groundtruth_labels: List[Tensor], groundtruth_box: List[Tensor], groundtruth_mask: List[Tensor], + groundtruth_keypoint: List[Tensor], groundtruth_crowds: List[Tensor], groundtruth_area: List[Tensor], detection_labels: List[Tensor], detection_box: List[Tensor], detection_mask: List[Tensor], + detection_keypoint: List[Tensor], detection_scores: List[Tensor], iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = ("bbox",), average: Literal["macro", "micro"] = "micro", @@ -221,6 +225,7 @@ def _get_coco_datasets( labels=groundtruth_labels, boxes=groundtruth_box if len(groundtruth_box) > 0 else None, masks=groundtruth_mask if len(groundtruth_mask) > 0 else None, + keypoints=groundtruth_keypoint if len(groundtruth_keypoint) > 0 else None, crowds=groundtruth_crowds, area=groundtruth_area, iou_type=iou_type, @@ -231,6 +236,7 @@ def _get_coco_datasets( labels=detection_labels, boxes=detection_box if len(detection_box) > 0 else None, masks=detection_mask if len(detection_mask) > 0 else None, + keypoints=detection_keypoint if len(detection_keypoint) > 0 else None, scores=detection_scores, iou_type=iou_type, all_labels=all_labels, @@ -244,24 +250,38 @@ def _get_coco_datasets( return coco_preds, coco_target def _coco_stats_to_tensor_dict( - self, stats: list[float], prefix: str, max_detection_thresholds: list[int] + self, stats: list[float], prefix: str, max_detection_thresholds: list[int], iou_type: str ) -> dict[str, Tensor]: """Converts the output of COCOeval.stats to a dict of tensors.""" mdt = max_detection_thresholds - return { - f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), - f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), - f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), - f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), - f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), - f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), - f"{prefix}mar_{mdt[0]}": torch.tensor([stats[6]], dtype=torch.float32), - f"{prefix}mar_{mdt[1]}": torch.tensor([stats[7]], dtype=torch.float32), - f"{prefix}mar_{mdt[2]}": torch.tensor([stats[8]], dtype=torch.float32), - f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), - f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), - f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), - } + if iou_type == "keypoints": + return { + f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), + f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), + f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), + f"{prefix}map_medium": torch.tensor([stats[3]], dtype=torch.float32), + f"{prefix}map_large": torch.tensor([stats[4]], dtype=torch.float32), + f"{prefix}mar_{mdt[0]}": torch.tensor([stats[5]], dtype=torch.float32), + f"{prefix}mar_{mdt[1]}": torch.tensor([stats[6]], dtype=torch.float32), + f"{prefix}mar_{mdt[2]}": torch.tensor([stats[7]], dtype=torch.float32), + f"{prefix}mar_medium": torch.tensor([stats[8]], dtype=torch.float32), + f"{prefix}mar_large": torch.tensor([stats[9]], dtype=torch.float32), + } + else: + return { + f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), + f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), + f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), + f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), + f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), + f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), + f"{prefix}mar_{mdt[0]}": torch.tensor([stats[6]], dtype=torch.float32), + f"{prefix}mar_{mdt[1]}": torch.tensor([stats[7]], dtype=torch.float32), + f"{prefix}mar_{mdt[2]}": torch.tensor([stats[8]], dtype=torch.float32), + f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), + f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), + f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), + } @staticmethod def coco_to_tm( @@ -319,11 +339,21 @@ def coco_to_tm( target[t["image_id"]]["boxes"] = [] if "segm" in iou_type: target[t["image_id"]]["masks"] = [] + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(t, "bbox"): + target[t["image_id"]]["boxes"] = [] + target[t["image_id"]]["keypoints"] = [] + target[t["image_id"]]["num_keypoints"] = [] if "bbox" in iou_type: target[t["image_id"]]["boxes"].append(t["bbox"]) if "segm" in iou_type: target[t["image_id"]]["masks"].append(gt.annToMask(t)) + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(t, "bbox"): + target[t["image_id"]]["boxes"].append(t["bbox"]) + target[t["image_id"]]["keypoints"].append(t['keypoints']) + target[t["image_id"]]["num_keypoints"].append(len(t["keypoints"]) // 3) target[t["image_id"]]["labels"].append(t["category_id"]) target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) target[t["image_id"]]["area"].append(t["area"]) @@ -336,12 +366,23 @@ def coco_to_tm( preds[p["image_id"]]["boxes"] = [] if "segm" in iou_type: preds[p["image_id"]]["masks"] = [] + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(p, "bbox"): + preds[p["image_id"]]["boxes"] = [] + preds[p["image_id"]]["keypoints"] = [] + preds[p["image_id"]]["num_keypoints"] = [] if "bbox" in iou_type: preds[p["image_id"]]["boxes"].append(p["bbox"]) if "segm" in iou_type: preds[p["image_id"]]["masks"].append(gt.annToMask(p)) + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(p, "bbox"): + preds[p["image_id"]]["boxes"].append(p["bbox"]) + preds[p["image_id"]]["keypoints"].append(p["keypoints"]) + preds[p["image_id"]]["num_keypoints"].append(len(p["keypoints"]) // 3) preds[p["image_id"]]["scores"].append(p["score"]) preds[p["image_id"]]["labels"].append(p["category_id"]) + for k in target: # add empty predictions for images without predictions if k not in preds: preds[k] = {"scores": [], "labels": []} @@ -349,6 +390,11 @@ def coco_to_tm( preds[k]["boxes"] = [] if "segm" in iou_type: preds[k]["masks"] = [] + if "keypoints" in iou_type: + preds[k]["keypoints"] = [] + preds[k]["num_keypoints"] = [] + if "bbox" not in iou_type and hasattr(preds[k], "bbox"): + preds[k]["boxes"] = [] batched_preds, batched_target = [], [] for key in target: @@ -360,6 +406,12 @@ def coco_to_tm( bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) if "segm" in iou_type: bp["masks"] = torch.tensor(np.array(preds[key]["masks"]), dtype=torch.uint8) + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(bp, "bbox"): + bp["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) + bp["keypoints"] = torch.tensor(np.array(preds[key]["keypoints"]), dtype=torch.float32) + bp["num_keypoints"] = torch.tensor(np.array(preds[key]["num_keypoints"]), dtype=torch.int32) + batched_preds.append(bp) bt = { @@ -371,6 +423,11 @@ def coco_to_tm( bt["boxes"] = torch.tensor(target[key]["boxes"], dtype=torch.float32) if "segm" in iou_type: bt["masks"] = torch.tensor(np.array(target[key]["masks"]), dtype=torch.uint8) + if "keypoints" in iou_type: + if "bbox" not in iou_type and hasattr(bp, "bbox"): + bt["boxes"] = torch.tensor(np.array(preds[key]["boxes"]), dtype=torch.float32) + bt["keypoints"] = torch.tensor(np.array(target[key]["keypoints"]), dtype=torch.float32) + bt["num_keypoints"] = torch.tensor(np.array(target[key]["num_keypoints"]), dtype=torch.int32) batched_target.append(bt) return batched_preds, batched_target @@ -380,11 +437,13 @@ def tm_to_coco( groundtruth_labels: List[Tensor], groundtruth_box: List[Tensor], groundtruth_mask: List[Tensor], + groundtruth_keypoint: List[Tensor], groundtruth_crowds: List[Tensor], groundtruth_area: List[Tensor], detection_labels: List[Tensor], detection_box: List[Tensor], detection_mask: List[Tensor], + detection_keypoint: List[Tensor], detection_scores: List[Tensor], name: str = "tm_map_input", iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = ("bbox",), @@ -400,11 +459,13 @@ def tm_to_coco( groundtruth_labels: List of tensors containing the ground truth labels groundtruth_box: List of tensors containing the ground truth bounding boxes groundtruth_mask: List of tensors containing the ground truth segmentation masks + groundtruth_keypoint: List of tensors containing the ground truth keypoints groundtruth_crowds: List of tensors indicating whether ground truth annotations are crowd annotations groundtruth_area: List of tensors containing the area of ground truth annotations detection_labels: List of tensors containing the predicted labels detection_box: List of tensors containing the predicted bounding boxes detection_mask: List of tensors containing the predicted segmentation masks + detection_keypoint: List of tensors containing the predicted keypoints detection_scores: List of tensors containing the confidence scores for predictions name: Name of the output file, which will be appended with "_preds.json" and "_target.json" iou_type: Type of IoU calculation to use. Can be either "bbox" for bounding box or "segm" for segmentation @@ -440,6 +501,7 @@ def tm_to_coco( labels=groundtruth_labels, boxes=groundtruth_box if len(groundtruth_box) > 0 else None, masks=groundtruth_mask if len(groundtruth_mask) > 0 else None, + keypoints=groundtruth_keypoint if len(groundtruth_keypoint) > 0 else None, crowds=groundtruth_crowds, area=groundtruth_area, all_labels=all_labels, @@ -450,6 +512,7 @@ def tm_to_coco( labels=detection_labels, boxes=detection_box if len(detection_box) > 0 else None, masks=detection_mask if len(detection_mask) > 0 else None, + keypoints=detection_keypoint if len(detection_keypoint) > 0 else None, scores=detection_scores, all_labels=all_labels, iou_type=iou_type, @@ -485,6 +548,7 @@ def _get_coco_format( all_labels: List[Tensor], boxes: Optional[List[Tensor]] = None, masks: Optional[List[Tensor]] = None, + keypoints: Optional[List[Tensor]] = None, scores: Optional[List[Tensor]] = None, crowds: Optional[List[Tensor]] = None, area: Optional[List[Tensor]] = None, @@ -509,6 +573,15 @@ def _get_coco_format( image_masks = masks[image_id] if len(image_masks) == 0 and boxes is None: continue + if keypoints is not None: + if boxes is None: + kp = keypoints[image_id].view(-1, 17, 3) + v_mask = kp[..., 2:] > 0 # visible points + mins = torch.masked.amin(kp[..., :2], mask=v_mask, dim=1) # [N, 2] -> [xmin, ymin] + maxs = torch.masked.amax(kp[..., :2], mask=v_mask, dim=1) + image_boxes = torch.cat([mins, maxs - mins], dim=-1).nan_to_num(0).cpu().tolist() + image_keypoints = keypoints[image_id] + image_keypoints = image_keypoints.cpu().tolist() image_labels = image_labels.cpu().tolist() # type: ignore[assignment] images.append({"id": image_id}) @@ -521,12 +594,20 @@ def _get_coco_format( if masks is not None and len(image_masks) > 0: image_mask = image_masks[k] image_mask = {"size": image_mask[0], "counts": image_mask[1]} + if keypoints is not None and len(image_keypoints) > 0: + image_keypoint = image_keypoints[k] + image_box = image_boxes[k] if "bbox" in iou_type and len(image_box) != 4: raise ValueError( f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" ) + if "keypoints" in iou_type and len(image_keypoint) != 51: + raise ValueError( + f"Invalid input keypoint of sample {image_id}, element {k} (expected 3 values, got {len(image_keypoint)})" + ) + if not isinstance(image_label, int): raise ValueError( f"Invalid input class of sample {image_id}, element {k}" @@ -535,13 +616,16 @@ def _get_coco_format( area_stat_box = None area_stat_mask = None + if area is not None and area[image_id][k].cpu().tolist() > 0: # type: ignore[operator] area_stat = area[image_id][k].cpu().tolist() else: area_stat = self.mask_utils.area(image_mask) if "segm" in iou_type else image_box[2] * image_box[3] if len(iou_type) > 1: - area_stat_box = image_box[2] * image_box[3] - area_stat_mask = self.mask_utils.area(image_mask) + if 'bbox' in iou_type: + area_stat_box = image_box[2] * image_box[3] + if 'segm' in iou_type: + area_stat_mask = self.mask_utils.area(image_mask) annotation = { "id": annotation_id, @@ -558,6 +642,10 @@ def _get_coco_format( annotation["bbox"] = image_box if masks is not None: annotation["segmentation"] = image_mask + if keypoints is not None: + annotation["bbox"] = image_box + annotation["keypoints"] = image_keypoint + annotation["num_keypoints"] = len(image_keypoint) // 3 if scores is not None: score = scores[image_id][k].cpu().tolist() @@ -596,6 +684,7 @@ def _warning_on_too_many_detections(limit: int) -> None: def _get_safe_item_values( iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"], ...]], box_format: str, + keypoint_format: str, max_detection_thresholds: List[int], coco_backend: CocoBackend, item: dict[str, Any], @@ -633,7 +722,7 @@ def _get_safe_item_values( """ from torchvision.ops import box_convert - output = [None, None] + output = [None, None, None] if "bbox" in iou_type: boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: @@ -645,6 +734,17 @@ def _get_safe_item_values( rle = coco_backend.mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) output[1] = tuple(masks) # type: ignore[call-overload] + if "keypoints" in iou_type: + def keypoint_convert(keypoints, in_fmt='xy', out_fmt="xyv"): + # the keypoint format is expected to be: + # [number_of_keypoints * 3] -> x_0,y_0,viz_0, x_1,y_1,viz_1, ..., x_n,y_n,viz_n + if in_fmt == "xy" and out_fmt == "xyv": + keypoints = torch.cat([keypoints, torch.ones_like(keypoints[..., :1])], dim=-1) + keypoints = keypoints.flatten(-2) # [num_keypoints, 3] --> [num_keypoints * 3] + return keypoints + + keypoints = _fix_empty_tensors(item["keypoints"]) + output[2] = keypoint_convert(keypoints, keypoint_format, out_fmt="xyv") def _valid_output_len(idx: int) -> bool: val = output[idx] @@ -668,13 +768,15 @@ def _calculate_map_with_coco( groundtruth_labels: List[Tensor], groundtruth_box: List[Tensor], groundtruth_mask: List[Tensor], + groundtruth_keypoint: List[Tensor], groundtruth_crowds: List[Tensor], groundtruth_area: List[Tensor], detection_labels: List[Tensor], detection_box: List[Tensor], detection_mask: List[Tensor], + detection_keypoint: List[Tensor], detection_scores: List[Tensor], - iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"], ...]], + iou_type: Union[Literal["bbox", "segm", "keypoints"], Tuple[Literal["bbox", "segm", "keypoints"], ...]], average: Literal["macro", "micro"], iou_thresholds: List[float], rec_thresholds: List[float], @@ -686,11 +788,13 @@ def _calculate_map_with_coco( groundtruth_labels, groundtruth_box, groundtruth_mask, + groundtruth_keypoint, groundtruth_crowds, groundtruth_area, detection_labels, detection_box, detection_mask, + detection_keypoint, detection_scores, iou_type, average=average, @@ -709,7 +813,7 @@ def _calculate_map_with_coco( if len(coco_preds.imgs) == 0 or len(coco_target.imgs) == 0: result_dict.update( coco_backend._coco_stats_to_tensor_dict( - 12 * [-1.0], prefix=prefix, max_detection_thresholds=max_detection_thresholds + 12 * [-1.0], prefix=prefix, max_detection_thresholds=max_detection_thresholds, iou_type=i_type ) ) else: @@ -724,7 +828,7 @@ def _calculate_map_with_coco( stats = coco_eval.stats result_dict.update( coco_backend._coco_stats_to_tensor_dict( - stats, prefix=prefix, max_detection_thresholds=max_detection_thresholds + stats, prefix=prefix, max_detection_thresholds=max_detection_thresholds, iou_type=i_type ) ) @@ -748,11 +852,13 @@ def _calculate_map_with_coco( groundtruth_labels, groundtruth_box, groundtruth_mask, + groundtruth_keypoint, groundtruth_crowds, groundtruth_area, detection_labels, detection_box, detection_mask, + detection_keypoint, detection_scores, iou_type, average="macro", diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 429766ffe25..440778b292c 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -158,8 +158,8 @@ class MeanAveragePrecision(Metric): width and height. iou_type: - Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are - ``"bbox"`` or ``"segm"`` or both as a tuple. + Type of input (either masks or bounding-boxes or keypoints) used for computing IOU. Supported IOU types are + ``"bbox"`` or ``"segm"`` or ``"kpt"`` or ``["bbox", "segm"]`` or ``["bbox", "kpt"]``. iou_thresholds: IoU thresholds for evaluation. If set to ``None`` it corresponds to the stepped range ``[0.5,...,0.95]`` with step ``0.05``. Else provide a list of floats. @@ -205,7 +205,7 @@ class MeanAveragePrecision(Metric): ValueError: If ``box_format`` is not one of ``"xyxy"``, ``"xywh"`` or ``"cxcywh"`` ValueError: - If ``iou_type`` is not one of ``"bbox"`` or ``"segm"`` + If ``iou_type`` is not one of ``"bbox"`` or ``"segm"`` or ``"keypoints"`` ValueError: If ``iou_thresholds`` is not None or a list of floats ValueError: @@ -319,10 +319,12 @@ class MeanAveragePrecision(Metric): detection_box: List[Tensor] detection_mask: List[Tensor] + detection_keypoint: List[Tensor] detection_scores: List[Tensor] detection_labels: List[Tensor] groundtruth_box: List[Tensor] groundtruth_mask: List[Tensor] + groundtruth_keypoint: List[Tensor] groundtruth_labels: List[Tensor] groundtruth_crowds: List[Tensor] groundtruth_area: List[Tensor] @@ -344,7 +346,8 @@ class MeanAveragePrecision(Metric): def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Union[Literal["bbox", "segm"], tuple[Literal["bbox", "segm"], ...]] = "bbox", + keypoint_format: Literal["xy", "xyv"] = "xy", + iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox", iou_thresholds: Optional[list[float]] = None, rec_thresholds: Optional[list[float]] = None, max_detection_thresholds: Optional[list[int]] = None, @@ -373,6 +376,11 @@ def __init__( raise ValueError(f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}") self.box_format = box_format + allowed_keypoint_formats = ("xy", "xyv") + if keypoint_format not in allowed_keypoint_formats: + raise ValueError(f"Expected argument `keypoint_format` to be one of {allowed_keypoint_formats} but got {keypoint_format}") + self.keypoint_format = keypoint_format + self.iou_type = _validate_iou_type_arg(iou_type) if iou_thresholds is not None and not isinstance(iou_thresholds, list): @@ -416,10 +424,12 @@ def __init__( self.add_state("detection_box", default=[], dist_reduce_fx=None) self.add_state("detection_mask", default=[], dist_reduce_fx=None) + self.add_state("detection_keypoint", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) self.add_state("detection_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruth_box", default=[], dist_reduce_fx=None) self.add_state("groundtruth_mask", default=[], dist_reduce_fx=None) + self.add_state("groundtruth_keypoint", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) self.add_state("groundtruth_crowds", default=[], dist_reduce_fx=None) self.add_state("groundtruth_area", default=[], dist_reduce_fx=None) @@ -459,11 +469,13 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: self.groundtruth_labels, self.groundtruth_box, self.groundtruth_mask, + self.groundtruth_keypoint, self.groundtruth_crowds, self.groundtruth_area, self.detection_labels, self.detection_box, self.detection_mask, + self.detection_keypoint, self.detection_scores, name, self.iou_type, @@ -530,9 +542,10 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]] _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - bbox_detection, mask_detection = _get_safe_item_values( + bbox_detection, mask_detection, keypoint_detection = _get_safe_item_values( iou_type=self.iou_type, box_format=self.box_format, + keypoint_format=self.keypoint_format, max_detection_thresholds=self.max_detection_thresholds, coco_backend=self._coco_backend, item=item, @@ -542,13 +555,16 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]] self.detection_box.append(bbox_detection) if mask_detection is not None: self.detection_mask.append(mask_detection) # type: ignore[arg-type] + if keypoint_detection is not None: + self.detection_keypoint.append(keypoint_detection) # type: ignore[arg-type] self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: - bbox_groundtruth, mask_groundtruth = _get_safe_item_values( + bbox_groundtruth, mask_groundtruth, keypoint_groundtruth = _get_safe_item_values( self.iou_type, self.box_format, + self.keypoint_format, self.max_detection_thresholds, self._coco_backend, item, @@ -557,6 +573,8 @@ def update(self, preds: list[dict[str, Tensor]], target: list[dict[str, Tensor]] self.groundtruth_box.append(bbox_groundtruth) if mask_groundtruth is not None: self.groundtruth_mask.append(mask_groundtruth) # type: ignore[arg-type] + if keypoint_groundtruth is not None: + self.groundtruth_keypoint.append(keypoint_groundtruth) # type: ignore[arg-type] self.groundtruth_labels.append(item["labels"]) self.groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) self.groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) @@ -568,11 +586,13 @@ def compute(self) -> dict: self.groundtruth_labels, self.groundtruth_box, self.groundtruth_mask, + self.groundtruth_keypoint, self.groundtruth_crowds, self.groundtruth_area, self.detection_labels, self.detection_box, self.detection_mask, + self.detection_keypoint, self.detection_scores, self.iou_type, self.average, diff --git a/src/torchmetrics/functional/detection/map.py b/src/torchmetrics/functional/detection/map.py index 44d19f9baf4..c78b86c5d3b 100644 --- a/src/torchmetrics/functional/detection/map.py +++ b/src/torchmetrics/functional/detection/map.py @@ -40,7 +40,8 @@ def mean_average_precision( preds: List[Dict[str, Any]], target: List[Dict[str, Any]], box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", - iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"], ...]] = "bbox", + keypoint_format: Literal["xy", "xyv"] = "xy", + iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox", iou_thresholds: Optional[list[float]] = None, rec_thresholds: Optional[list[float]] = None, max_detection_thresholds: Optional[list[int]] = None, @@ -167,33 +168,39 @@ def mean_average_precision( coco_backend = CocoBackend(backend=backend) detection_box: List[Tensor] = [] + detection_keypoints: List[Tensor] = [] detection_labels: List[Tensor] = [] detection_scores: List[Tensor] = [] detection_mask: List[Tensor] = [] for item in preds: - bbox_detection, mask_detection = _get_safe_item_values( - iou_type, box_format, max_detection_thresholds, coco_backend, item, warn=warn_on_many_detections + bbox_detection, mask_detection, keypoint_detection = _get_safe_item_values( + iou_type, box_format, keypoint_format, max_detection_thresholds, coco_backend, item, warn=warn_on_many_detections ) if bbox_detection is not None: detection_box.append(bbox_detection) if mask_detection is not None: detection_mask.append(mask_detection) # type: ignore[arg-type] + if keypoint_detection is not None: + detection_keypoints.append(keypoint_detection) detection_labels.append(item["labels"]) detection_scores.append(item["scores"]) groundtruth_box: List[Tensor] = [] groundtruth_mask: List[Tensor] = [] + groundtruth_keypoints: List[Tensor] = [] groundtruth_labels: List[Tensor] = [] groundtruth_crowds: List[Tensor] = [] groundtruth_area: List[Tensor] = [] for item in target: - bbox_groundtruth, mask_groundtruth = _get_safe_item_values( - iou_type, box_format, max_detection_thresholds, coco_backend, item + bbox_groundtruth, mask_groundtruth, keypoint_groundtruth = _get_safe_item_values( + iou_type, box_format, keypoint_format, max_detection_thresholds, coco_backend, item ) if bbox_groundtruth is not None: groundtruth_box.append(bbox_groundtruth) if mask_groundtruth is not None: groundtruth_mask.append(mask_groundtruth) # type: ignore[arg-type] + if keypoint_groundtruth is not None: + groundtruth_keypoints.append(keypoint_groundtruth) groundtruth_labels.append(item["labels"]) groundtruth_crowds.append(item.get("iscrowd", torch.zeros_like(item["labels"]))) groundtruth_area.append(item.get("area", torch.zeros_like(item["labels"]))) @@ -203,11 +210,13 @@ def mean_average_precision( groundtruth_labels, groundtruth_box, groundtruth_mask, + groundtruth_keypoints, groundtruth_crowds, groundtruth_area, detection_labels, detection_box, detection_mask, + detection_keypoints, detection_scores, iou_type, average, diff --git a/tests/unittests/detection/__init__.py b/tests/unittests/detection/__init__.py index af0589a2557..d72bd3975c9 100644 --- a/tests/unittests/detection/__init__.py +++ b/tests/unittests/detection/__init__.py @@ -8,3 +8,6 @@ _DETECTION_VAL = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "instances_val2014_100.json") _DETECTION_BBOX = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "instances_val2014_fakebbox100_results.json") _DETECTION_SEGM = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "instances_val2014_fakesegm100_results.json") + +_DETECTION_KEYPOINTS = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "person_keypoints_val2014_fakekeypoints100_results.json") +_DETECTION_VAL_KEYPOINTS = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "person_keypoints_val2014_100.json") diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 62e1ba8661f..84b9a1e68df 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -34,7 +34,7 @@ _PYCOCOTOOLS_AVAILABLE, ) from unittests._helpers.testers import MetricTester -from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL +from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL, _DETECTION_KEYPOINTS, _DETECTION_VAL_KEYPOINTS def _skip_if_faster_coco_eval_missing(backend): @@ -50,29 +50,45 @@ def _generate_coco_inputs(iou_type): and should therefore correspond directly to the result on the webpage """ + if iou_type == "bbox": + _DETECTION = _DETECTION_BBOX + _VAL_DETECTION = _DETECTION_VAL + elif iou_type == "segm": + _DETECTION = _DETECTION_SEGM + _VAL_DETECTION = _DETECTION_VAL + elif iou_type == "keypoints": + _DETECTION = _DETECTION_KEYPOINTS + _VAL_DETECTION = _DETECTION_VAL_KEYPOINTS batched_preds, batched_target = MeanAveragePrecision().coco_to_tm( - _DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM, _DETECTION_VAL, iou_type + _DETECTION, _VAL_DETECTION, iou_type ) # create 10 batches of 10 preds/targets each - batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)] - batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(10)] + n = min(10, len(batched_preds) // 10) + batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(n)] + batched_target = [batched_target[10 * i : 10 * (i + 1)] for i in range(n)] return batched_preds, batched_target _coco_bbox_input = _generate_coco_inputs("bbox") _coco_segm_input = _generate_coco_inputs("segm") +_coco_keypoints_input = _generate_coco_inputs("keypoints") @pytest.mark.skipif( not _PYCOCOTOOLS_AVAILABLE, reason="test requires that torchvision=>0.8.0 and pycocotools is installed" ) -@pytest.mark.parametrize("iou_type", ["bbox", "segm"]) +@pytest.mark.parametrize("iou_type", ["bbox", "segm", "keypoints"]) @pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) def test_tm_to_coco(tmpdir, iou_type, backend): """Test that the conversion from TM to COCO format works.""" - preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input - metric = MeanAveragePrecision(iou_type=iou_type, backend=backend, box_format="xywh") + if iou_type == "bbox": + preds, target = _coco_bbox_input + elif iou_type == "segm": + preds, target = _coco_segm_input + elif iou_type == "keypoints": + preds, target = _coco_keypoints_input + metric = MeanAveragePrecision(iou_type=iou_type, backend=backend, box_format="xywh", keypoint_format="xyv") for bp, bt in zip(preds, target): metric.update(bp, bt) metric.tm_to_coco(f"{tmpdir}/tm_map_input") @@ -95,11 +111,16 @@ def test_tm_to_coco(tmpdir, iou_type, backend): sample1["masks"], sample2["masks"] ): sample_found = True - else: + elif iou_type == "bbox": if sample1["boxes"].shape == sample2["boxes"].shape and torch.allclose( sample1["boxes"], sample2["boxes"] ): sample_found = True + elif iou_type == "keypoints": + if sample1["keypoints"].shape == sample2["keypoints"].shape and torch.allclose( + sample1["keypoints"], sample2["keypoints"] + ): + sample_found = True assert sample_found, "preds not found" for sample1 in target: @@ -110,11 +131,16 @@ def test_tm_to_coco(tmpdir, iou_type, backend): sample1["masks"], sample2["masks"] ): sample_found = True - else: + elif iou_type == "bbox": if sample1["boxes"].shape == sample2["boxes"].shape and torch.allclose( sample1["boxes"], sample2["boxes"] ): sample_found = True + elif iou_type == "keypoints": + if sample1["keypoints"].shape == sample2["keypoints"].shape and torch.allclose( + sample1["keypoints"], sample2["keypoints"] + ): + sample_found = True assert sample_found, "target not found" @@ -965,19 +991,24 @@ def compare_with_class(functional_result, preds, target, **kwargs: Any): @pytest.mark.parametrize("backend", ["pycocotools", "faster_coco_eval"]) -@pytest.mark.parametrize("iou_type", ["bbox", "segm"]) +@pytest.mark.parametrize("iou_type", ["bbox", "segm", "keypoints"]) def test_mean_average_precision_iou_type_functional(backend, iou_type): """Test that the functional API returns a valid dictionary with the expected keys.""" - preds, target = _coco_bbox_input if iou_type == "bbox" else _coco_segm_input + if iou_type == "bbox": + preds, target = _coco_bbox_input + elif iou_type == "segm": + preds, target = _coco_segm_input + elif iou_type == "keypoints": + preds, target = _coco_keypoints_input preds_flat = [p for batch in preds for p in batch] target_flat = [t for batch in target for t in batch] functional_result = mean_average_precision( - preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh" + preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh", keypoint_format="xyv", ) compare_with_class( - functional_result, preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh" + functional_result, preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh", keypoint_format="xyv", ) From 8d37ac2c37b581ba4b83c97152ff45405d29a8a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:56:27 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/detection/helpers.py | 42 ++++++++++---------- src/torchmetrics/detection/mean_ap.py | 8 +++- src/torchmetrics/functional/detection/map.py | 10 ++++- tests/unittests/detection/__init__.py | 4 +- tests/unittests/detection/test_map.py | 27 ++++++++++--- 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index 04593f77b15..1e9105ba7da 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -267,21 +267,20 @@ def _coco_stats_to_tensor_dict( f"{prefix}mar_medium": torch.tensor([stats[8]], dtype=torch.float32), f"{prefix}mar_large": torch.tensor([stats[9]], dtype=torch.float32), } - else: - return { - f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), - f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), - f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), - f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), - f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), - f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), - f"{prefix}mar_{mdt[0]}": torch.tensor([stats[6]], dtype=torch.float32), - f"{prefix}mar_{mdt[1]}": torch.tensor([stats[7]], dtype=torch.float32), - f"{prefix}mar_{mdt[2]}": torch.tensor([stats[8]], dtype=torch.float32), - f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), - f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), - f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), - } + return { + f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), + f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), + f"{prefix}map_75": torch.tensor([stats[2]], dtype=torch.float32), + f"{prefix}map_small": torch.tensor([stats[3]], dtype=torch.float32), + f"{prefix}map_medium": torch.tensor([stats[4]], dtype=torch.float32), + f"{prefix}map_large": torch.tensor([stats[5]], dtype=torch.float32), + f"{prefix}mar_{mdt[0]}": torch.tensor([stats[6]], dtype=torch.float32), + f"{prefix}mar_{mdt[1]}": torch.tensor([stats[7]], dtype=torch.float32), + f"{prefix}mar_{mdt[2]}": torch.tensor([stats[8]], dtype=torch.float32), + f"{prefix}mar_small": torch.tensor([stats[9]], dtype=torch.float32), + f"{prefix}mar_medium": torch.tensor([stats[10]], dtype=torch.float32), + f"{prefix}mar_large": torch.tensor([stats[11]], dtype=torch.float32), + } @staticmethod def coco_to_tm( @@ -352,7 +351,7 @@ def coco_to_tm( if "keypoints" in iou_type: if "bbox" not in iou_type and hasattr(t, "bbox"): target[t["image_id"]]["boxes"].append(t["bbox"]) - target[t["image_id"]]["keypoints"].append(t['keypoints']) + target[t["image_id"]]["keypoints"].append(t["keypoints"]) target[t["image_id"]]["num_keypoints"].append(len(t["keypoints"]) // 3) target[t["image_id"]]["labels"].append(t["category_id"]) target[t["image_id"]]["iscrowd"].append(t["iscrowd"]) @@ -576,7 +575,7 @@ def _get_coco_format( if keypoints is not None: if boxes is None: kp = keypoints[image_id].view(-1, 17, 3) - v_mask = kp[..., 2:] > 0 # visible points + v_mask = kp[..., 2:] > 0 # visible points mins = torch.masked.amin(kp[..., :2], mask=v_mask, dim=1) # [N, 2] -> [xmin, ymin] maxs = torch.masked.amax(kp[..., :2], mask=v_mask, dim=1) image_boxes = torch.cat([mins, maxs - mins], dim=-1).nan_to_num(0).cpu().tolist() @@ -622,9 +621,9 @@ def _get_coco_format( else: area_stat = self.mask_utils.area(image_mask) if "segm" in iou_type else image_box[2] * image_box[3] if len(iou_type) > 1: - if 'bbox' in iou_type: + if "bbox" in iou_type: area_stat_box = image_box[2] * image_box[3] - if 'segm' in iou_type: + if "segm" in iou_type: area_stat_mask = self.mask_utils.area(image_mask) annotation = { @@ -735,12 +734,13 @@ def _get_safe_item_values( masks.append((tuple(rle["size"]), rle["counts"])) output[1] = tuple(masks) # type: ignore[call-overload] if "keypoints" in iou_type: - def keypoint_convert(keypoints, in_fmt='xy', out_fmt="xyv"): + + def keypoint_convert(keypoints, in_fmt="xy", out_fmt="xyv"): # the keypoint format is expected to be: # [number_of_keypoints * 3] -> x_0,y_0,viz_0, x_1,y_1,viz_1, ..., x_n,y_n,viz_n if in_fmt == "xy" and out_fmt == "xyv": keypoints = torch.cat([keypoints, torch.ones_like(keypoints[..., :1])], dim=-1) - keypoints = keypoints.flatten(-2) # [num_keypoints, 3] --> [num_keypoints * 3] + keypoints = keypoints.flatten(-2) # [num_keypoints, 3] --> [num_keypoints * 3] return keypoints keypoints = _fix_empty_tensors(item["keypoints"]) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 440778b292c..2dc0e89d98f 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -347,7 +347,9 @@ def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", keypoint_format: Literal["xy", "xyv"] = "xy", - iou_type: Union[Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...]] = "bbox", + iou_type: Union[ + Literal["bbox", "segm", "keypoints"], tuple[Literal["bbox", "segm", "keypoints"], ...] + ] = "bbox", iou_thresholds: Optional[list[float]] = None, rec_thresholds: Optional[list[float]] = None, max_detection_thresholds: Optional[list[int]] = None, @@ -378,7 +380,9 @@ def __init__( allowed_keypoint_formats = ("xy", "xyv") if keypoint_format not in allowed_keypoint_formats: - raise ValueError(f"Expected argument `keypoint_format` to be one of {allowed_keypoint_formats} but got {keypoint_format}") + raise ValueError( + f"Expected argument `keypoint_format` to be one of {allowed_keypoint_formats} but got {keypoint_format}" + ) self.keypoint_format = keypoint_format self.iou_type = _validate_iou_type_arg(iou_type) diff --git a/src/torchmetrics/functional/detection/map.py b/src/torchmetrics/functional/detection/map.py index c78b86c5d3b..059d9b9c9a0 100644 --- a/src/torchmetrics/functional/detection/map.py +++ b/src/torchmetrics/functional/detection/map.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Union import torch from torch import Tensor @@ -174,7 +174,13 @@ def mean_average_precision( detection_mask: List[Tensor] = [] for item in preds: bbox_detection, mask_detection, keypoint_detection = _get_safe_item_values( - iou_type, box_format, keypoint_format, max_detection_thresholds, coco_backend, item, warn=warn_on_many_detections + iou_type, + box_format, + keypoint_format, + max_detection_thresholds, + coco_backend, + item, + warn=warn_on_many_detections, ) if bbox_detection is not None: detection_box.append(bbox_detection) diff --git a/tests/unittests/detection/__init__.py b/tests/unittests/detection/__init__.py index d72bd3975c9..b2f2422ff0c 100644 --- a/tests/unittests/detection/__init__.py +++ b/tests/unittests/detection/__init__.py @@ -9,5 +9,7 @@ _DETECTION_BBOX = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "instances_val2014_fakebbox100_results.json") _DETECTION_SEGM = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "instances_val2014_fakesegm100_results.json") -_DETECTION_KEYPOINTS = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "person_keypoints_val2014_fakekeypoints100_results.json") +_DETECTION_KEYPOINTS = os.path.join( + _PATH_ALL_TESTS, "_data", "detection", "person_keypoints_val2014_fakekeypoints100_results.json" +) _DETECTION_VAL_KEYPOINTS = os.path.join(_PATH_ALL_TESTS, "_data", "detection", "person_keypoints_val2014_100.json") diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 84b9a1e68df..20a8b388600 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -34,7 +34,13 @@ _PYCOCOTOOLS_AVAILABLE, ) from unittests._helpers.testers import MetricTester -from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL, _DETECTION_KEYPOINTS, _DETECTION_VAL_KEYPOINTS +from unittests.detection import ( + _DETECTION_BBOX, + _DETECTION_KEYPOINTS, + _DETECTION_SEGM, + _DETECTION_VAL, + _DETECTION_VAL_KEYPOINTS, +) def _skip_if_faster_coco_eval_missing(backend): @@ -59,9 +65,7 @@ def _generate_coco_inputs(iou_type): elif iou_type == "keypoints": _DETECTION = _DETECTION_KEYPOINTS _VAL_DETECTION = _DETECTION_VAL_KEYPOINTS - batched_preds, batched_target = MeanAveragePrecision().coco_to_tm( - _DETECTION, _VAL_DETECTION, iou_type - ) + batched_preds, batched_target = MeanAveragePrecision().coco_to_tm(_DETECTION, _VAL_DETECTION, iou_type) # create 10 batches of 10 preds/targets each n = min(10, len(batched_preds) // 10) @@ -1005,10 +1009,21 @@ def test_mean_average_precision_iou_type_functional(backend, iou_type): target_flat = [t for batch in target for t in batch] functional_result = mean_average_precision( - preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh", keypoint_format="xyv", + preds_flat, + target_flat, + backend=backend, + iou_type=iou_type, + box_format="xywh", + keypoint_format="xyv", ) compare_with_class( - functional_result, preds_flat, target_flat, backend=backend, iou_type=iou_type, box_format="xywh", keypoint_format="xyv", + functional_result, + preds_flat, + target_flat, + backend=backend, + iou_type=iou_type, + box_format="xywh", + keypoint_format="xyv", )