diff --git a/src/supervision/detection/tools/csv_sink.py b/src/supervision/detection/tools/csv_sink.py index d06d896692..b242636245 100644 --- a/src/supervision/detection/tools/csv_sink.py +++ b/src/supervision/detection/tools/csv_sink.py @@ -144,7 +144,13 @@ def parse_detection_data( row[key] = value[i] if hasattr(value, "__getitem__") else value if custom_data: - row.update(custom_data) + for key, value in custom_data.items(): + if isinstance(value, np.ndarray) and value.ndim == 0: + row[key] = value + elif isinstance(value, np.ndarray): + row[key] = value[i] + else: + row[key] = value parsed_rows.append(row) return parsed_rows diff --git a/src/supervision/detection/tools/json_sink.py b/src/supervision/detection/tools/json_sink.py index 5a34e16eaa..4cdbb30354 100644 --- a/src/supervision/detection/tools/json_sink.py +++ b/src/supervision/detection/tools/json_sink.py @@ -118,7 +118,13 @@ def parse_detection_data( ) if custom_data: - row.update(custom_data) + for key, value in custom_data.items(): + if isinstance(value, np.ndarray) and value.ndim == 0: + row[key] = str(value) + elif isinstance(value, np.ndarray): + row[key] = str(value[i]) + else: + row[key] = value parsed_rows.append(row) return parsed_rows diff --git a/tests/detection/test_csv.py b/tests/detection/test_csv.py index 12d52058e7..127352e99a 100644 --- a/tests/detection/test_csv.py +++ b/tests/detection/test_csv.py @@ -2,6 +2,7 @@ import os from typing import Any +import numpy as np import pytest import supervision as sv @@ -193,6 +194,36 @@ ], ], ), # Complex Data + ( + _create_detections( + xyxy=[[10, 20, 30, 40], [50, 60, 70, 80]], + confidence=[0.9, 0.8], + class_id=[0, 1], + ), + {"area": np.array([400.0, 400.0])}, + _create_detections( + xyxy=[[15, 25, 35, 45]], + confidence=[0.7], + class_id=[2], + ), + {"area": np.array([400.0])}, + "test_detections_array_custom_data.csv", + [ + [ + "x_min", + "y_min", + "x_max", + "y_max", + "class_id", + "confidence", + "tracker_id", + "area", + ], + ["10.0", "20.0", "30.0", "40.0", "0", "0.9", "", "400.0"], + ["50.0", "60.0", "70.0", "80.0", "1", "0.8", "", "400.0"], + ["15.0", "25.0", "35.0", "45.0", "2", "0.7", "", "400.0"], + ], + ), # numpy array in custom_data sliced per detection row ], ) def test_csv_sink(