Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions BenchmarkEvaluator_Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "70868bca",
"metadata": {},
"source": [
"# 🎯 BenchmarkEvaluator Demo\n",
"\n",
"This notebook demonstrates how to use `BenchmarkEvaluator` to compute precision/recall metrics for object detection tasks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ee3b103",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from supervision.detection.core import Detections\n",
"from supervision.metrics.benchmark import BenchmarkEvaluator"
]
},
{
"cell_type": "markdown",
"id": "f806eff5",
"metadata": {},
"source": [
"## Step 1: Create Ground Truth and Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65183606",
"metadata": {},
"outputs": [],
"source": [
"# Ground truth with 2 boxes\n",
"gt = Detections(\n",
" xyxy=np.array([[10, 10, 100, 100], [150, 150, 300, 300]]), class_id=np.array([0, 1])\n",
")\n",
"\n",
"# Predictions: One perfect match, one wrong class\n",
"pred = Detections(\n",
" xyxy=np.array([[10, 10, 100, 100], [150, 150, 300, 300]]), class_id=np.array([0, 2])\n",
")"
]
},
{
"cell_type": "markdown",
"id": "529f0ef0",
"metadata": {},
"source": [
"## Step 2: Run BenchmarkEvaluator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5615d704",
"metadata": {},
"outputs": [],
"source": [
"evaluator = BenchmarkEvaluator(ground_truth=gt, predictions=pred)\n",
"metrics = evaluator.compute_precision_recall()\n",
"print(\"Precision:\", metrics[\"precision\"])\n",
"print(\"Recall:\", metrics[\"recall\"])"
]
},
{
"cell_type": "markdown",
"id": "9ab6f923",
"metadata": {},
"source": [
"## Step 3: Per-Class Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dde2bc49",
"metadata": {},
"outputs": [],
"source": [
"per_class = evaluator.compute_precision_recall_per_class()\n",
"for cls, metric in per_class.items():\n",
" print(\n",
" f\"Class {cls} - Precision: {metric['precision']:.2f}, Recall: {metric['recall']:.2f}\"\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "dfa1f1e5",
"metadata": {},
"source": [
"## Step 4: Visualize Bounding Boxes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6a6ce9d",
"metadata": {},
"outputs": [],
"source": [
"def draw_boxes(image, detections, color, label):\n",
" for box, cls in zip(detections.xyxy, detections.class_id):\n",
" x1, y1, x2, y2 = box.astype(int)\n",
" cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)\n",
" cv2.putText(\n",
" image,\n",
" f\"{label}:{cls}\",\n",
" (x1, y1 - 10),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.5,\n",
" color,\n",
" 2,\n",
" )\n",
"\n",
"\n",
"canvas = np.ones((350, 350, 3), dtype=np.uint8) * 255\n",
"draw_boxes(canvas, gt, (0, 255, 0), \"GT\")\n",
"draw_boxes(canvas, pred, (0, 0, 255), \"Pred\")\n",
"\n",
"plt.imshow(canvas[..., ::-1])\n",
"plt.title(\"Ground Truth (Green) vs Prediction (Red)\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7b3d6112",
"metadata": {},
"source": [
"🎉 That's it! You've run a complete object detection benchmark with precision/recall metrics and visualization."
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
40 changes: 40 additions & 0 deletions supervision/metrics/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# supervision/metrics/benchmark.py

from typing import Dict, Optional

from supervision.detection.core import Detections


class BenchmarkEvaluator:
def __init__(
self,
ground_truth: Detections,
predictions: Detections,
class_map: Optional[Dict[str, str]] = None,
iou_threshold: float = 0.5,
):
self.ground_truth = ground_truth
self.predictions = predictions
self.class_map = class_map or {}
self.iou_threshold = iou_threshold

def compute_precision_recall(self) -> Dict[str, float]:
"""
Compute basic precision and recall metrics.
For demo purposes — you will expand this.
"""
# TODO: Add class alignment, matching using IoU
tp = len(self.predictions.xyxy) # Placeholder
fp = 0
fn = len(self.ground_truth.xyxy) - tp

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0

return {"precision": precision, "recall": recall}

def summary(self) -> None:
metrics = self.compute_precision_recall()
print("Benchmark Summary:")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
15 changes: 15 additions & 0 deletions tests/metrics/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.metrics.benchmark import BenchmarkEvaluator


def test_basic_precision_recall():
gt = Detections(xyxy=np.array([[0, 0, 100, 100]]), class_id=np.array([0]))
pred = Detections(xyxy=np.array([[0, 0, 100, 100]]), class_id=np.array([0]))

evaluator = BenchmarkEvaluator(ground_truth=gt, predictions=pred)
metrics = evaluator.compute_precision_recall()

assert metrics["precision"] == 1.0
assert metrics["recall"] == 1.0