Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_retriever/src/nemo_retriever/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor
from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor, ASRGPUActor
from nemo_retriever.audio.asr_actor import asr_params_from_env
from nemo_retriever.audio.chunk_actor import MediaChunkActor
from nemo_retriever.audio.media_interface import MediaInterface
Expand All @@ -23,6 +23,7 @@
__all__ = [
"ASRActor",
"ASRCPUActor",
"ASRGPUActor",
"ASRParams",
"app",
"asr_params_from_env",
Expand Down
57 changes: 47 additions & 10 deletions nemo_retriever/src/nemo_retriever/audio/asr_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from nemo_retriever.graph.abstract_operator import AbstractOperator
from nemo_retriever.graph.cpu_operator import CPUOperator
from nemo_retriever.graph.gpu_operator import GPUOperator
from nemo_retriever.graph.operator_archetype import ArchetypeOperator
from nemo_retriever.params import ASRParams

Expand Down Expand Up @@ -123,7 +124,13 @@ def _get_client(params: ASRParams): # noqa: ANN201
)


class ASRCPUActor(AbstractOperator, CPUOperator):
def _create_local_model(): # noqa: ANN201
from nemo_retriever.model.local import ParakeetCTC1B1ASR

return ParakeetCTC1B1ASR()


class _ASRBaseActor(AbstractOperator):
"""
Ray Data map_batches callable: chunk rows (path/bytes) -> rows with text (transcript).

Expand All @@ -134,17 +141,17 @@ class ASRCPUActor(AbstractOperator, CPUOperator):
segments are emitted as multiple rows per chunk.
"""

def __init__(self, params: ASRParams | None = None) -> None:
def __init__(
self,
params: ASRParams | None = None,
*,
client: Any = None,
model: Any = None,
) -> None:
super().__init__(params=params)
self._params = params or ASRParams()
if _use_remote(self._params):
self._client = _get_client(self._params)
self._model = None
else:
self._client = None
from nemo_retriever.model.local import ParakeetCTC1B1ASR

self._model = ParakeetCTC1B1ASR()
self._client = client
self._model = model

def preprocess(self, data: Any, **kwargs: Any) -> Any:
return data
Expand Down Expand Up @@ -369,10 +376,40 @@ def _transcribe_one(self, row: pd.Series) -> List[Dict[str, Any]]:
return self._build_output_rows(row, transcript)


class ASRCPUActor(_ASRBaseActor, CPUOperator):
"""CPU actor for ASR.

Uses remote Parakeet when endpoints are configured; otherwise falls back to
the local Parakeet model so CPU-only environments continue to work.
"""

def __init__(self, params: ASRParams | None = None) -> None:
resolved_params = params or ASRParams()
client = _get_client(resolved_params) if _use_remote(resolved_params) else None
model = None if client is not None else _create_local_model()
super().__init__(params=resolved_params, client=client, model=model)


class ASRGPUActor(_ASRBaseActor, GPUOperator):
"""GPU actor for ASR using the local Parakeet model."""

def __init__(self, params: ASRParams | None = None) -> None:
resolved_params = params or ASRParams()
if _use_remote(resolved_params):
raise ValueError("ASRGPUActor does not support remote endpoints. Use ASRCPUActor instead.")
super().__init__(params=resolved_params, client=None, model=_create_local_model())


class ASRActor(ArchetypeOperator):
"""Graph-facing ASR archetype resolved to the best concrete runtime implementation."""

_cpu_variant_class = ASRCPUActor
_gpu_variant_class = ASRGPUActor

@classmethod
def prefers_cpu_variant(cls, operator_kwargs: dict[str, Any] | None = None) -> bool:
params = (operator_kwargs or {}).get("params")
return isinstance(params, ASRParams) and _use_remote(params)

def __init__(self, params: ASRParams | None = None) -> None:
resolved_params = params or ASRParams()
Expand Down
67 changes: 66 additions & 1 deletion nemo_retriever/tests/test_asr_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import pandas as pd

