Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions daemon/src/daemon/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

__all__: list[str] = []
61 changes: 61 additions & 0 deletions daemon/src/daemon/tools/annotation_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from pydantic import Field, model_validator

from daemon.contracts._base import ContractModel

if TYPE_CHECKING:
from pathlib import Path

EventType = Literal["bell", "multi_voice", "quiet", "speech", "noise"]


class AnnotationEvent(ContractModel): # type: ignore[explicit-any]
time_s: float = Field(ge=0.0)
end_s: float | None = None
type: EventType
notes: str = ""

@model_validator(mode="after")
def _validate_end_s(self) -> AnnotationEvent:
if self.end_s is not None and self.end_s <= self.time_s:
msg = "end_s must be strictly greater than time_s"
raise ValueError(msg)
return self


class RecordingAnnotation(ContractModel): # type: ignore[explicit-any]
"""Staff-annotated ground truth for a single recording session.

``events`` may be empty when a recording has no annotated events
(e.g. a silence-only baseline). Consumers must handle zero-length
lists gracefully to avoid divide-by-zero in metric aggregation.
"""

recording_id: str
events: list[AnnotationEvent]
Comment thread
ToaruPen marked this conversation as resolved.


def save_annotation(annotation: RecordingAnnotation, path: Path) -> None:
"""Serialize a RecordingAnnotation to a JSON file."""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(annotation.model_dump_json(indent=2), encoding="utf-8")


def load_annotation(path: Path) -> RecordingAnnotation:
"""Deserialize a RecordingAnnotation from a JSON file."""
if not path.exists():
msg = f"Annotation file not found: {path}"
raise FileNotFoundError(msg)
return RecordingAnnotation.model_validate_json(path.read_text(encoding="utf-8"))


__all__ = [
"AnnotationEvent",
"EventType",
"RecordingAnnotation",
"load_annotation",
"save_annotation",
]
229 changes: 229 additions & 0 deletions daemon/src/daemon/tools/evaluate_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from __future__ import annotations

import argparse
import logging
import sys
import wave
from collections.abc import Callable
from pathlib import Path

import numpy as np
import numpy.typing as npt
from pydantic import Field

from daemon.contracts._base import ContractModel
from daemon.tools.annotation_schema import RecordingAnnotation, load_annotation

FloatArray = npt.NDArray[np.float64]
POSITIVE_EVENT_TYPES = frozenset({"bell", "speech", "multi_voice"})
UINT8_OFFSET = 128.0
PCM16_SAMPWIDTH = 2
PCM32_SAMPWIDTH = 4
PCM16_SCALE = 32768.0
PCM32_SCALE = 2147483648.0
LOGGER = logging.getLogger(__name__)


class DetectionMetrics(ContractModel): # type: ignore[explicit-any]
precision: float = Field(ge=0.0, le=1.0)
recall: float = Field(ge=0.0, le=1.0)
f1: float = Field(ge=0.0, le=1.0)
false_positive_rate: float = Field(ge=0.0, le=1.0)


class EvaluationResult(ContractModel): # type: ignore[explicit-any]
recording_id: str
metrics: DetectionMetrics
event_count: int
detection_latencies_ms: list[float] = Field(default_factory=list)
false_positive_times_s: list[float] = Field(default_factory=list)
vad_overlap_ratios: list[float] = Field(default_factory=list)


class DetectionEventResult(ContractModel): # type: ignore[explicit-any]
detected: bool
detected_time_s: float | None = Field(default=None, ge=0.0)
confidence: float | None = Field(default=None, ge=0.0, le=1.0)
matched_annotation_index: int | None = Field(default=None, ge=0)
detection_latency_ms: float | None = Field(default=None, ge=0.0)
vad_overlap_ratio: float | None = Field(default=None, ge=0.0, le=1.0)


DetectorCallable = Callable[[FloatArray, int, RecordingAnnotation], list[DetectionEventResult]]


def compute_metrics(predicted: list[bool], ground_truth: list[bool]) -> DetectionMetrics:
"""Compute precision, recall, F1, and FPR from boolean lists."""
if len(predicted) != len(ground_truth):
msg = (
f"predicted and ground_truth must have the same length, "
f"got {len(predicted)} and {len(ground_truth)}"
)
raise ValueError(msg)

