feat(stt): add speaker diarization support to STT interface and proxy#5283
feat(stt): add speaker diarization support to STT interface and proxy#5283russellmartin-livekit wants to merge 5 commits intomainfrom
Conversation
The inference STT capabilities.diarization was hardcoded to False, which caused MultiSpeakerAdapter to not work since it checks capabilities.diarization before enabling diarization. This change: - Adds diarize option to DeepgramOptions TypedDict - Adds speaker_labels option to AssemblyaiOptions TypedDict - Detects diarization params in extra_kwargs and sets capabilities - Updates capabilities when update_options() is called with diarization - Adds comprehensive tests for diarization capability detection Fixes AGT-2608 Slack thread: https://live-kit.slack.com/archives/C06TN33TV44/p1772573869144129?thread_ts=1771977322.899519&cid=C06TN33TV44 https://claude.ai/code/session_01VRKQuBXiq8BHKr9AiJ6uEw
|
|
| class TestSTTDiarizationCapabilities: | ||
| """Tests for STT diarization capability detection from extra_kwargs.""" | ||
|
|
||
| def test_no_diarization_by_default(self): | ||
| """Without diarization params, capabilities.diarization is False.""" | ||
| stt = _make_stt() | ||
| assert stt.capabilities.diarization is False | ||
|
|
||
| def test_diarization_enabled_with_deepgram_diarize(self): | ||
| """Deepgram's 'diarize' param enables diarization capability.""" | ||
| stt = _make_stt(extra_kwargs={"diarize": True}) | ||
| assert stt.capabilities.diarization is True | ||
|
|
||
| def test_diarization_disabled_with_diarize_false(self): | ||
| """Deepgram's 'diarize: False' keeps diarization capability False.""" | ||
| stt = _make_stt(extra_kwargs={"diarize": False}) | ||
| assert stt.capabilities.diarization is False | ||
|
|
||
| def test_diarization_enabled_with_assemblyai_speaker_labels(self): | ||
| """AssemblyAI's 'speaker_labels' param enables diarization capability.""" | ||
| stt = _make_stt(model="assemblyai/universal-streaming", extra_kwargs={"speaker_labels": True}) | ||
| assert stt.capabilities.diarization is True | ||
|
|
||
| def test_diarization_disabled_with_speaker_labels_false(self): | ||
| """AssemblyAI's 'speaker_labels: False' keeps diarization capability False.""" | ||
| stt = _make_stt(model="assemblyai/universal-streaming", extra_kwargs={"speaker_labels": False}) | ||
| assert stt.capabilities.diarization is False | ||
|
|
||
| def test_diarization_with_other_extra_kwargs(self): | ||
| """Diarization works alongside other extra_kwargs.""" | ||
| stt = _make_stt(extra_kwargs={"diarize": True, "punctuate": True, "smart_format": True}) | ||
| assert stt.capabilities.diarization is True | ||
|
|
||
| def test_update_options_enables_diarization(self): | ||
| """update_options with diarization params enables diarization capability.""" | ||
| stt = _make_stt() | ||
| assert stt.capabilities.diarization is False | ||
| stt.update_options(extra={"diarize": True}) | ||
| assert stt.capabilities.diarization is True | ||
|
|
||
| def test_update_options_disables_diarization(self): | ||
| """update_options can disable diarization by setting params to False.""" | ||
| stt = _make_stt(extra_kwargs={"diarize": True}) | ||
| assert stt.capabilities.diarization is True | ||
| stt.update_options(extra={"diarize": False}) | ||
| assert stt.capabilities.diarization is False |
There was a problem hiding this comment.
Those tests aren't really useful?
There was a problem hiding this comment.
yeah its overkill, removed most of them
This update introduces the ability to determine the speaker ID in the SpeechData object when processing speech events. If the speaker is not provided and the speech is final, the system will infer the speaker based on the most common speaker ID from the recognized words. Additionally, several outdated tests related to diarization capabilities have been removed, and the test for toggling diarization has been updated for clarity. This change improves the accuracy of speaker identification in speech recognition scenarios.
| end_time=self.start_time_offset + data.get("start", 0) + data.get("duration", 0), | ||
| confidence=data.get("confidence", 1.0), | ||
| text=text, | ||
| speaker_id=f"S{speaker}" if speaker is not None else None, |
There was a problem hiding this comment.
What's tricky with diarization in our inference API. Is that we need a way for the speaker_id to be consistent across every provider.
Maybe some of them aren't even int, and str?
There was a problem hiding this comment.
I could standardize the inference side or I could just pass through whatever we get from the provider
There was a problem hiding this comment.
I think the gateway should standardize it.
It would be OK if it was extra, but everything inside the core "inference" API should be the same across every provider
There was a problem hiding this comment.
Agree, my latest changes standardizes it on the gateway size
There was a problem hiding this comment.
Standardize on gateway side to return string ints so it stays compatible with plugins that require strings. Also added support for word level diarization
This update introduces a new speaker_id attribute to the TimedString class, allowing for better tracking of speaker information in speech recognition scenarios. The STT processing logic has been updated to utilize this new attribute, enhancing the accuracy of speaker identification in the SpeechData object.
This update modifies the speaker_id attribute in the TimedString class to be of type str instead of int. The STT processing logic has been adjusted accordingly to ensure compatibility with this change, enhancing the handling of speaker identification in speech recognition scenarios.
This update simplifies the extraction of speaker_id in the SpeechData object by directly using the value from the input data and words, removing unnecessary conversions. This change enhances code clarity and maintains compatibility with recent updates to the TimedString class.
Related gateway change: https://github.com/livekit/agent-gateway/pull/557
Changes in
agentsDeepgramOptionsandAssemblyaiOptionsto explicitly type diarization flags (diarizeandspeaker_labels).STTinference wrapper to dynamically set thediarizationcapability based on the providedextra_kwargsduring initialization andupdate_options.speaker_idtoTimedStringand populates it from the inference proxy's response.Fixes AGT-2608
Slack thread: https://live-kit.slack.com/archives/C06TN33TV44/p1772573869144129?thread_ts=1771977322.899519&cid=C06TN33TV44
https://claude.ai/code/session_01VRKQuBXiq8BHKr9AiJ6uEw