From 1a9053f63c3bb728edb1f36159c99216f3e6f26b Mon Sep 17 00:00:00 2001 From: AshAnand34 Date: Thu, 15 May 2025 00:43:44 -0700 Subject: [PATCH 1/4] feat(detection): add transform method to remap and filter detections --- supervision/detection/core.py | 34 ++++++++++++++++++++++ test/detection/test_transform.py | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 test/detection/test_transform.py diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 5fa2b7b037..1bd23c3cc8 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1435,6 +1435,40 @@ def with_nmm( return Detections.merge(result) + def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections: + """ + Remap and filter detections to match a target dataset's class set. + + Args: + dataset: An object with a .classes attribute (list of class names). + class_mapping (dict, optional): Mapping from model class names to dataset class names. + + Returns: + Detections: A new Detections object with class names and IDs remapped and filtered. + """ + # Get class names for each detection + class_names = self.data.get("class_name") + if class_names is None: + raise ValueError("Detections must have 'class_name' in .data to use transform().") + class_names = np.array(class_names) + # Remap class names if mapping is provided + if class_mapping is not None: + class_names = np.array([class_mapping.get(name, name) for name in class_names]) + # Filter out detections whose class is not in dataset.classes + keep = np.isin(class_names, dataset.classes) + # Remap class_id to match dataset.classes + new_class_id = np.array([dataset.classes.index(name) for name in class_names[keep]]) + # Build new Detections object + return Detections( + xyxy=self.xyxy[keep], + mask=self.mask[keep] if self.mask is not None else None, + confidence=self.confidence[keep] if self.confidence is not None else None, + class_id=new_class_id, + tracker_id=self.tracker_id[keep] if self.tracker_id is not None else None, + data={k: (np.array(v)[keep] if isinstance(v, (list, np.ndarray)) and len(v) == len(self) else v) + for k, v in self.data.items()}, + metadata=self.metadata.copy(), + ) def merge_inner_detection_object_pair( detections_1: Detections, detections_2: Detections diff --git a/test/detection/test_transform.py b/test/detection/test_transform.py new file mode 100644 index 0000000000..ed66a5447b --- /dev/null +++ b/test/detection/test_transform.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest +from types import SimpleNamespace +from supervision.detection.core import Detections + +def test_transform_remap_and_filter(): + # Simulate a model that predicts 'dog', 'cat', 'eagle', 'car' + det = Detections( + xyxy=np.array([[0,0,1,1],[1,1,2,2],[2,2,3,3],[3,3,4,4]]), + class_id=np.array([0,1,2,3]), + confidence=np.array([0.9,0.8,0.7,0.6]), + data={"class_name": np.array(["dog","cat","eagle","car"])} + ) + # Dataset expects 'animal', 'bird', 'car' (in that order) + dataset = SimpleNamespace(classes=["animal","bird","car"]) + class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"} + det2 = det.transform(dataset, class_mapping=class_mapping) + # Only 'dog', 'cat', 'eagle', 'car' should remain, but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird' + assert set(det2.data["class_name"]) <= set(dataset.classes + ["car"]) + assert all([name in dataset.classes for name in det2.data["class_name"]]) + # class_id should be remapped to dataset.classes indices + for name, cid in zip(det2.data["class_name"], det2.class_id): + assert dataset.classes[cid] == name + # Only 'dog', 'cat', 'eagle', 'car' remain, but 'car' is already in dataset.classes + assert len(det2) == 4 + +def test_transform_no_class_mapping(): + det = Detections( + xyxy=np.array([[0,0,1,1],[1,1,2,2]]), + class_id=np.array([0,1]), + confidence=np.array([0.9,0.8]), + data={"class_name": np.array(["car","truck"])} + ) + dataset = SimpleNamespace(classes=["car"]) + det2 = det.transform(dataset) + assert len(det2) == 1 + assert det2.data["class_name"][0] == "car" + assert det2.class_id[0] == 0 + +def test_transform_raises_without_class_name(): + det = Detections( + xyxy=np.array([[0,0,1,1]]), + class_id=np.array([0]), + confidence=np.array([0.9]), + data={} + ) + dataset = SimpleNamespace(classes=["car"]) + with pytest.raises(ValueError): + det.transform(dataset) From 3e8238173cd119619f47347c188d2d3fd328973b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 May 2025 07:46:53 +0000 Subject: [PATCH 2/4] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/core.py | 23 ++++++++++++++++++----- test/detection/test_transform.py | 29 +++++++++++++++++------------ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 1bd23c3cc8..64f7d87476 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1449,15 +1449,21 @@ def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections # Get class names for each detection class_names = self.data.get("class_name") if class_names is None: - raise ValueError("Detections must have 'class_name' in .data to use transform().") + raise ValueError( + "Detections must have 'class_name' in .data to use transform()." + ) class_names = np.array(class_names) # Remap class names if mapping is provided if class_mapping is not None: - class_names = np.array([class_mapping.get(name, name) for name in class_names]) + class_names = np.array( + [class_mapping.get(name, name) for name in class_names] + ) # Filter out detections whose class is not in dataset.classes keep = np.isin(class_names, dataset.classes) # Remap class_id to match dataset.classes - new_class_id = np.array([dataset.classes.index(name) for name in class_names[keep]]) + new_class_id = np.array( + [dataset.classes.index(name) for name in class_names[keep]] + ) # Build new Detections object return Detections( xyxy=self.xyxy[keep], @@ -1465,11 +1471,18 @@ def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections confidence=self.confidence[keep] if self.confidence is not None else None, class_id=new_class_id, tracker_id=self.tracker_id[keep] if self.tracker_id is not None else None, - data={k: (np.array(v)[keep] if isinstance(v, (list, np.ndarray)) and len(v) == len(self) else v) - for k, v in self.data.items()}, + data={ + k: ( + np.array(v)[keep] + if isinstance(v, (list, np.ndarray)) and len(v) == len(self) + else v + ) + for k, v in self.data.items() + }, metadata=self.metadata.copy(), ) + def merge_inner_detection_object_pair( detections_1: Detections, detections_2: Detections ) -> Detections: diff --git a/test/detection/test_transform.py b/test/detection/test_transform.py index ed66a5447b..efdb75dd0c 100644 --- a/test/detection/test_transform.py +++ b/test/detection/test_transform.py @@ -1,18 +1,21 @@ +from types import SimpleNamespace + import numpy as np import pytest -from types import SimpleNamespace + from supervision.detection.core import Detections + def test_transform_remap_and_filter(): # Simulate a model that predicts 'dog', 'cat', 'eagle', 'car' det = Detections( - xyxy=np.array([[0,0,1,1],[1,1,2,2],[2,2,3,3],[3,3,4,4]]), - class_id=np.array([0,1,2,3]), - confidence=np.array([0.9,0.8,0.7,0.6]), - data={"class_name": np.array(["dog","cat","eagle","car"])} + xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2], [2, 2, 3, 3], [3, 3, 4, 4]]), + class_id=np.array([0, 1, 2, 3]), + confidence=np.array([0.9, 0.8, 0.7, 0.6]), + data={"class_name": np.array(["dog", "cat", "eagle", "car"])}, ) # Dataset expects 'animal', 'bird', 'car' (in that order) - dataset = SimpleNamespace(classes=["animal","bird","car"]) + dataset = SimpleNamespace(classes=["animal", "bird", "car"]) class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"} det2 = det.transform(dataset, class_mapping=class_mapping) # Only 'dog', 'cat', 'eagle', 'car' should remain, but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird' @@ -24,12 +27,13 @@ def test_transform_remap_and_filter(): # Only 'dog', 'cat', 'eagle', 'car' remain, but 'car' is already in dataset.classes assert len(det2) == 4 + def test_transform_no_class_mapping(): det = Detections( - xyxy=np.array([[0,0,1,1],[1,1,2,2]]), - class_id=np.array([0,1]), - confidence=np.array([0.9,0.8]), - data={"class_name": np.array(["car","truck"])} + xyxy=np.array([[0, 0, 1, 1], [1, 1, 2, 2]]), + class_id=np.array([0, 1]), + confidence=np.array([0.9, 0.8]), + data={"class_name": np.array(["car", "truck"])}, ) dataset = SimpleNamespace(classes=["car"]) det2 = det.transform(dataset) @@ -37,12 +41,13 @@ def test_transform_no_class_mapping(): assert det2.data["class_name"][0] == "car" assert det2.class_id[0] == 0 + def test_transform_raises_without_class_name(): det = Detections( - xyxy=np.array([[0,0,1,1]]), + xyxy=np.array([[0, 0, 1, 1]]), class_id=np.array([0]), confidence=np.array([0.9]), - data={} + data={}, ) dataset = SimpleNamespace(classes=["car"]) with pytest.raises(ValueError): From a7d6f42727d2b91f67b39d4ecfd1b3ac21ca8283 Mon Sep 17 00:00:00 2001 From: AshAnand34 Date: Thu, 15 May 2025 00:50:26 -0700 Subject: [PATCH 3/4] Fixing ruff errors --- supervision/detection/core.py | 6 ++++-- test/detection/test_transform.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 64f7d87476..b3d968632e 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -1441,10 +1441,12 @@ def transform(self, dataset, class_mapping: Optional[dict] = None) -> Detections Args: dataset: An object with a .classes attribute (list of class names). - class_mapping (dict, optional): Mapping from model class names to dataset class names. + class_mapping (dict, optional): Mapping from model class names to + dataset class names. Returns: - Detections: A new Detections object with class names and IDs remapped and filtered. + Detections: A new Detections object with class names and IDs + remapped and filtered. """ # Get class names for each detection class_names = self.data.get("class_name") diff --git a/test/detection/test_transform.py b/test/detection/test_transform.py index efdb75dd0c..89525897a3 100644 --- a/test/detection/test_transform.py +++ b/test/detection/test_transform.py @@ -18,8 +18,9 @@ def test_transform_remap_and_filter(): dataset = SimpleNamespace(classes=["animal", "bird", "car"]) class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"} det2 = det.transform(dataset, class_mapping=class_mapping) - # Only 'dog', 'cat', 'eagle', 'car' should remain, but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird' - assert set(det2.data["class_name"]) <= set(dataset.classes + ["car"]) + # Only 'dog', 'cat', 'eagle', 'car' should remain, but 'dog' and 'cat' become 'animal', + # 'eagle' becomes 'bird' + assert set(det2.data["class_name"]) <= set([*dataset.classes, "car"]) assert all([name in dataset.classes for name in det2.data["class_name"]]) # class_id should be remapped to dataset.classes indices for name, cid in zip(det2.data["class_name"], det2.class_id): From 71abd9692e2b5287e93f2f2d69d2f561ee363425 Mon Sep 17 00:00:00 2001 From: AshAnand34 Date: Thu, 15 May 2025 00:51:47 -0700 Subject: [PATCH 4/4] Minor ruff error fix --- test/detection/test_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/detection/test_transform.py b/test/detection/test_transform.py index 89525897a3..ed3c9c6af7 100644 --- a/test/detection/test_transform.py +++ b/test/detection/test_transform.py @@ -18,8 +18,8 @@ def test_transform_remap_and_filter(): dataset = SimpleNamespace(classes=["animal", "bird", "car"]) class_mapping = {"dog": "animal", "cat": "animal", "eagle": "bird"} det2 = det.transform(dataset, class_mapping=class_mapping) - # Only 'dog', 'cat', 'eagle', 'car' should remain, but 'dog' and 'cat' become 'animal', - # 'eagle' becomes 'bird' + # Only 'dog', 'cat', 'eagle', 'car' should remain, + # but 'dog' and 'cat' become 'animal', 'eagle' becomes 'bird' assert set(det2.data["class_name"]) <= set([*dataset.classes, "car"]) assert all([name in dataset.classes for name in det2.data["class_name"]]) # class_id should be remapped to dataset.classes indices