diff --git a/nemo_retriever/src/nemo_retriever/audio/__init__.py b/nemo_retriever/src/nemo_retriever/audio/__init__.py index 7ce292aa05..b0cf70550c 100644 --- a/nemo_retriever/src/nemo_retriever/audio/__init__.py +++ b/nemo_retriever/src/nemo_retriever/audio/__init__.py @@ -11,7 +11,7 @@ from __future__ import annotations -from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor +from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor, ASRGPUActor from nemo_retriever.audio.asr_actor import asr_params_from_env from nemo_retriever.audio.chunk_actor import MediaChunkActor from nemo_retriever.audio.media_interface import MediaInterface @@ -23,6 +23,7 @@ __all__ = [ "ASRActor", "ASRCPUActor", + "ASRGPUActor", "ASRParams", "app", "asr_params_from_env", diff --git a/nemo_retriever/src/nemo_retriever/audio/asr_actor.py b/nemo_retriever/src/nemo_retriever/audio/asr_actor.py index fe55bcd7cd..4edd8697c9 100644 --- a/nemo_retriever/src/nemo_retriever/audio/asr_actor.py +++ b/nemo_retriever/src/nemo_retriever/audio/asr_actor.py @@ -28,6 +28,7 @@ from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.graph.cpu_operator import CPUOperator +from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.graph.operator_archetype import ArchetypeOperator from nemo_retriever.params import ASRParams @@ -123,7 +124,13 @@ def _get_client(params: ASRParams): # noqa: ANN201 ) -class ASRCPUActor(AbstractOperator, CPUOperator): +def _create_local_model(): # noqa: ANN201 + from nemo_retriever.model.local import ParakeetCTC1B1ASR + + return ParakeetCTC1B1ASR() + + +class _ASRBaseActor(AbstractOperator): """ Ray Data map_batches callable: chunk rows (path/bytes) -> rows with text (transcript). @@ -134,17 +141,17 @@ class ASRCPUActor(AbstractOperator, CPUOperator): segments are emitted as multiple rows per chunk. """ - def __init__(self, params: ASRParams | None = None) -> None: + def __init__( + self, + params: ASRParams | None = None, + *, + client: Any = None, + model: Any = None, + ) -> None: super().__init__(params=params) self._params = params or ASRParams() - if _use_remote(self._params): - self._client = _get_client(self._params) - self._model = None - else: - self._client = None - from nemo_retriever.model.local import ParakeetCTC1B1ASR - - self._model = ParakeetCTC1B1ASR() + self._client = client + self._model = model def preprocess(self, data: Any, **kwargs: Any) -> Any: return data @@ -369,10 +376,40 @@ def _transcribe_one(self, row: pd.Series) -> List[Dict[str, Any]]: return self._build_output_rows(row, transcript) +class ASRCPUActor(_ASRBaseActor, CPUOperator): + """CPU actor for ASR. + + Uses remote Parakeet when endpoints are configured; otherwise falls back to + the local Parakeet model so CPU-only environments continue to work. + """ + + def __init__(self, params: ASRParams | None = None) -> None: + resolved_params = params or ASRParams() + client = _get_client(resolved_params) if _use_remote(resolved_params) else None + model = None if client is not None else _create_local_model() + super().__init__(params=resolved_params, client=client, model=model) + + +class ASRGPUActor(_ASRBaseActor, GPUOperator): + """GPU actor for ASR using the local Parakeet model.""" + + def __init__(self, params: ASRParams | None = None) -> None: + resolved_params = params or ASRParams() + if _use_remote(resolved_params): + raise ValueError("ASRGPUActor does not support remote endpoints. Use ASRCPUActor instead.") + super().__init__(params=resolved_params, client=None, model=_create_local_model()) + + class ASRActor(ArchetypeOperator): """Graph-facing ASR archetype resolved to the best concrete runtime implementation.""" _cpu_variant_class = ASRCPUActor + _gpu_variant_class = ASRGPUActor + + @classmethod + def prefers_cpu_variant(cls, operator_kwargs: dict[str, Any] | None = None) -> bool: + params = (operator_kwargs or {}).get("params") + return isinstance(params, ASRParams) and _use_remote(params) def __init__(self, params: ASRParams | None = None) -> None: resolved_params = params or ASRParams() diff --git a/nemo_retriever/tests/test_asr_actor.py b/nemo_retriever/tests/test_asr_actor.py index c7297e3912..fef8131f1e 100644 --- a/nemo_retriever/tests/test_asr_actor.py +++ b/nemo_retriever/tests/test_asr_actor.py @@ -17,9 +17,10 @@ import pandas as pd -from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor +from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor, ASRGPUActor from nemo_retriever.audio.asr_actor import apply_asr_to_df from nemo_retriever.params import ASRParams +from nemo_retriever.utils.ray_resource_hueristics import Resources def test_strip_pad_from_transcript(): @@ -273,3 +274,67 @@ def test_local_asr_apply_asr_to_df(): sys.modules.pop("nemo_retriever.model.local", None) else: sys.modules["nemo_retriever.model.local"] = prev_local + + +def test_local_asr_gpu_actor_does_not_call_get_client(): + """ASRGPUActor should use the local model path and never create a remote client.""" + mock_model = MagicMock() + mock_model.transcribe.return_value = ["gpu local transcript"] + mock_class = MagicMock(return_value=mock_model) + mock_local = MagicMock() + mock_local.ParakeetCTC1B1ASR = mock_class + prev_local = sys.modules.get("nemo_retriever.model.local") + sys.modules["nemo_retriever.model.local"] = mock_local + try: + with patch("nemo_retriever.audio.asr_actor._get_client") as mock_get: + params = ASRParams(audio_endpoints=(None, None)) + actor = ASRGPUActor(params=params) + + mock_get.assert_not_called() + assert actor._client is None + assert actor._model is mock_model + + batch = pd.DataFrame( + [ + { + "path": "/tmp/chunk.wav", + "bytes": b"fake_audio_bytes", + "source_path": "/tmp/source.wav", + "duration": 1.0, + "chunk_index": 0, + "metadata": {}, + "page_number": 0, + } + ] + ) + out = actor(batch) + + assert len(out) == 1 + assert out["text"].iloc[0] == "gpu local transcript" + mock_model.transcribe.assert_called_once() + finally: + if prev_local is None: + sys.modules.pop("nemo_retriever.model.local", None) + else: + sys.modules["nemo_retriever.model.local"] = prev_local + + +def test_asr_gpu_actor_rejects_remote_configuration(): + with patch("nemo_retriever.audio.asr_actor._get_client") as mock_get: + params = ASRParams(audio_endpoints=("localhost:50051", None)) + try: + ASRGPUActor(params=params) + except ValueError as exc: + assert "does not support remote endpoints" in str(exc) + else: + raise AssertionError("ASRGPUActor should reject remote endpoint configuration") + mock_get.assert_not_called() + + +def test_asr_actor_resolution_prefers_cpu_for_remote_and_gpu_for_local(): + remote_params = ASRParams(audio_endpoints=("localhost:50051", None)) + local_params = ASRParams(audio_endpoints=(None, None)) + gpu_resources = Resources(cpu_count=8, gpu_count=1) + + assert ASRActor.resolve_operator_class(gpu_resources, {"params": remote_params}) is ASRCPUActor + assert ASRActor.resolve_operator_class(gpu_resources, {"params": local_params}) is ASRGPUActor diff --git a/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py b/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py index 7baf23ff08..87a7611a4e 100644 --- a/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py +++ b/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py @@ -22,6 +22,7 @@ def test_is_standalone_class(self): assert isinstance(GPUOperator(), GPUOperator) def test_gpu_operators_have_flag(self): + from nemo_retriever.audio.asr_actor import ASRGPUActor from nemo_retriever.page_elements.page_elements import PageElementDetectionGPUActor from nemo_retriever.chart.chart_detection import GraphicElementsGPUActor from nemo_retriever.table.table_detection import TableStructureGPUActor @@ -33,6 +34,7 @@ def test_gpu_operators_have_flag(self): from nemo_retriever.rerank.rerank import NemotronRerankGPUActor from nemo_retriever.text_embed.text_embed import TextEmbedGPUActor + assert issubclass(ASRGPUActor, GPUOperator) assert issubclass(PageElementDetectionGPUActor, GPUOperator) assert issubclass(GraphicElementsGPUActor, GPUOperator) assert issubclass(TableStructureGPUActor, GPUOperator)