Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions livekit-agents/livekit/agents/inference/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DeepgramOptions(TypedDict, total=False):
numerals: bool
mip_opt_out: bool
vad_events: bool # default: False
diarize: bool # default: False - enables speaker diarization


class AssemblyaiOptions(TypedDict, total=False):
Expand All @@ -78,6 +79,7 @@ class AssemblyaiOptions(TypedDict, total=False):
max_turn_silence: int # default: not specified
keyterms_prompt: list[str] # default: not specified
prompt: str # default: not specified (u3-rt-pro only, mutually exclusive with keyterms_prompt)
speaker_labels: bool # default: False - enables speaker diarization


class ElevenlabsOptions(TypedDict, total=False):
Expand Down Expand Up @@ -286,10 +288,19 @@ def __init__(
a list of FallbackModel instances.
conn_options (APIConnectOptions, optional): Connection options for request attempts.
"""
# Check extra_kwargs for diarization parameters across different providers
# Deepgram uses "diarize", AssemblyAI uses "speaker_labels"
diarization_enabled = False
if is_given(extra_kwargs):
diarization_enabled = bool(
extra_kwargs.get("diarize") or extra_kwargs.get("speaker_labels")
)

super().__init__(
capabilities=stt.STTCapabilities(
streaming=True,
interim_results=True,
diarization=diarization_enabled,
aligned_transcript="word",
offline_recognize=False,
),
Expand Down Expand Up @@ -410,6 +421,12 @@ def update_options(
self._opts.language = LanguageCode(language)
if is_given(extra):
self._opts.extra_kwargs.update(extra)
# Update diarization capability based on extra_kwargs
diarization_enabled = bool(
self._opts.extra_kwargs.get("diarize")
or self._opts.extra_kwargs.get("speaker_labels")
)
self._capabilities = replace(self._capabilities, diarization=diarization_enabled)

for stream in self._streams:
stream.update_options(model=model, language=language, extra=extra)
Expand Down Expand Up @@ -651,13 +668,15 @@ def _process_transcript(self, data: dict, is_final: bool) -> None:
end_time=self.start_time_offset + data.get("start", 0) + data.get("duration", 0),
confidence=data.get("confidence", 1.0),
text=text,
speaker_id=data.get("speaker_id"),
words=[
TimedString(
text=word.get("word", ""),
start_time=word.get("start", 0) + self.start_time_offset,
end_time=word.get("end", 0) + self.start_time_offset,
start_time_offset=self.start_time_offset,
confidence=word.get("confidence", 0.0),
speaker_id=word.get("speaker_id"),
)
for word in words
],
Expand Down
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class TimedString(str):
confidence: NotGivenOr[float]
start_time_offset: NotGivenOr[float]
# offset relative to the start of the audio input stream or session in seconds, used in STT plugins
speaker_id: str | None

def __new__(
cls,
Expand All @@ -151,10 +152,12 @@ def __new__(
end_time: NotGivenOr[float] = NOT_GIVEN,
confidence: NotGivenOr[float] = NOT_GIVEN,
start_time_offset: NotGivenOr[float] = NOT_GIVEN,
speaker_id: str | None = None,
) -> "TimedString":
obj = super().__new__(cls, text)
obj.start_time = start_time
obj.end_time = end_time
obj.confidence = confidence
obj.start_time_offset = start_time_offset
obj.speaker_id = speaker_id
return obj
28 changes: 28 additions & 0 deletions tests/test_inference_stt_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,31 @@ def test_connect_options_full_custom(self):
assert stt._opts.conn_options.timeout == 60.0
assert stt._opts.conn_options.max_retry == 10
assert stt._opts.conn_options.retry_interval == 2.0


class TestSTTDiarizationCapabilities:
"""Tests for STT diarization capability detection from extra_kwargs."""

def test_no_diarization_by_default(self):
"""Without diarization params, capabilities.diarization is False."""
stt = _make_stt()
assert stt.capabilities.diarization is False

def test_diarization_enabled_with_deepgram_diarize(self):
"""Deepgram's 'diarize' param enables diarization capability."""
stt = _make_stt(extra_kwargs={"diarize": True})
assert stt.capabilities.diarization is True

def test_diarization_enabled_with_assemblyai_speaker_labels(self):
"""AssemblyAI's 'speaker_labels' param enables diarization capability."""
stt = _make_stt(model="assemblyai/universal-streaming", extra_kwargs={"speaker_labels": True})
assert stt.capabilities.diarization is True

def test_update_options_toggles_diarization(self):
"""update_options can enable and disable diarization capability."""
stt = _make_stt()
assert stt.capabilities.diarization is False
stt.update_options(extra={"diarize": True})
assert stt.capabilities.diarization is True
stt.update_options(extra={"diarize": False})
assert stt.capabilities.diarization is False
Comment on lines +262 to +287
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests aren't really useful?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah its overkill, removed most of them

Loading