from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor
from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor, ASRGPUActor
from nemo_retriever.audio.asr_actor import apply_asr_to_df
from nemo_retriever.params import ASRParams
from nemo_retriever.utils.ray_resource_hueristics import Resources


def test_strip_pad_from_transcript():
Expand Down Expand Up @@ -273,3 +274,67 @@ def test_local_asr_apply_asr_to_df():
sys.modules.pop("nemo_retriever.model.local", None)
else:
sys.modules["nemo_retriever.model.local"] = prev_local


def test_local_asr_gpu_actor_does_not_call_get_client():
"""ASRGPUActor should use the local model path and never create a remote client."""
mock_model = MagicMock()
mock_model.transcribe.return_value = ["gpu local transcript"]
mock_class = MagicMock(return_value=mock_model)
mock_local = MagicMock()
mock_local.ParakeetCTC1B1ASR = mock_class
prev_local = sys.modules.get("nemo_retriever.model.local")
sys.modules["nemo_retriever.model.local"] = mock_local
try:
with patch("nemo_retriever.audio.asr_actor._get_client") as mock_get:
params = ASRParams(audio_endpoints=(None, None))
actor = ASRGPUActor(params=params)

mock_get.assert_not_called()
assert actor._client is None
assert actor._model is mock_model

batch = pd.DataFrame(
[
{
"path": "/tmp/chunk.wav",
"bytes": b"fake_audio_bytes",
"source_path": "/tmp/source.wav",
"duration": 1.0,
"chunk_index": 0,
"metadata": {},
"page_number": 0,
}
]
)
out = actor(batch)

assert len(out) == 1
assert out["text"].iloc[0] == "gpu local transcript"
mock_model.transcribe.assert_called_once()
finally:
if prev_local is None:
sys.modules.pop("nemo_retriever.model.local", None)
else:
sys.modules["nemo_retriever.model.local"] = prev_local


def test_asr_gpu_actor_rejects_remote_configuration():
with patch("nemo_retriever.audio.asr_actor._get_client") as mock_get:
params = ASRParams(audio_endpoints=("localhost:50051", None))
try:
ASRGPUActor(params=params)
except ValueError as exc:
assert "does not support remote endpoints" in str(exc)
else:
raise AssertionError("ASRGPUActor should reject remote endpoint configuration")
mock_get.assert_not_called()


def test_asr_actor_resolution_prefers_cpu_for_remote_and_gpu_for_local():
remote_params = ASRParams(audio_endpoints=("localhost:50051", None))
local_params = ASRParams(audio_endpoints=(None, None))
gpu_resources = Resources(cpu_count=8, gpu_count=1)

assert ASRActor.resolve_operator_class(gpu_resources, {"params": remote_params}) is ASRCPUActor
assert ASRActor.resolve_operator_class(gpu_resources, {"params": local_params}) is ASRGPUActor
2 changes: 2 additions & 0 deletions nemo_retriever/tests/test_operator_flags_and_cpu_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_is_standalone_class(self):
assert isinstance(GPUOperator(), GPUOperator)

def test_gpu_operators_have_flag(self):
from nemo_retriever.audio.asr_actor import ASRGPUActor
from nemo_retriever.page_elements.page_elements import PageElementDetectionGPUActor
from nemo_retriever.chart.chart_detection import GraphicElementsGPUActor
from nemo_retriever.table.table_detection import TableStructureGPUActor
Expand All @@ -33,6 +34,7 @@ def test_gpu_operators_have_flag(self):
from nemo_retriever.rerank.rerank import NemotronRerankGPUActor
from nemo_retriever.text_embed.text_embed import TextEmbedGPUActor

assert issubclass(ASRGPUActor, GPUOperator)
assert issubclass(PageElementDetectionGPUActor, GPUOperator)
assert issubclass(GraphicElementsGPUActor, GPUOperator)
assert issubclass(TableStructureGPUActor, GPUOperator)
Expand Down
Loading