tp = 0
fp = 0
fn = 0
tn = 0
for predicted_flag, ground_truth_flag in zip(predicted, ground_truth, strict=True):
if predicted_flag and ground_truth_flag:
tp += 1
elif predicted_flag:
fp += 1
elif ground_truth_flag:
fn += 1
else:
tn += 1

precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0

return DetectionMetrics(
precision=precision,
recall=recall,
f1=f1,
false_positive_rate=fpr,
)


def load_recording_pair(
wav_path: Path,
ann_path: Path,
) -> tuple[FloatArray, int, RecordingAnnotation]:
"""Load a WAV file and its annotation, returning samples, sample_rate, annotation."""
if not wav_path.exists():
msg = f"WAV file not found: {wav_path}"
raise FileNotFoundError(msg)

with wave.open(str(wav_path), "rb") as wf:
sample_rate = wf.getframerate()
n_frames = wf.getnframes()
raw = wf.readframes(n_frames)
sampwidth = wf.getsampwidth()

if sampwidth == 1:
pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float64) - UINT8_OFFSET) / UINT8_OFFSET
elif sampwidth == PCM16_SAMPWIDTH:
pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float64) / PCM16_SCALE
elif sampwidth == PCM32_SAMPWIDTH:
pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float64) / PCM32_SCALE
else:
msg = f"Unsupported WAV sample width: {sampwidth} bytes"
raise ValueError(msg)
Comment thread
ToaruPen marked this conversation as resolved.

annotation = load_annotation(ann_path)
return pcm, sample_rate, annotation


def evaluate_recording(
wav_path: Path,
ann_path: Path,
*,
detector: DetectorCallable,
) -> EvaluationResult:
"""Run detector over a recording and compute metrics against annotations."""
samples, sample_rate, annotation = load_recording_pair(wav_path, ann_path)

detected_events = detector(samples, sample_rate, annotation)
ground_truth = [event.type in POSITIVE_EVENT_TYPES for event in annotation.events]
detected_flags = [event.detected for event in detected_events]

metrics = compute_metrics(detected_flags, ground_truth)
return EvaluationResult(
recording_id=annotation.recording_id,
metrics=metrics,
event_count=len(annotation.events),
detection_latencies_ms=[
event.detection_latency_ms
for event in detected_events
if event.detection_latency_ms is not None
],
false_positive_times_s=[
event.detected_time_s
for event in detected_events
if (
event.detected
and event.matched_annotation_index is None
and event.detected_time_s is not None
)
],
vad_overlap_ratios=[
event.vad_overlap_ratio
for event in detected_events
if event.vad_overlap_ratio is not None
],
)


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Evaluate sound detection accuracy against annotated recordings."
)
parser.add_argument(
"--recordings",
type=Path,
default=Path("data/evaluation"),
help="Directory containing WAV and JSON annotation pairs (default: data/evaluation)",
)
return parser


def _main(argv: list[str] | None = None) -> int:
parser = _build_parser()
args = parser.parse_args(argv)
recordings_dir: Path = args.recordings

pairs: list[tuple[Path, Path]] = []
for wav_path in sorted(recordings_dir.glob("*.wav")):
ann_path = wav_path.with_suffix(".json")
if ann_path.exists():
pairs.append((wav_path, ann_path))

if not pairs:
return 1

def _placeholder_detector(
_samples: FloatArray,
_sample_rate: int,
annotation: RecordingAnnotation,
) -> list[DetectionEventResult]:
return [
DetectionEventResult(
detected=False,
matched_annotation_index=index,
)
for index, _event in enumerate(annotation.events)
]

evaluated_any = False
for wav_path, ann_path in pairs:
try:
result = evaluate_recording(
wav_path,
ann_path,
detector=_placeholder_detector,
)
except Exception:
LOGGER.exception("Failed to evaluate recording pair: %s / %s", wav_path, ann_path)
continue
Comment thread
ToaruPen marked this conversation as resolved.
sys.stdout.write(f"{result.model_dump_json()}\n")
evaluated_any = True

return 0 if evaluated_any else 1
Comment thread
ToaruPen marked this conversation as resolved.


if __name__ == "__main__": # pragma: no cover
sys.exit(_main())


__all__ = [
"DetectionEventResult",
"DetectionMetrics",
"DetectorCallable",
"EvaluationResult",
"compute_metrics",
"evaluate_recording",
"load_recording_pair",
]
Loading
Loading