-
Notifications
You must be signed in to change notification settings - Fork 0
feat(tools): in-situ recording evaluation protocol #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2cb1cb1
feat(tools): in-situ recording evaluation protocol with annotation, m…
ToaruPen b9acb1c
fix(tools): address PR review findings — coverage, WAV handling, grou…
ToaruPen 99b3e4b
fix(tools): fail-fast on non-positive duration_s in record_audio
ToaruPen 2fb2b73
fix(tools): address remaining review findings
ToaruPen d73e9b3
fix(tools): align recording tool runtime paths
ToaruPen 51de3e8
ci: trigger CI re-run
ToaruPen baaf8bd
style: format test_record_evaluation.py
ToaruPen f3cee0a
fix(tools): remove test wrapper functions, add annotation docstring a…
ToaruPen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from __future__ import annotations | ||
|
|
||
| __all__: list[str] = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| from pydantic import Field, model_validator | ||
|
|
||
| from daemon.contracts._base import ContractModel | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pathlib import Path | ||
|
|
||
| EventType = Literal["bell", "multi_voice", "quiet", "speech", "noise"] | ||
|
|
||
|
|
||
| class AnnotationEvent(ContractModel): # type: ignore[explicit-any] | ||
| time_s: float = Field(ge=0.0) | ||
| end_s: float | None = None | ||
| type: EventType | ||
| notes: str = "" | ||
|
|
||
| @model_validator(mode="after") | ||
| def _validate_end_s(self) -> AnnotationEvent: | ||
| if self.end_s is not None and self.end_s <= self.time_s: | ||
| msg = "end_s must be strictly greater than time_s" | ||
| raise ValueError(msg) | ||
| return self | ||
|
|
||
|
|
||
| class RecordingAnnotation(ContractModel): # type: ignore[explicit-any] | ||
| """Staff-annotated ground truth for a single recording session. | ||
|
|
||
| ``events`` may be empty when a recording has no annotated events | ||
| (e.g. a silence-only baseline). Consumers must handle zero-length | ||
| lists gracefully to avoid divide-by-zero in metric aggregation. | ||
| """ | ||
|
|
||
| recording_id: str | ||
| events: list[AnnotationEvent] | ||
|
|
||
|
|
||
| def save_annotation(annotation: RecordingAnnotation, path: Path) -> None: | ||
| """Serialize a RecordingAnnotation to a JSON file.""" | ||
| path.parent.mkdir(parents=True, exist_ok=True) | ||
| path.write_text(annotation.model_dump_json(indent=2), encoding="utf-8") | ||
|
|
||
|
|
||
| def load_annotation(path: Path) -> RecordingAnnotation: | ||
| """Deserialize a RecordingAnnotation from a JSON file.""" | ||
| if not path.exists(): | ||
| msg = f"Annotation file not found: {path}" | ||
| raise FileNotFoundError(msg) | ||
| return RecordingAnnotation.model_validate_json(path.read_text(encoding="utf-8")) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "AnnotationEvent", | ||
| "EventType", | ||
| "RecordingAnnotation", | ||
| "load_annotation", | ||
| "save_annotation", | ||
| ] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,229 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import logging | ||
| import sys | ||
| import wave | ||
| from collections.abc import Callable | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import numpy.typing as npt | ||
| from pydantic import Field | ||
|
|
||
| from daemon.contracts._base import ContractModel | ||
| from daemon.tools.annotation_schema import RecordingAnnotation, load_annotation | ||
|
|
||
| FloatArray = npt.NDArray[np.float64] | ||
| POSITIVE_EVENT_TYPES = frozenset({"bell", "speech", "multi_voice"}) | ||
| UINT8_OFFSET = 128.0 | ||
| PCM16_SAMPWIDTH = 2 | ||
| PCM32_SAMPWIDTH = 4 | ||
| PCM16_SCALE = 32768.0 | ||
| PCM32_SCALE = 2147483648.0 | ||
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class DetectionMetrics(ContractModel): # type: ignore[explicit-any] | ||
| precision: float = Field(ge=0.0, le=1.0) | ||
| recall: float = Field(ge=0.0, le=1.0) | ||
| f1: float = Field(ge=0.0, le=1.0) | ||
| false_positive_rate: float = Field(ge=0.0, le=1.0) | ||
|
|
||
|
|
||
| class EvaluationResult(ContractModel): # type: ignore[explicit-any] | ||
| recording_id: str | ||
| metrics: DetectionMetrics | ||
| event_count: int | ||
| detection_latencies_ms: list[float] = Field(default_factory=list) | ||
| false_positive_times_s: list[float] = Field(default_factory=list) | ||
| vad_overlap_ratios: list[float] = Field(default_factory=list) | ||
|
|
||
|
|
||
| class DetectionEventResult(ContractModel): # type: ignore[explicit-any] | ||
| detected: bool | ||
| detected_time_s: float | None = Field(default=None, ge=0.0) | ||
| confidence: float | None = Field(default=None, ge=0.0, le=1.0) | ||
| matched_annotation_index: int | None = Field(default=None, ge=0) | ||
| detection_latency_ms: float | None = Field(default=None, ge=0.0) | ||
| vad_overlap_ratio: float | None = Field(default=None, ge=0.0, le=1.0) | ||
|
|
||
|
|
||
| DetectorCallable = Callable[[FloatArray, int, RecordingAnnotation], list[DetectionEventResult]] | ||
|
|
||
|
|
||
| def compute_metrics(predicted: list[bool], ground_truth: list[bool]) -> DetectionMetrics: | ||
| """Compute precision, recall, F1, and FPR from boolean lists.""" | ||
| if len(predicted) != len(ground_truth): | ||
| msg = ( | ||
| f"predicted and ground_truth must have the same length, " | ||
| f"got {len(predicted)} and {len(ground_truth)}" | ||
| ) | ||
| raise ValueError(msg) | ||
|
|
||
| tp = 0 | ||
| fp = 0 | ||
| fn = 0 | ||
| tn = 0 | ||
| for predicted_flag, ground_truth_flag in zip(predicted, ground_truth, strict=True): | ||
| if predicted_flag and ground_truth_flag: | ||
| tp += 1 | ||
| elif predicted_flag: | ||
| fp += 1 | ||
| elif ground_truth_flag: | ||
| fn += 1 | ||
| else: | ||
| tn += 1 | ||
|
|
||
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | ||
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | ||
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | ||
| fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 | ||
|
|
||
| return DetectionMetrics( | ||
| precision=precision, | ||
| recall=recall, | ||
| f1=f1, | ||
| false_positive_rate=fpr, | ||
| ) | ||
|
|
||
|
|
||
| def load_recording_pair( | ||
| wav_path: Path, | ||
| ann_path: Path, | ||
| ) -> tuple[FloatArray, int, RecordingAnnotation]: | ||
| """Load a WAV file and its annotation, returning samples, sample_rate, annotation.""" | ||
| if not wav_path.exists(): | ||
| msg = f"WAV file not found: {wav_path}" | ||
| raise FileNotFoundError(msg) | ||
|
|
||
| with wave.open(str(wav_path), "rb") as wf: | ||
| sample_rate = wf.getframerate() | ||
| n_frames = wf.getnframes() | ||
| raw = wf.readframes(n_frames) | ||
| sampwidth = wf.getsampwidth() | ||
|
|
||
| if sampwidth == 1: | ||
| pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float64) - UINT8_OFFSET) / UINT8_OFFSET | ||
| elif sampwidth == PCM16_SAMPWIDTH: | ||
| pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float64) / PCM16_SCALE | ||
| elif sampwidth == PCM32_SAMPWIDTH: | ||
| pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float64) / PCM32_SCALE | ||
| else: | ||
| msg = f"Unsupported WAV sample width: {sampwidth} bytes" | ||
| raise ValueError(msg) | ||
|
ToaruPen marked this conversation as resolved.
|
||
|
|
||
| annotation = load_annotation(ann_path) | ||
| return pcm, sample_rate, annotation | ||
|
|
||
|
|
||
| def evaluate_recording( | ||
| wav_path: Path, | ||
| ann_path: Path, | ||
| *, | ||
| detector: DetectorCallable, | ||
| ) -> EvaluationResult: | ||
| """Run detector over a recording and compute metrics against annotations.""" | ||
| samples, sample_rate, annotation = load_recording_pair(wav_path, ann_path) | ||
|
|
||
| detected_events = detector(samples, sample_rate, annotation) | ||
| ground_truth = [event.type in POSITIVE_EVENT_TYPES for event in annotation.events] | ||
| detected_flags = [event.detected for event in detected_events] | ||
|
|
||
| metrics = compute_metrics(detected_flags, ground_truth) | ||
| return EvaluationResult( | ||
| recording_id=annotation.recording_id, | ||
| metrics=metrics, | ||
| event_count=len(annotation.events), | ||
| detection_latencies_ms=[ | ||
| event.detection_latency_ms | ||
| for event in detected_events | ||
| if event.detection_latency_ms is not None | ||
| ], | ||
| false_positive_times_s=[ | ||
| event.detected_time_s | ||
| for event in detected_events | ||
| if ( | ||
| event.detected | ||
| and event.matched_annotation_index is None | ||
| and event.detected_time_s is not None | ||
| ) | ||
| ], | ||
| vad_overlap_ratios=[ | ||
| event.vad_overlap_ratio | ||
| for event in detected_events | ||
| if event.vad_overlap_ratio is not None | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| def _build_parser() -> argparse.ArgumentParser: | ||
| parser = argparse.ArgumentParser( | ||
| description="Evaluate sound detection accuracy against annotated recordings." | ||
| ) | ||
| parser.add_argument( | ||
| "--recordings", | ||
| type=Path, | ||
| default=Path("data/evaluation"), | ||
| help="Directory containing WAV and JSON annotation pairs (default: data/evaluation)", | ||
| ) | ||
| return parser | ||
|
|
||
|
|
||
| def _main(argv: list[str] | None = None) -> int: | ||
| parser = _build_parser() | ||
| args = parser.parse_args(argv) | ||
| recordings_dir: Path = args.recordings | ||
|
|
||
| pairs: list[tuple[Path, Path]] = [] | ||
| for wav_path in sorted(recordings_dir.glob("*.wav")): | ||
| ann_path = wav_path.with_suffix(".json") | ||
| if ann_path.exists(): | ||
| pairs.append((wav_path, ann_path)) | ||
|
|
||
| if not pairs: | ||
| return 1 | ||
|
|
||
| def _placeholder_detector( | ||
| _samples: FloatArray, | ||
| _sample_rate: int, | ||
| annotation: RecordingAnnotation, | ||
| ) -> list[DetectionEventResult]: | ||
| return [ | ||
| DetectionEventResult( | ||
| detected=False, | ||
| matched_annotation_index=index, | ||
| ) | ||
| for index, _event in enumerate(annotation.events) | ||
| ] | ||
|
|
||
| evaluated_any = False | ||
| for wav_path, ann_path in pairs: | ||
| try: | ||
| result = evaluate_recording( | ||
| wav_path, | ||
| ann_path, | ||
| detector=_placeholder_detector, | ||
| ) | ||
| except Exception: | ||
| LOGGER.exception("Failed to evaluate recording pair: %s / %s", wav_path, ann_path) | ||
| continue | ||
|
ToaruPen marked this conversation as resolved.
|
||
| sys.stdout.write(f"{result.model_dump_json()}\n") | ||
| evaluated_any = True | ||
|
|
||
| return 0 if evaluated_any else 1 | ||
|
ToaruPen marked this conversation as resolved.
|
||
|
|
||
|
|
||
| if __name__ == "__main__": # pragma: no cover | ||
| sys.exit(_main()) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "DetectionEventResult", | ||
| "DetectionMetrics", | ||
| "DetectorCallable", | ||
| "EvaluationResult", | ||
| "compute_metrics", | ||
| "evaluate_recording", | ||
| "load_recording_pair", | ||
| ] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.