diff --git a/daemon/src/daemon/tools/__init__.py b/daemon/src/daemon/tools/__init__.py new file mode 100644 index 00000000..bdec2fc8 --- /dev/null +++ b/daemon/src/daemon/tools/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +__all__: list[str] = [] diff --git a/daemon/src/daemon/tools/annotation_schema.py b/daemon/src/daemon/tools/annotation_schema.py new file mode 100644 index 00000000..7d76c0a2 --- /dev/null +++ b/daemon/src/daemon/tools/annotation_schema.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from pydantic import Field, model_validator + +from daemon.contracts._base import ContractModel + +if TYPE_CHECKING: + from pathlib import Path + +EventType = Literal["bell", "multi_voice", "quiet", "speech", "noise"] + + +class AnnotationEvent(ContractModel): # type: ignore[explicit-any] + time_s: float = Field(ge=0.0) + end_s: float | None = None + type: EventType + notes: str = "" + + @model_validator(mode="after") + def _validate_end_s(self) -> AnnotationEvent: + if self.end_s is not None and self.end_s <= self.time_s: + msg = "end_s must be strictly greater than time_s" + raise ValueError(msg) + return self + + +class RecordingAnnotation(ContractModel): # type: ignore[explicit-any] + """Staff-annotated ground truth for a single recording session. + + ``events`` may be empty when a recording has no annotated events + (e.g. a silence-only baseline). Consumers must handle zero-length + lists gracefully to avoid divide-by-zero in metric aggregation. + """ + + recording_id: str + events: list[AnnotationEvent] + + +def save_annotation(annotation: RecordingAnnotation, path: Path) -> None: + """Serialize a RecordingAnnotation to a JSON file.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(annotation.model_dump_json(indent=2), encoding="utf-8") + + +def load_annotation(path: Path) -> RecordingAnnotation: + """Deserialize a RecordingAnnotation from a JSON file.""" + if not path.exists(): + msg = f"Annotation file not found: {path}" + raise FileNotFoundError(msg) + return RecordingAnnotation.model_validate_json(path.read_text(encoding="utf-8")) + + +__all__ = [ + "AnnotationEvent", + "EventType", + "RecordingAnnotation", + "load_annotation", + "save_annotation", +] diff --git a/daemon/src/daemon/tools/evaluate_detection.py b/daemon/src/daemon/tools/evaluate_detection.py new file mode 100644 index 00000000..be3f9c94 --- /dev/null +++ b/daemon/src/daemon/tools/evaluate_detection.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import argparse +import logging +import sys +import wave +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import numpy.typing as npt +from pydantic import Field + +from daemon.contracts._base import ContractModel +from daemon.tools.annotation_schema import RecordingAnnotation, load_annotation + +FloatArray = npt.NDArray[np.float64] +POSITIVE_EVENT_TYPES = frozenset({"bell", "speech", "multi_voice"}) +UINT8_OFFSET = 128.0 +PCM16_SAMPWIDTH = 2 +PCM32_SAMPWIDTH = 4 +PCM16_SCALE = 32768.0 +PCM32_SCALE = 2147483648.0 +LOGGER = logging.getLogger(__name__) + + +class DetectionMetrics(ContractModel): # type: ignore[explicit-any] + precision: float = Field(ge=0.0, le=1.0) + recall: float = Field(ge=0.0, le=1.0) + f1: float = Field(ge=0.0, le=1.0) + false_positive_rate: float = Field(ge=0.0, le=1.0) + + +class EvaluationResult(ContractModel): # type: ignore[explicit-any] + recording_id: str + metrics: DetectionMetrics + event_count: int + detection_latencies_ms: list[float] = Field(default_factory=list) + false_positive_times_s: list[float] = Field(default_factory=list) + vad_overlap_ratios: list[float] = Field(default_factory=list) + + +class DetectionEventResult(ContractModel): # type: ignore[explicit-any] + detected: bool + detected_time_s: float | None = Field(default=None, ge=0.0) + confidence: float | None = Field(default=None, ge=0.0, le=1.0) + matched_annotation_index: int | None = Field(default=None, ge=0) + detection_latency_ms: float | None = Field(default=None, ge=0.0) + vad_overlap_ratio: float | None = Field(default=None, ge=0.0, le=1.0) + + +DetectorCallable = Callable[[FloatArray, int, RecordingAnnotation], list[DetectionEventResult]] + + +def compute_metrics(predicted: list[bool], ground_truth: list[bool]) -> DetectionMetrics: + """Compute precision, recall, F1, and FPR from boolean lists.""" + if len(predicted) != len(ground_truth): + msg = ( + f"predicted and ground_truth must have the same length, " + f"got {len(predicted)} and {len(ground_truth)}" + ) + raise ValueError(msg) + + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for predicted_flag, ground_truth_flag in zip(predicted, ground_truth, strict=True): + if predicted_flag and ground_truth_flag: + tp += 1 + elif predicted_flag: + fp += 1 + elif ground_truth_flag: + fn += 1 + else: + tn += 1 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0 + + return DetectionMetrics( + precision=precision, + recall=recall, + f1=f1, + false_positive_rate=fpr, + ) + + +def load_recording_pair( + wav_path: Path, + ann_path: Path, +) -> tuple[FloatArray, int, RecordingAnnotation]: + """Load a WAV file and its annotation, returning samples, sample_rate, annotation.""" + if not wav_path.exists(): + msg = f"WAV file not found: {wav_path}" + raise FileNotFoundError(msg) + + with wave.open(str(wav_path), "rb") as wf: + sample_rate = wf.getframerate() + n_frames = wf.getnframes() + raw = wf.readframes(n_frames) + sampwidth = wf.getsampwidth() + + if sampwidth == 1: + pcm = (np.frombuffer(raw, dtype=np.uint8).astype(np.float64) - UINT8_OFFSET) / UINT8_OFFSET + elif sampwidth == PCM16_SAMPWIDTH: + pcm = np.frombuffer(raw, dtype=np.int16).astype(np.float64) / PCM16_SCALE + elif sampwidth == PCM32_SAMPWIDTH: + pcm = np.frombuffer(raw, dtype=np.int32).astype(np.float64) / PCM32_SCALE + else: + msg = f"Unsupported WAV sample width: {sampwidth} bytes" + raise ValueError(msg) + + annotation = load_annotation(ann_path) + return pcm, sample_rate, annotation + + +def evaluate_recording( + wav_path: Path, + ann_path: Path, + *, + detector: DetectorCallable, +) -> EvaluationResult: + """Run detector over a recording and compute metrics against annotations.""" + samples, sample_rate, annotation = load_recording_pair(wav_path, ann_path) + + detected_events = detector(samples, sample_rate, annotation) + ground_truth = [event.type in POSITIVE_EVENT_TYPES for event in annotation.events] + detected_flags = [event.detected for event in detected_events] + + metrics = compute_metrics(detected_flags, ground_truth) + return EvaluationResult( + recording_id=annotation.recording_id, + metrics=metrics, + event_count=len(annotation.events), + detection_latencies_ms=[ + event.detection_latency_ms + for event in detected_events + if event.detection_latency_ms is not None + ], + false_positive_times_s=[ + event.detected_time_s + for event in detected_events + if ( + event.detected + and event.matched_annotation_index is None + and event.detected_time_s is not None + ) + ], + vad_overlap_ratios=[ + event.vad_overlap_ratio + for event in detected_events + if event.vad_overlap_ratio is not None + ], + ) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Evaluate sound detection accuracy against annotated recordings." + ) + parser.add_argument( + "--recordings", + type=Path, + default=Path("data/evaluation"), + help="Directory containing WAV and JSON annotation pairs (default: data/evaluation)", + ) + return parser + + +def _main(argv: list[str] | None = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + recordings_dir: Path = args.recordings + + pairs: list[tuple[Path, Path]] = [] + for wav_path in sorted(recordings_dir.glob("*.wav")): + ann_path = wav_path.with_suffix(".json") + if ann_path.exists(): + pairs.append((wav_path, ann_path)) + + if not pairs: + return 1 + + def _placeholder_detector( + _samples: FloatArray, + _sample_rate: int, + annotation: RecordingAnnotation, + ) -> list[DetectionEventResult]: + return [ + DetectionEventResult( + detected=False, + matched_annotation_index=index, + ) + for index, _event in enumerate(annotation.events) + ] + + evaluated_any = False + for wav_path, ann_path in pairs: + try: + result = evaluate_recording( + wav_path, + ann_path, + detector=_placeholder_detector, + ) + except Exception: + LOGGER.exception("Failed to evaluate recording pair: %s / %s", wav_path, ann_path) + continue + sys.stdout.write(f"{result.model_dump_json()}\n") + evaluated_any = True + + return 0 if evaluated_any else 1 + + +if __name__ == "__main__": # pragma: no cover + sys.exit(_main()) + + +__all__ = [ + "DetectionEventResult", + "DetectionMetrics", + "DetectorCallable", + "EvaluationResult", + "compute_metrics", + "evaluate_recording", + "load_recording_pair", +] diff --git a/daemon/src/daemon/tools/optimize_thresholds.py b/daemon/src/daemon/tools/optimize_thresholds.py new file mode 100644 index 00000000..e3aa6c0b --- /dev/null +++ b/daemon/src/daemon/tools/optimize_thresholds.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import argparse +import sys +from collections.abc import Callable +from pathlib import Path +from typing import Literal + +from pydantic import Field + +from daemon.contracts._base import ContractModel +from daemon.tools.evaluate_detection import EvaluationResult + +MetricName = Literal["f1", "precision", "recall"] + +# Type alias for the evaluator callable injected by callers / tests +EvaluatorCallable = Callable[ + [float, float, float, int, list[tuple[Path, Path]]], + list[EvaluationResult], +] + + +class ThresholdGrid(ContractModel): # type: ignore[explicit-any] + sensitivity_threshold: list[float] = Field(default=[0.5, 0.6, 0.7, 0.8, 0.9]) + energy_gate_threshold: list[float] = Field(default=[0.005, 0.01, 0.02, 0.05]) + speech_prob_threshold: list[float] = Field(default=[0.3, 0.4, 0.5, 0.6]) + overlap_hold_ms: list[int] = Field(default=[100, 200, 300, 500]) + + +class OptimizeConfig(ContractModel): # type: ignore[explicit-any] + metric: MetricName = "f1" + + +class GridSearchResult(ContractModel): # type: ignore[explicit-any] + sensitivity_threshold: float + energy_gate_threshold: float + speech_prob_threshold: float + overlap_hold_ms: int + score: float + metric: str + + +def select_best(results: list[GridSearchResult]) -> GridSearchResult: + """Return the highest-scoring result, preserving input order on ties.""" + if not results: + msg = "Cannot select best from empty results list" + raise ValueError(msg) + _, best_result = max( + enumerate(results), + key=lambda item: (item[1].score, -item[0]), + ) + return best_result + + +def run_grid_search( + *, + recording_pairs: list[tuple[Path, Path]], + grid: ThresholdGrid, + config: OptimizeConfig, + evaluator: EvaluatorCallable, +) -> GridSearchResult: + """Exhaustive grid search over threshold combinations. + + The *evaluator* callable is injected so callers can supply a real + SoundDetectionService wrapper or a test mock. + """ + if not recording_pairs: + msg = "Cannot run grid search with empty recording pairs list" + raise ValueError(msg) + + all_results: list[GridSearchResult] = [] + + for sensitivity in grid.sensitivity_threshold: + for energy_gate in grid.energy_gate_threshold: + for speech_prob in grid.speech_prob_threshold: + for overlap_ms in grid.overlap_hold_ms: + eval_results = evaluator( + sensitivity, + energy_gate, + speech_prob, + overlap_ms, + recording_pairs, + ) + if not eval_results: + continue + + # Macro-average the chosen metric across recordings + scores = [getattr(r.metrics, config.metric) for r in eval_results] + avg_score = sum(scores) / len(scores) + + all_results.append( + GridSearchResult( + sensitivity_threshold=sensitivity, + energy_gate_threshold=energy_gate, + speech_prob_threshold=speech_prob, + overlap_hold_ms=overlap_ms, + score=avg_score, + metric=config.metric, + ) + ) + + return select_best(all_results) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Grid-search sound detection thresholds over evaluation recordings." + ) + parser.add_argument( + "--recordings", + type=Path, + default=Path("data/evaluation"), + help="Directory containing WAV + JSON annotation pairs (default: data/evaluation)", + ) + parser.add_argument( + "--metric", + choices=["f1", "precision", "recall"], + default="f1", + help="Metric to optimise (default: f1)", + ) + return parser + + +def _main(argv: list[str] | None = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + + recordings_dir: Path = args.recordings + metric: MetricName = args.metric + + pairs: list[tuple[Path, Path]] = [] + for wav_path in sorted(recordings_dir.glob("*.wav")): + ann_path = wav_path.with_suffix(".json") + if ann_path.exists(): + pairs.append((wav_path, ann_path)) + + if not pairs: + return 1 + + # Lazy import to avoid pulling config dependencies during test collection. + from daemon.config import load_config_directory, resolve_config_dir + + def _real_evaluator( + sensitivity: float, + energy_gate: float, + speech_prob: float, + overlap_ms: int, + eval_pairs: list[tuple[Path, Path]], + ) -> list[EvaluationResult]: + del sensitivity, energy_gate, speech_prob, overlap_ms, eval_pairs + msg = "Real evaluator is not implemented for optimize_thresholds CLI yet" + raise NotImplementedError(msg) + + grid = ThresholdGrid() + cfg = OptimizeConfig(metric=metric) + snapshot = load_config_directory(resolve_config_dir()) + settings = snapshot.settings + current_thresholds = { + "sensitivity_threshold": settings.sound_detection.sensitivity_threshold, + "energy_gate_threshold": settings.sound_detection.gate_rms_threshold, + "speech_prob_threshold": settings.turn_arbiter.speech_prob_threshold, + "overlap_hold_ms": settings.turn_arbiter.overlap_hold_ms, + } + + try: + result = run_grid_search( + recording_pairs=pairs, + grid=grid, + config=cfg, + evaluator=_real_evaluator, + ) + except (NotImplementedError, RuntimeError, ValueError) as exc: + sys.stderr.write(f"Grid search failed: {exc}\n") + return 1 + + sys.stdout.write(f"metric: {result.metric}\n") + sys.stdout.write("threshold\tcurrent\trecommended\n") + sys.stdout.write( + "sensitivity_threshold\t" + f"{current_thresholds['sensitivity_threshold']}\t{result.sensitivity_threshold}\n" + ) + sys.stdout.write( + "energy_gate_threshold\t" + f"{current_thresholds['energy_gate_threshold']}\t{result.energy_gate_threshold}\n" + ) + sys.stdout.write( + "speech_prob_threshold\t" + f"{current_thresholds['speech_prob_threshold']}\t{result.speech_prob_threshold}\n" + ) + sys.stdout.write( + f"overlap_hold_ms\t{current_thresholds['overlap_hold_ms']}\t{result.overlap_hold_ms}\n" + ) + return 0 + + +if __name__ == "__main__": # pragma: no cover + sys.exit(_main()) + + +__all__ = [ + "EvaluatorCallable", + "GridSearchResult", + "MetricName", + "OptimizeConfig", + "ThresholdGrid", + "run_grid_search", + "select_best", +] diff --git a/daemon/src/daemon/tools/record_evaluation.py b/daemon/src/daemon/tools/record_evaluation.py new file mode 100644 index 00000000..41ea9138 --- /dev/null +++ b/daemon/src/daemon/tools/record_evaluation.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import argparse +import io +import sys +import wave +from datetime import UTC, datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +from pydantic import Field + +from daemon.contracts._base import ContractModel + +if TYPE_CHECKING: + import numpy.typing as npt + + from daemon.audio.input import AudioInput + + +class RecordingMetadata(ContractModel): # type: ignore[explicit-any] + recording_id: str + wav_path: str + device_type: str + sample_rate: int = Field(gt=0) + duration_s: float = Field(gt=0.0) + timestamp: str + retention_policy: str + expires_at: str | None = None + + +def record_audio( # noqa: PLR0913, PLR0914 + *, + audio_input: AudioInput, + duration_s: float, + output_dir: Path, + recording_id: str | None = None, + ephemeral: bool = False, + retention_days: int | None = None, +) -> tuple[Path, Path]: + """Record audio for *duration_s* seconds and save WAV + metadata JSON. + + Returns the (wav_path, metadata_json_path) tuple. + """ + if duration_s <= 0: + msg = "duration_s must be positive" + raise ValueError(msg) + if retention_days is not None and retention_days <= 0: + msg = "retention_days must be positive" + raise ValueError(msg) + output_dir.mkdir(parents=True, exist_ok=True) + + rec_id = recording_id or _generate_id() + wav_path = output_dir / f"{rec_id}.wav" + meta_path = output_dir / f"{rec_id}.json" + + sample_rate = audio_input.device_info.sample_rate + total_frames = max(1, int(sample_rate * duration_s)) + chunk_frames = max(1, int(sample_rate * 0.1)) + + collected: list[float] = [] + frames_remaining = total_frames + while frames_remaining > 0: + to_read = min(chunk_frames, frames_remaining) + audio_frames = audio_input.read_frames(to_read) + frames_read = len(audio_frames.samples) // max(audio_frames.channels, 1) + if frames_read == 0: + msg = "audio input returned no frames during recording" + raise RuntimeError(msg) + collected.extend(audio_frames.samples) + frames_remaining -= frames_read + + samples_np: npt.NDArray[np.float64] = np.array(collected, dtype=np.float64) + _write_wav(wav_path, samples_np, sample_rate=sample_rate) + + actual_duration_s = len(collected) / sample_rate + timestamp = datetime.now(UTC).isoformat() + expires_at = ( + (datetime.now(UTC) + timedelta(days=retention_days)).isoformat() + if retention_days is not None + else None + ) + retention_policy = "ephemeral" if ephemeral else "retained" + metadata = RecordingMetadata( + recording_id=rec_id, + wav_path=str(wav_path), + device_type=audio_input.device_info.device_type, + sample_rate=sample_rate, + duration_s=actual_duration_s, + timestamp=timestamp, + retention_policy=retention_policy, + expires_at=expires_at, + ) + meta_path.write_text(metadata.model_dump_json(indent=2), encoding="utf-8") + if ephemeral: + wav_path.unlink(missing_ok=True) + meta_path.unlink(missing_ok=True) + + return wav_path, meta_path + + +def _generate_id() -> str: + return datetime.now(UTC).strftime("rec-%Y%m%d-%H%M%S-%f") + + +def _write_wav( + path: Path, + samples: npt.NDArray[np.float64], + *, + sample_rate: int, +) -> None: + pcm = np.clip(samples, -1.0, 1.0) + pcm_i16 = (pcm * 32767.0).astype(np.int16) + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm_i16.tobytes()) + path.write_bytes(buf.getvalue()) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Record audio from the Audio HAL for evaluation purposes." + ) + parser.add_argument( + "--duration", + type=float, + default=300.0, + help="Recording duration in seconds (default: 300)", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("data/evaluation"), + help="Output directory for WAV and metadata files (default: data/evaluation)", + ) + parser.add_argument( + "--ephemeral", + action="store_true", + help="Delete the recorded WAV and metadata immediately after writing them", + ) + parser.add_argument( + "--retention-days", + type=int, + default=None, + help="Retention period to record in metadata before cleanup (default: no expiry)", + ) + return parser + + +def _main(argv: list[str] | None = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + + output_dir: Path = args.output + duration_s: float = args.duration + ephemeral: bool = args.ephemeral + retention_days: int | None = args.retention_days + + # Lazy import to avoid pulling in heavy deps during test collection + from daemon.audio import create_audio_input + from daemon.config import load_config_directory, resolve_config_dir + + snapshot = load_config_directory(resolve_config_dir()) + audio_input = create_audio_input(snapshot.settings.audio) # pyright: ignore[reportUnknownMemberType] + + try: + _wav_path, _meta_path = record_audio( + audio_input=audio_input, + duration_s=duration_s, + output_dir=output_dir, + ephemeral=ephemeral, + retention_days=retention_days, + ) + except KeyboardInterrupt: + return 1 + finally: + audio_input.close() # pyright: ignore[reportUnknownMemberType] + + return 0 + + +if __name__ == "__main__": # pragma: no cover + sys.exit(_main()) + + +__all__ = ["RecordingMetadata", "record_audio"] diff --git a/data/evaluation/.gitkeep b/data/evaluation/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/docs/adr/decision-log.md b/docs/adr/decision-log.md index ff635c9f..bfea235a 100644 --- a/docs/adr/decision-log.md +++ b/docs/adr/decision-log.md @@ -116,3 +116,4 @@ individual decision artifact under `docs/adr/decisions/`. - 2026-04-05T02:44:36Z | adr_required=false | Reduce turn arbiter claim_window_ms from 2500ms to 1500ms as issue #126 Phase 1 conversation control policy | [details](decisions/2026-04-05-reduce-turn-arbiter-claim-window-ms-from-2500ms-to-1500ms-as-issue-126-phase-1-conversation-control-policy.md) - 2026-04-05T02:47:15Z | adr_required=false | Add comprehensive Playwright E2E coverage for the memory browser workflows and update local mock review APIs. | [details](decisions/2026-04-05-add-comprehensive-playwright-e2e-coverage-for-the-memory-browser-workflows-and-update-local-mock-review-apis.md) - 2026-04-05T07:58:42Z | adr_required=false | Add backend degraded mode handling and auto-recovery for subsystem failures | [details](decisions/2026-04-05-add-backend-degraded-mode-handling-and-auto-recovery-for-subsystem-failures.md) +- 2026-04-05T08:02:38Z | adr_required=false | Address PR #138 review findings for recording evaluation tools coverage, WAV decoding, ground truth mapping, and tests | [details](decisions/2026-04-05-address-pr-138-review-findings-for-recording-evaluation-tools-coverage-wav-decoding-ground-truth-mapping-and-tests.md) diff --git a/docs/adr/decisions/2026-04-05-address-pr-138-review-findings-for-recording-evaluation-tools-coverage-wav-decoding-ground-truth-mapping-and-tests.md b/docs/adr/decisions/2026-04-05-address-pr-138-review-findings-for-recording-evaluation-tools-coverage-wav-decoding-ground-truth-mapping-and-tests.md new file mode 100644 index 00000000..2c4b1e57 --- /dev/null +++ b/docs/adr/decisions/2026-04-05-address-pr-138-review-findings-for-recording-evaluation-tools-coverage-wav-decoding-ground-truth-mapping-and-tests.md @@ -0,0 +1,8 @@ +# ADR Decision Record + +timestamp: 2026-04-05T08:02:38Z +change: Address PR #138 review findings for recording evaluation tools coverage, WAV decoding, ground truth mapping, and tests +adr_required: false +rationale: Changes are localized to existing tool modules and tests, with no architectural boundary or policy change. +files: [] +adr_paths: [] diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 00000000..bdec2fc8 --- /dev/null +++ b/tests/tools/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +__all__: list[str] = [] diff --git a/tests/tools/conftest.py b/tests/tools/conftest.py new file mode 100644 index 00000000..34931686 --- /dev/null +++ b/tests/tools/conftest.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import io +import wave +from typing import TYPE_CHECKING + +import numpy as np +import numpy.typing as npt +from daemon.tools.annotation_schema import RecordingAnnotation, save_annotation + +if TYPE_CHECKING: + from pathlib import Path + +SAMPLE_RATE = 16_000 + + +def make_wav_bytes(samples: npt.NDArray[np.float64], *, sample_rate: int = SAMPLE_RATE) -> bytes: + """Encode a float64 sample array as PCM-16 WAV bytes.""" + pcm = np.clip(samples, -1.0, 1.0) + pcm_i16 = (pcm * 32767.0).astype(np.int16) + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm_i16.tobytes()) + return buf.getvalue() + + +def write_wav(path: Path, duration_s: float = 1.0) -> None: + """Write a silent WAV file of the given duration to *path*.""" + samples = np.zeros(int(SAMPLE_RATE * duration_s), dtype=np.float64) + path.write_bytes(make_wav_bytes(samples)) + + +def write_annotation(path: Path, annotation: RecordingAnnotation) -> None: + """Serialize *annotation* to *path* as JSON.""" + save_annotation(annotation, path) + + +__all__ = ["SAMPLE_RATE", "make_wav_bytes", "write_annotation", "write_wav"] diff --git a/tests/tools/test_annotation_schema.py b/tests/tools/test_annotation_schema.py new file mode 100644 index 00000000..05c45ef3 --- /dev/null +++ b/tests/tools/test_annotation_schema.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +import pytest +from daemon.tools.annotation_schema import ( + AnnotationEvent, + RecordingAnnotation, + load_annotation, + save_annotation, +) +from pydantic import ValidationError + +if TYPE_CHECKING: + from pathlib import Path + +# --------------------------------------------------------------------------- +# AnnotationEvent validation +# --------------------------------------------------------------------------- + + +def test_annotation_event_valid_bell() -> None: + event = AnnotationEvent(time_s=1.0, type="bell") + assert event.time_s == 1.0 + assert event.type == "bell" + assert event.end_s is None + assert event.notes == "" + + +def test_annotation_event_valid_with_end() -> None: + expected_end = 2.5 + event = AnnotationEvent(time_s=0.5, end_s=expected_end, type="speech", notes="child speaking") + assert event.end_s == expected_end + assert event.notes == "child speaking" + + +def test_annotation_event_all_types() -> None: + for event_type in ("bell", "multi_voice", "quiet", "speech", "noise"): + event = AnnotationEvent(time_s=0.0, type=event_type) + assert event.type == event_type + + +def test_annotation_event_negative_time_raises() -> None: + with pytest.raises(ValidationError): + AnnotationEvent(time_s=-1.0, type="bell") + + +def test_annotation_event_zero_time_valid() -> None: + event = AnnotationEvent(time_s=0.0, type="quiet") + assert event.time_s == 0.0 + + +def test_annotation_event_end_equal_to_start_raises() -> None: + with pytest.raises(ValidationError): + AnnotationEvent(time_s=1.0, end_s=1.0, type="bell") + + +def test_annotation_event_end_before_start_raises() -> None: + with pytest.raises(ValidationError): + AnnotationEvent(time_s=2.0, end_s=1.0, type="bell") + + +def test_annotation_event_invalid_type_raises() -> None: + with pytest.raises(ValidationError): + AnnotationEvent(time_s=0.0, type="unknown") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# RecordingAnnotation validation +# --------------------------------------------------------------------------- + + +def test_recording_annotation_valid() -> None: + annotation = RecordingAnnotation( + recording_id="rec-001", + events=[AnnotationEvent(time_s=1.0, type="bell")], + ) + assert annotation.recording_id == "rec-001" + assert len(annotation.events) == 1 + + +def test_recording_annotation_empty_events() -> None: + annotation = RecordingAnnotation(recording_id="rec-001", events=[]) + assert annotation.events == [] + + +def test_recording_annotation_multiple_events() -> None: + events = [ + AnnotationEvent(time_s=1.0, type="bell"), + AnnotationEvent(time_s=5.0, end_s=8.0, type="speech"), + AnnotationEvent(time_s=15.0, type="quiet"), + ] + annotation = RecordingAnnotation(recording_id="rec-002", events=events) + expected_count = 3 + assert len(annotation.events) == expected_count + + +# --------------------------------------------------------------------------- +# load/save round-trip +# --------------------------------------------------------------------------- + + +def test_save_and_load_annotation(tmp_path: Path) -> None: + annotation = RecordingAnnotation( + recording_id="rec-roundtrip", + events=[ + AnnotationEvent(time_s=1.0, type="bell", notes="clear ring"), + AnnotationEvent(time_s=5.0, end_s=8.0, type="speech"), + ], + ) + filepath = tmp_path / "rec-roundtrip.json" + save_annotation(annotation, filepath) + + loaded = load_annotation(filepath) + assert loaded.recording_id == annotation.recording_id + expected_event_count = 2 + expected_end_s = 8.0 + assert len(loaded.events) == expected_event_count + assert loaded.events[0].type == "bell" + assert loaded.events[0].notes == "clear ring" + assert loaded.events[1].end_s == expected_end_s + + +def test_save_produces_valid_json(tmp_path: Path) -> None: + annotation = RecordingAnnotation( + recording_id="rec-json", + events=[AnnotationEvent(time_s=0.0, type="quiet")], + ) + filepath = tmp_path / "rec-json.json" + save_annotation(annotation, filepath) + + raw = json.loads(filepath.read_text()) + assert raw["recording_id"] == "rec-json" + assert isinstance(raw["events"], list) + + +def test_load_annotation_missing_file_raises(tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError): + load_annotation(tmp_path / "does_not_exist.json") + + +def test_load_annotation_invalid_json_raises(tmp_path: Path) -> None: + bad_file = tmp_path / "bad.json" + bad_file.write_text("not valid json", encoding="utf-8") + with pytest.raises((ValueError, ValidationError)): + load_annotation(bad_file) diff --git a/tests/tools/test_evaluate_detection.py b/tests/tools/test_evaluate_detection.py new file mode 100644 index 00000000..07e41a23 --- /dev/null +++ b/tests/tools/test_evaluate_detection.py @@ -0,0 +1,390 @@ +from __future__ import annotations + +import logging +import wave +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from daemon.tools import evaluate_detection +from daemon.tools.annotation_schema import AnnotationEvent, RecordingAnnotation +from daemon.tools.evaluate_detection import ( + DetectionEventResult, + DetectionMetrics, + EvaluationResult, + _main, + compute_metrics, + evaluate_recording, + load_recording_pair, +) +from pydantic import ValidationError + +from tests.tools.conftest import write_annotation, write_wav + +if TYPE_CHECKING: + from pathlib import Path + + import numpy.typing as npt + +SAMPLE_RATE = 16_000 + + +def _write_custom_wav(path: Path, *, sample_rate: int, sampwidth: int, raw_frames: bytes) -> None: + with wave.open(str(path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(sampwidth) + wf.setframerate(sample_rate) + wf.writeframes(raw_frames) + + +# --------------------------------------------------------------------------- +# DetectionMetrics model +# --------------------------------------------------------------------------- + + +def test_detection_metrics_perfect() -> None: + metrics = DetectionMetrics(precision=1.0, recall=1.0, f1=1.0, false_positive_rate=0.0) + assert metrics.f1 == 1.0 + + +def test_detection_metrics_bounds() -> None: + with pytest.raises(ValidationError): + DetectionMetrics(precision=1.5, recall=1.0, f1=1.0, false_positive_rate=0.0) + + +# --------------------------------------------------------------------------- +# compute_metrics +# --------------------------------------------------------------------------- + + +def test_compute_metrics_perfect_detection() -> None: + predicted = [True, True, False, False] + ground_truth = [True, True, False, False] + metrics = compute_metrics(predicted, ground_truth) + assert metrics.precision == pytest.approx(1.0) + assert metrics.recall == pytest.approx(1.0) + assert metrics.f1 == pytest.approx(1.0) + assert metrics.false_positive_rate == pytest.approx(0.0) + + +def test_compute_metrics_all_false_positives() -> None: + predicted = [True, True, False, False] + ground_truth = [False, False, False, False] + metrics = compute_metrics(predicted, ground_truth) + assert metrics.precision == pytest.approx(0.0) + assert metrics.false_positive_rate == pytest.approx(0.5) + + +def test_compute_metrics_all_missed() -> None: + predicted = [False, False, False, False] + ground_truth = [True, True, False, False] + metrics = compute_metrics(predicted, ground_truth) + assert metrics.recall == pytest.approx(0.0) + assert metrics.f1 == pytest.approx(0.0) + + +def test_compute_metrics_no_positives() -> None: + predicted = [False, False] + ground_truth = [False, False] + metrics = compute_metrics(predicted, ground_truth) + assert metrics.f1 == pytest.approx(0.0) + + +def test_compute_metrics_mismatched_length_raises() -> None: + with pytest.raises(ValueError, match="length"): + compute_metrics([True, False], [True]) + + +# --------------------------------------------------------------------------- +# load_recording_pair +# --------------------------------------------------------------------------- + + +def test_load_recording_pair_returns_samples_and_annotation(tmp_path: Path) -> None: + wav_path = tmp_path / "rec.wav" + ann_path = tmp_path / "rec.json" + write_wav(wav_path, 1.0) + annotation = RecordingAnnotation( + recording_id="rec", + events=[AnnotationEvent(time_s=0.5, type="bell")], + ) + write_annotation(ann_path, annotation) + + samples, sample_rate, loaded_ann = load_recording_pair(wav_path, ann_path) + assert len(samples) == SAMPLE_RATE + assert sample_rate == SAMPLE_RATE + assert loaded_ann.recording_id == "rec" + assert len(loaded_ann.events) == 1 + + +def test_load_recording_pair_missing_wav_raises(tmp_path: Path) -> None: + ann_path = tmp_path / "rec.json" + annotation = RecordingAnnotation(recording_id="rec", events=[]) + write_annotation(ann_path, annotation) + with pytest.raises(FileNotFoundError): + load_recording_pair(tmp_path / "no.wav", ann_path) + + +def test_load_recording_pair_missing_annotation_raises(tmp_path: Path) -> None: + wav_path = tmp_path / "rec.wav" + write_wav(wav_path, 0.1) + + with pytest.raises(FileNotFoundError): + load_recording_pair(wav_path, tmp_path / "missing.json") + + +def test_load_recording_pair_decodes_uint8_wav(tmp_path: Path) -> None: + wav_path = tmp_path / "uint8.wav" + ann_path = tmp_path / "uint8.json" + samples_u8 = np.array([0, 128, 255], dtype=np.uint8) + _write_custom_wav( + wav_path, + sample_rate=SAMPLE_RATE, + sampwidth=1, + raw_frames=samples_u8.tobytes(), + ) + write_annotation(ann_path, RecordingAnnotation(recording_id="uint8", events=[])) + + samples, sample_rate, loaded_ann = load_recording_pair(wav_path, ann_path) + + assert sample_rate == SAMPLE_RATE + assert loaded_ann.recording_id == "uint8" + assert samples.tolist() == pytest.approx([-1.0, 0.0, 127.0 / 128.0]) + + +def test_load_recording_pair_decodes_int32_wav(tmp_path: Path) -> None: + wav_path = tmp_path / "int32.wav" + ann_path = tmp_path / "int32.json" + samples_i32 = np.array([np.iinfo(np.int32).min, 0, np.iinfo(np.int32).max], dtype=np.int32) + _write_custom_wav( + wav_path, + sample_rate=SAMPLE_RATE, + sampwidth=4, + raw_frames=samples_i32.tobytes(), + ) + write_annotation(ann_path, RecordingAnnotation(recording_id="int32", events=[])) + + samples, _sample_rate, loaded_ann = load_recording_pair(wav_path, ann_path) + + assert loaded_ann.recording_id == "int32" + assert samples.tolist() == pytest.approx([-1.0, 0.0, 2147483647.0 / 2147483648.0]) + + +def test_load_recording_pair_unsupported_sampwidth_raises(tmp_path: Path) -> None: + wav_path = tmp_path / "int24.wav" + ann_path = tmp_path / "int24.json" + _write_custom_wav(wav_path, sample_rate=SAMPLE_RATE, sampwidth=3, raw_frames=b"\x00\x00\x00") + write_annotation(ann_path, RecordingAnnotation(recording_id="int24", events=[])) + + with pytest.raises(ValueError, match="Unsupported WAV sample width"): + load_recording_pair(wav_path, ann_path) + + +# --------------------------------------------------------------------------- +# evaluate_recording — integration with a dummy detector callable +# --------------------------------------------------------------------------- + + +def test_evaluate_recording_with_dummy_detector(tmp_path: Path) -> None: + wav_path = tmp_path / "rec.wav" + ann_path = tmp_path / "rec.json" + write_wav(wav_path, 2.0) + + annotation = RecordingAnnotation( + recording_id="rec", + events=[AnnotationEvent(time_s=0.5, type="bell")], + ) + write_annotation(ann_path, annotation) + + # Dummy detector: always says "detected" at 0.5 s + def _dummy_detector( + _samples: npt.NDArray[np.float64], + _sample_rate: int, + annotation: RecordingAnnotation, + ) -> list[DetectionEventResult]: + return [ + DetectionEventResult( + detected=True, + detected_time_s=event.time_s, + confidence=1.0, + matched_annotation_index=index, + detection_latency_ms=0.0, + vad_overlap_ratio=1.0, + ) + for index, event in enumerate(annotation.events) + ] + + result = evaluate_recording(wav_path, ann_path, detector=_dummy_detector) + assert isinstance(result, EvaluationResult) + assert result.recording_id == "rec" + assert result.metrics.f1 == pytest.approx(1.0) + + +def test_evaluate_recording_partial_detection(tmp_path: Path) -> None: + wav_path = tmp_path / "partial.wav" + ann_path = tmp_path / "partial.json" + write_wav(wav_path, 2.0) + annotation = RecordingAnnotation( + recording_id="partial", + events=[ + AnnotationEvent(time_s=0.1, type="bell"), + AnnotationEvent(time_s=0.7, type="speech"), + AnnotationEvent(time_s=1.4, type="multi_voice"), + ], + ) + write_annotation(ann_path, annotation) + + def _partial_detector( + _samples: npt.NDArray[np.float64], + _sample_rate: int, + _annotation: RecordingAnnotation, + ) -> list[DetectionEventResult]: + return [ + DetectionEventResult( + detected=True, + matched_annotation_index=0, + detection_latency_ms=5.0, + ), + DetectionEventResult(detected=False), + DetectionEventResult( + detected=True, + matched_annotation_index=2, + detection_latency_ms=12.0, + ), + ] + + result = evaluate_recording(wav_path, ann_path, detector=_partial_detector) + + assert result.metrics.precision == pytest.approx(1.0) + assert result.metrics.recall == pytest.approx(2.0 / 3.0) + assert result.metrics.f1 == pytest.approx(0.8) + assert result.metrics.false_positive_rate == pytest.approx(0.0) + assert result.detection_latencies_ms == [5.0, 12.0] + + +def test_evaluate_recording_quiet_and_noise_are_negative_ground_truth(tmp_path: Path) -> None: + wav_path = tmp_path / "quiet.wav" + ann_path = tmp_path / "quiet.json" + write_wav(wav_path, 2.0) + annotation = RecordingAnnotation( + recording_id="quiet-mix", + events=[ + AnnotationEvent(time_s=0.1, type="quiet"), + AnnotationEvent(time_s=0.7, type="bell"), + AnnotationEvent(time_s=1.2, type="noise"), + AnnotationEvent(time_s=1.5, type="speech"), + ], + ) + write_annotation(ann_path, annotation) + + def _mixed_detector( + _samples: npt.NDArray[np.float64], + _sample_rate: int, + _annotation: RecordingAnnotation, + ) -> list[DetectionEventResult]: + return [ + DetectionEventResult(detected=True, detected_time_s=0.1), + DetectionEventResult(detected=True, matched_annotation_index=1, vad_overlap_ratio=0.9), + DetectionEventResult(detected=False), + DetectionEventResult(detected=False), + ] + + result = evaluate_recording(wav_path, ann_path, detector=_mixed_detector) + + assert result.metrics.precision == pytest.approx(0.5) + assert result.metrics.recall == pytest.approx(0.5) + assert result.metrics.f1 == pytest.approx(0.5) + assert result.metrics.false_positive_rate == pytest.approx(0.5) + assert result.false_positive_times_s == [0.1] + assert result.vad_overlap_ratios == [0.9] + + +# --------------------------------------------------------------------------- +# _main +# --------------------------------------------------------------------------- + + +def test_main_returns_1_when_no_pairs(tmp_path: Path) -> None: + assert _main(["--recordings", str(tmp_path)]) == 1 + + +def test_main_logs_and_continues_on_pair_failure( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, + capsys: pytest.CaptureFixture[str], +) -> None: + good_wav = tmp_path / "good.wav" + bad_wav = tmp_path / "bad.wav" + good_ann = good_wav.with_suffix(".json") + bad_ann = bad_wav.with_suffix(".json") + for wav_path, ann_path, recording_id in ( + (good_wav, good_ann, "good"), + (bad_wav, bad_ann, "bad"), + ): + write_wav(wav_path, 0.1) + write_annotation( + ann_path, + RecordingAnnotation( + recording_id=recording_id, + events=[AnnotationEvent(time_s=0.1, type="bell")], + ), + ) + + calls: list[Path] = [] + + def _fake_evaluate_recording( + wav_path: Path, + ann_path: Path, + *, + detector: evaluate_detection.DetectorCallable, + ) -> EvaluationResult: + del ann_path, detector + calls.append(wav_path) + if wav_path == bad_wav: + raise ValueError + return EvaluationResult( + recording_id="good", + metrics=DetectionMetrics( + precision=1.0, + recall=0.5, + f1=2.0 / 3.0, + false_positive_rate=0.0, + ), + event_count=1, + ) + + monkeypatch.setattr(evaluate_detection, "evaluate_recording", _fake_evaluate_recording) + + with caplog.at_level(logging.ERROR): + result = _main(["--recordings", str(tmp_path)]) + + stdout = capsys.readouterr().out + assert result == 0 + assert calls == [bad_wav, good_wav] + assert good_wav in calls + assert "Failed to evaluate recording pair" in caplog.text + assert '"recording_id":"good"' in stdout + assert '"f1":0.6666666666666666' in stdout + + +# --------------------------------------------------------------------------- +# EvaluationResult +# --------------------------------------------------------------------------- + + +def test_evaluation_result_has_recording_id() -> None: + metrics = DetectionMetrics( + precision=0.9, + recall=0.8, + f1=0.85, + false_positive_rate=0.1, + ) + expected_event_count = 5 + result = EvaluationResult( + recording_id="rec-001", + metrics=metrics, + event_count=expected_event_count, + ) + assert result.recording_id == "rec-001" + assert result.event_count == expected_event_count diff --git a/tests/tools/test_optimize_thresholds.py b/tests/tools/test_optimize_thresholds.py new file mode 100644 index 00000000..94721b36 --- /dev/null +++ b/tests/tools/test_optimize_thresholds.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +import sys +import types +from typing import TYPE_CHECKING + +import pytest +from daemon.tools import optimize_thresholds +from daemon.tools.annotation_schema import AnnotationEvent, RecordingAnnotation +from daemon.tools.evaluate_detection import DetectionMetrics, EvaluationResult +from daemon.tools.optimize_thresholds import ( + GridSearchResult, + OptimizeConfig, + ThresholdGrid, + _main, + run_grid_search, + select_best, +) +from pydantic import ValidationError + +from tests.tools.conftest import write_annotation, write_wav + +if TYPE_CHECKING: + from pathlib import Path + +SAMPLE_RATE = 16_000 + + +# --------------------------------------------------------------------------- +# ThresholdGrid +# --------------------------------------------------------------------------- + + +def test_threshold_grid_default_has_values() -> None: + grid = ThresholdGrid() + assert len(grid.sensitivity_threshold) > 0 + assert len(grid.energy_gate_threshold) > 0 + assert len(grid.speech_prob_threshold) > 0 + assert len(grid.overlap_hold_ms) > 0 + + +def test_threshold_grid_all_values_in_range() -> None: + grid = ThresholdGrid() + for v in grid.sensitivity_threshold: + assert 0.0 <= v <= 1.0 + for v in grid.energy_gate_threshold: + assert v >= 0.0 + for v in grid.speech_prob_threshold: + assert 0.0 <= v <= 1.0 + for v in grid.overlap_hold_ms: + assert v >= 0 + + +def test_threshold_grid_custom() -> None: + grid = ThresholdGrid( + sensitivity_threshold=[0.7, 0.8], + energy_gate_threshold=[0.01, 0.02], + speech_prob_threshold=[0.4, 0.5], + overlap_hold_ms=[100, 200], + ) + assert grid.sensitivity_threshold == [0.7, 0.8] + + +# --------------------------------------------------------------------------- +# OptimizeConfig +# --------------------------------------------------------------------------- + + +def test_optimize_config_default_metric() -> None: + cfg = OptimizeConfig() + assert cfg.metric == "f1" + + +def test_optimize_config_valid_metrics() -> None: + for metric in ("f1", "precision", "recall"): + cfg = OptimizeConfig(metric=metric) + assert cfg.metric == metric + + +def test_optimize_config_invalid_metric_raises() -> None: + with pytest.raises(ValidationError): + OptimizeConfig(metric="accuracy") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# GridSearchResult +# --------------------------------------------------------------------------- + + +def test_grid_search_result_fields() -> None: + result = GridSearchResult( + sensitivity_threshold=0.8, + energy_gate_threshold=0.02, + speech_prob_threshold=0.5, + overlap_hold_ms=300, + score=0.9, + metric="f1", + ) + expected_score = 0.9 + assert result.score == expected_score + assert result.metric == "f1" + + +# --------------------------------------------------------------------------- +# select_best +# --------------------------------------------------------------------------- + + +def test_select_best_picks_highest_score() -> None: + results = [ + GridSearchResult( + sensitivity_threshold=0.7, + energy_gate_threshold=0.01, + speech_prob_threshold=0.4, + overlap_hold_ms=200, + score=0.6, + metric="f1", + ), + GridSearchResult( + sensitivity_threshold=0.8, + energy_gate_threshold=0.02, + speech_prob_threshold=0.5, + overlap_hold_ms=300, + score=0.9, + metric="f1", + ), + GridSearchResult( + sensitivity_threshold=0.9, + energy_gate_threshold=0.03, + speech_prob_threshold=0.6, + overlap_hold_ms=400, + score=0.75, + metric="f1", + ), + ] + best = select_best(results) + assert best.score == pytest.approx(0.9) + assert best.sensitivity_threshold == pytest.approx(0.8) + + +def test_select_best_empty_raises() -> None: + with pytest.raises(ValueError, match="empty"): + select_best([]) + + +def test_select_best_keeps_first_result_on_tie() -> None: + first = GridSearchResult( + sensitivity_threshold=0.7, + energy_gate_threshold=0.01, + speech_prob_threshold=0.4, + overlap_hold_ms=200, + score=0.9, + metric="f1", + ) + second = GridSearchResult( + sensitivity_threshold=0.9, + energy_gate_threshold=0.03, + speech_prob_threshold=0.6, + overlap_hold_ms=400, + score=0.9, + metric="f1", + ) + + assert select_best([first, second]) is first + + +# --------------------------------------------------------------------------- +# run_grid_search — with a mock evaluator +# --------------------------------------------------------------------------- + + +def test_run_grid_search_returns_best(tmp_path: Path) -> None: + wav_path = tmp_path / "rec.wav" + ann_path = tmp_path / "rec.json" + write_wav(wav_path, 1.0) + annotation = RecordingAnnotation( + recording_id="rec", + events=[AnnotationEvent(time_s=0.5, type="bell")], + ) + write_annotation(ann_path, annotation) + + recording_pairs = [(wav_path, ann_path)] + + call_count = 0 + + def _mock_evaluator( + sensitivity: float, + _energy_gate: float, + _speech_prob: float, + _overlap_ms: int, + _pairs: list[tuple[Path, Path]], + ) -> list[EvaluationResult]: + nonlocal call_count + call_count += 1 + score = 0.5 + sensitivity * 0.3 + metrics = DetectionMetrics( + precision=score, + recall=score, + f1=score, + false_positive_rate=1.0 - score, + ) + return [EvaluationResult(recording_id="rec", metrics=metrics, event_count=1)] + + grid = ThresholdGrid( + sensitivity_threshold=[0.6, 0.9], + energy_gate_threshold=[0.01, 0.02], + speech_prob_threshold=[0.5], + overlap_hold_ms=[300], + ) + cfg = OptimizeConfig(metric="f1") + + best = run_grid_search( + recording_pairs=recording_pairs, + grid=grid, + config=cfg, + evaluator=_mock_evaluator, + ) + + expected_call_count = 4 + assert best.sensitivity_threshold == pytest.approx(0.9) + assert call_count == expected_call_count + + +def test_run_grid_search_no_recordings_raises() -> None: + grid = ThresholdGrid( + sensitivity_threshold=[0.8], + energy_gate_threshold=[0.02], + speech_prob_threshold=[0.5], + overlap_hold_ms=[300], + ) + cfg = OptimizeConfig(metric="f1") + + def _dummy_evaluator( + _sensitivity: float, + _energy_gate: float, + _speech_prob: float, + _overlap_ms: int, + _pairs: list[tuple[Path, Path]], + ) -> list[EvaluationResult]: + return [] + + with pytest.raises(ValueError, match="empty"): + run_grid_search( + recording_pairs=[], + grid=grid, + config=cfg, + evaluator=_dummy_evaluator, + ) + + +# --------------------------------------------------------------------------- +# _main +# --------------------------------------------------------------------------- + + +def test_main_returns_1_when_no_pairs(tmp_path: Path) -> None: + assert _main(["--recordings", str(tmp_path)]) == 1 + + +def test_main_returns_0_when_grid_search_succeeds( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + wav_path = tmp_path / "rec.wav" + ann_path = tmp_path / "rec.json" + write_wav(wav_path, 0.1) + write_annotation( + ann_path, + RecordingAnnotation( + recording_id="rec", + events=[AnnotationEvent(time_s=0.1, type="bell")], + ), + ) + + def _fake_run_grid_search(**_kwargs: object) -> GridSearchResult: + return GridSearchResult( + sensitivity_threshold=0.8, + energy_gate_threshold=0.02, + speech_prob_threshold=0.5, + overlap_hold_ms=300, + score=0.9, + metric="precision", + ) + + monkeypatch.setattr(optimize_thresholds, "run_grid_search", _fake_run_grid_search) + settings = types.SimpleNamespace( + sound_detection=types.SimpleNamespace( + sensitivity_threshold=0.7, + gate_rms_threshold=0.01, + ), + turn_arbiter=types.SimpleNamespace( + speech_prob_threshold=0.4, + overlap_hold_ms=200, + ), + ) + monkeypatch.setitem( + sys.modules, + "daemon.config", + types.SimpleNamespace( + load_config_directory=lambda _config_dir: types.SimpleNamespace(settings=settings), + resolve_config_dir=lambda: tmp_path, + ), + ) + + assert _main(["--recordings", str(tmp_path), "--metric", "precision"]) == 0 + stdout = capsys.readouterr().out + assert "metric: precision" in stdout + assert "threshold\tcurrent\trecommended" in stdout + assert "sensitivity_threshold\t0.7\t0.8" in stdout + assert "energy_gate_threshold\t0.01\t0.02" in stdout + assert "speech_prob_threshold\t0.4\t0.5" in stdout + assert "overlap_hold_ms\t200\t300" in stdout + + +def test_main_returns_1_when_grid_search_raises( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + wav_path = tmp_path / "rec.wav" + ann_path = tmp_path / "rec.json" + write_wav(wav_path, 0.1) + write_annotation( + ann_path, + RecordingAnnotation( + recording_id="rec", + events=[AnnotationEvent(time_s=0.1, type="bell")], + ), + ) + + def _failing_run_grid_search(**_kwargs: object) -> GridSearchResult: + raise ValueError("boom") + + monkeypatch.setattr(optimize_thresholds, "run_grid_search", _failing_run_grid_search) + settings = types.SimpleNamespace( + sound_detection=types.SimpleNamespace( + sensitivity_threshold=0.7, + gate_rms_threshold=0.01, + ), + turn_arbiter=types.SimpleNamespace( + speech_prob_threshold=0.4, + overlap_hold_ms=200, + ), + ) + monkeypatch.setitem( + sys.modules, + "daemon.config", + types.SimpleNamespace( + load_config_directory=lambda _config_dir: types.SimpleNamespace(settings=settings), + resolve_config_dir=lambda: tmp_path, + ), + ) + + assert _main(["--recordings", str(tmp_path)]) == 1 + stderr = capsys.readouterr().err + assert "Grid search failed: boom" in stderr diff --git a/tests/tools/test_record_evaluation.py b/tests/tools/test_record_evaluation.py new file mode 100644 index 00000000..d1199815 --- /dev/null +++ b/tests/tools/test_record_evaluation.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import json +import sys +import types +import wave +from pathlib import Path + +import pytest +from daemon.audio.input import AudioFrames +from daemon.contracts import AudioCapability, AudioDeviceInfo +from daemon.tools import record_evaluation +from daemon.tools.record_evaluation import RecordingMetadata, _main, record_audio +from pydantic import ValidationError + +SAMPLE_RATE = 16_000 + + +class _FakeAudioInput: + """Minimal AudioInput stub that returns silence.""" + + def __init__(self) -> None: + self.device_info = AudioDeviceInfo( + device_id="fake-device", + device_type="single_mic", + capabilities=[AudioCapability.BASIC], + sample_rate=SAMPLE_RATE, + channels=1, + ) + self.backend_name = "fake" + self.mode = "test" + self.closed = False + + def read_frames(self, frame_count: int) -> AudioFrames: # noqa: PLR6301 + return AudioFrames( + samples=tuple(0.0 for _ in range(frame_count)), + sample_rate=SAMPLE_RATE, + ) + + def read_doa(self) -> int | None: # noqa: PLR6301 + return None + + def snapshot_raw_channels(self, frame_count: int) -> tuple[AudioFrames, ...]: + return (self.read_frames(frame_count),) + + def close(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# RecordingMetadata model +# --------------------------------------------------------------------------- + + +def test_recording_metadata_fields() -> None: + meta = RecordingMetadata( + recording_id="rec-001", + wav_path="data/evaluation/rec-001.wav", + device_type="single_mic", + sample_rate=SAMPLE_RATE, + duration_s=5.0, + timestamp="2026-04-05T10:00:00+00:00", + retention_policy="retained", + ) + assert meta.recording_id == "rec-001" + assert meta.sample_rate == SAMPLE_RATE + expected_duration = 5.0 + assert meta.duration_s == expected_duration + + +def test_recording_metadata_positive_duration() -> None: + with pytest.raises(ValidationError): + RecordingMetadata( + recording_id="bad", + wav_path="x.wav", + device_type="single_mic", + sample_rate=SAMPLE_RATE, + duration_s=0.0, + timestamp="2026-04-05T10:00:00+00:00", + retention_policy="retained", + ) + + +def test_recording_metadata_negative_duration_raises() -> None: + with pytest.raises(ValidationError): + RecordingMetadata( + recording_id="bad", + wav_path="x.wav", + device_type="single_mic", + sample_rate=SAMPLE_RATE, + duration_s=-1.0, + timestamp="2026-04-05T10:00:00+00:00", + retention_policy="retained", + ) + + +def test_record_audio_zero_duration_raises(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + wav_path = tmp_path / "bad-zero.wav" + meta_path = tmp_path / "bad-zero.json" + with pytest.raises(ValueError, match="positive"): + record_audio( + audio_input=audio_input, + duration_s=0.0, + output_dir=tmp_path, + recording_id="bad-zero", + ) + assert not wav_path.exists() + assert not meta_path.exists() + assert list(tmp_path.iterdir()) == [] + + +def test_record_audio_negative_duration_raises(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + wav_path = tmp_path / "bad-negative.wav" + meta_path = tmp_path / "bad-negative.json" + with pytest.raises(ValueError, match="positive"): + record_audio( + audio_input=audio_input, + duration_s=-1.0, + output_dir=tmp_path, + recording_id="bad-negative", + ) + assert not wav_path.exists() + assert not meta_path.exists() + assert list(tmp_path.iterdir()) == [] + + +# --------------------------------------------------------------------------- +# record_audio creates WAV + metadata files +# --------------------------------------------------------------------------- + + +def test_record_audio_creates_wav_and_metadata(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + wav_path, meta_path = record_audio( + audio_input=audio_input, + duration_s=0.1, + output_dir=tmp_path, + recording_id="test-rec", + retention_days=7, + ) + + assert wav_path.exists() + assert meta_path.exists() + assert wav_path.suffix == ".wav" + assert meta_path.suffix == ".json" + raw = json.loads(meta_path.read_text()) + assert raw["retention_policy"] == "retained" + assert raw["expires_at"] is not None + + +def test_record_audio_wav_is_valid(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + wav_path, _ = record_audio( + audio_input=audio_input, + duration_s=0.1, + output_dir=tmp_path, + recording_id="test-rec", + ) + with wave.open(str(wav_path), "rb") as wf: + assert wf.getnchannels() == 1 + assert wf.getframerate() == SAMPLE_RATE + assert wf.getnframes() > 0 + + +def test_record_audio_metadata_content(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + _, meta_path = record_audio( + audio_input=audio_input, + duration_s=0.1, + output_dir=tmp_path, + recording_id="test-rec", + ) + raw = json.loads(meta_path.read_text()) + assert raw["recording_id"] == "test-rec" + assert raw["sample_rate"] == SAMPLE_RATE + assert raw["device_type"] == "single_mic" + assert raw["duration_s"] > 0.0 + + +def test_record_audio_auto_generates_id(tmp_path: Path) -> None: + audio_input = _FakeAudioInput() + wav_path, meta_path = record_audio( + audio_input=audio_input, + duration_s=0.05, + output_dir=tmp_path, + retention_days=3, + ) + assert wav_path.exists() + meta_raw = json.loads(meta_path.read_text()) + assert meta_raw["recording_id"] # not empty + assert meta_raw["expires_at"] is not None + + +def test_record_audio_output_dir_created(tmp_path: Path) -> None: + new_dir = tmp_path / "nested" / "dir" + audio_input = _FakeAudioInput() + wav_path, meta_path = record_audio( + audio_input=audio_input, + duration_s=0.05, + output_dir=new_dir, + recording_id="auto-dir", + ephemeral=True, + ) + assert new_dir.exists() + assert not wav_path.exists() + assert not meta_path.exists() + + +# --------------------------------------------------------------------------- +# _main +# --------------------------------------------------------------------------- + + +def test_main_records_and_closes_audio_input( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + audio_input = _FakeAudioInput() + settings = types.SimpleNamespace(audio=object()) + calls: list[tuple[float, Path, bool, int | None]] = [] + + def _fake_record_audio(**kwargs: object) -> tuple[Path, Path]: + duration_s = kwargs["duration_s"] + output_dir = kwargs["output_dir"] + ephemeral = kwargs.get("ephemeral", False) + retention_days = kwargs.get("retention_days") + assert isinstance(duration_s, float) + assert isinstance(output_dir, Path) + assert isinstance(ephemeral, bool) + assert retention_days is None or isinstance(retention_days, int) + calls.append((duration_s, output_dir, ephemeral, retention_days)) + wav_path = output_dir / "cli.wav" + meta_path = output_dir / "cli.json" + wav_path.write_bytes(b"wav") + meta_path.write_text("{}", encoding="utf-8") + return wav_path, meta_path + + monkeypatch.setattr(record_evaluation, "record_audio", _fake_record_audio) + monkeypatch.setitem( + sys.modules, + "daemon.audio", + types.SimpleNamespace(create_audio_input=lambda _audio: audio_input), + ) + monkeypatch.setitem( + sys.modules, + "daemon.config", + types.SimpleNamespace( + load_config_directory=lambda _config_dir: types.SimpleNamespace(settings=settings), + resolve_config_dir=lambda: tmp_path, + ), + ) + + assert ( + _main([ + "--duration", + "1.5", + "--output", + str(tmp_path), + "--ephemeral", + "--retention-days", + "2", + ]) + == 0 + ) + assert calls == [(1.5, tmp_path, True, 2)] + assert audio_input.closed is True + + +def test_main_returns_1_on_keyboard_interrupt_and_closes_audio_input( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + audio_input = _FakeAudioInput() + settings = types.SimpleNamespace(audio=object()) + + def _interrupting_record_audio(**kwargs: object) -> tuple[Path, Path]: + del kwargs + raise KeyboardInterrupt + + monkeypatch.setattr(record_evaluation, "record_audio", _interrupting_record_audio) + monkeypatch.setitem( + sys.modules, + "daemon.audio", + types.SimpleNamespace(create_audio_input=lambda _audio: audio_input), + ) + monkeypatch.setitem( + sys.modules, + "daemon.config", + types.SimpleNamespace( + load_config_directory=lambda _config_dir: types.SimpleNamespace(settings=settings), + resolve_config_dir=lambda: tmp_path, + ), + ) + + assert _main(["--duration", "0.1", "--output", str(tmp_path)]) == 1 + assert audio_input.closed is True