-
Notifications
You must be signed in to change notification settings - Fork 77
Refactor realtime_ws to use built-in silero VAD #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
6f01e93
d4e74eb
16ac56d
53c226e
d425b33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,7 @@ use axum::{ | |||||||||||||
| response::IntoResponse, | ||||||||||||||
| }; | ||||||||||||||
| use base64::Engine; | ||||||||||||||
| use bytes::{BufMut, Bytes, BytesMut}; | ||||||||||||||
| use bytes::{BufMut, BytesMut}; | ||||||||||||||
| use futures_util::{ | ||||||||||||||
| sink::SinkExt, | ||||||||||||||
| stream::{SplitStream, StreamExt}, | ||||||||||||||
|
|
@@ -15,13 +15,7 @@ use tokio::sync::mpsc; | |||||||||||||
| use uuid::Uuid; | ||||||||||||||
|
|
||||||||||||||
| use crate::{ | ||||||||||||||
| ai::{ | ||||||||||||||
| ChatSession, | ||||||||||||||
| bailian::cosyvoice, | ||||||||||||||
| elevenlabs, | ||||||||||||||
| openai::realtime::*, | ||||||||||||||
| vad::{VadRealtimeClient, VadRealtimeEvent}, | ||||||||||||||
| }, | ||||||||||||||
| ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession}, | ||||||||||||||
| config::*, | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -44,11 +38,11 @@ pub struct RealtimeSession { | |||||||||||||
| pub input_audio_buffer: BytesMut, | ||||||||||||||
| pub triggered: bool, | ||||||||||||||
| pub is_generating: bool, | ||||||||||||||
| pub vad_realtime_client: Option<VadRealtimeClient>, | ||||||||||||||
| pub vad_session: Option<VadSession>, | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| impl RealtimeSession { | ||||||||||||||
| pub fn new(chat_session: ChatSession) -> Self { | ||||||||||||||
| pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self { | ||||||||||||||
| Self { | ||||||||||||||
| client: reqwest::Client::new(), | ||||||||||||||
| chat_session, | ||||||||||||||
|
|
@@ -58,7 +52,7 @@ impl RealtimeSession { | |||||||||||||
| input_audio_buffer: BytesMut::new(), | ||||||||||||||
| triggered: false, | ||||||||||||||
| is_generating: false, | ||||||||||||||
| vad_realtime_client: None, | ||||||||||||||
| vad_session, | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
@@ -72,7 +66,6 @@ pub struct StableRealtimeConfig { | |||||||||||||
|
|
||||||||||||||
| enum RealtimeEvent { | ||||||||||||||
| ClientEvent(ClientEvent), | ||||||||||||||
| VadEvent(VadRealtimeEvent), | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| pub async fn ws_handler( | ||||||||||||||
|
|
@@ -100,24 +93,36 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||
| chat_session.system_prompts = parts.sys_prompts; | ||||||||||||||
| chat_session.messages = parts.dynamic_prompts; | ||||||||||||||
|
|
||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||
| let mut session = RealtimeSession::new(chat_session); | ||||||||||||||
| let mut realtime_rx: Option<_> = None; | ||||||||||||||
|
|
||||||||||||||
| if let Some(vad_realtime_url) = &config.asr.vad_realtime_url { | ||||||||||||||
| match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await { | ||||||||||||||
| Ok((client, rx)) => { | ||||||||||||||
| session.vad_realtime_client = Some(client); | ||||||||||||||
| realtime_rx = Some(rx); | ||||||||||||||
| log::info!("Connected to VAD realtime service at {}", vad_realtime_url); | ||||||||||||||
| } | ||||||||||||||
| Err(e) => { | ||||||||||||||
| log::error!("Failed to connect to VAD realtime service: {}", e); | ||||||||||||||
| // Initialize built-in silero VAD session | ||||||||||||||
| let device = burn::backend::ndarray::NdArrayDevice::default(); | ||||||||||||||
| let vad_session = match silero_vad_burn::SileroVAD6Model::new(&device) { | ||||||||||||||
| Ok(vad_model) => { | ||||||||||||||
| match crate::ai::vad::VadSession::new(&config.asr.vad, Box::new(vad_model), device) { | ||||||||||||||
| Ok(session) => { | ||||||||||||||
| log::info!("Initialized built-in silero VAD session"); | ||||||||||||||
| Some(session) | ||||||||||||||
| } | ||||||||||||||
| Err(e) => { | ||||||||||||||
| log::error!("Failed to create VAD session: {}", e); | ||||||||||||||
| None | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| Err(e) => { | ||||||||||||||
| log::error!( | ||||||||||||||
| "Failed to load silero VAD model: {}. \ | ||||||||||||||
| This may be due to missing model files or insufficient memory.", | ||||||||||||||
| e | ||||||||||||||
| ); | ||||||||||||||
| None | ||||||||||||||
| } | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||
| let has_vad = vad_session.is_some(); | ||||||||||||||
| let mut session = RealtimeSession::new(chat_session, vad_session); | ||||||||||||||
|
|
||||||||||||||
| let turn_detection = if realtime_rx.is_some() { | ||||||||||||||
| let turn_detection = if has_vad { | ||||||||||||||
| TurnDetection::server_vad() | ||||||||||||||
| } else { | ||||||||||||||
| TurnDetection::none() | ||||||||||||||
|
|
@@ -244,33 +249,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||
| None | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| async fn select_event( | ||||||||||||||
| socket: &mut SplitStream<WebSocket>, | ||||||||||||||
| realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>, | ||||||||||||||
| ) -> Option<RealtimeEvent> { | ||||||||||||||
| if let Some(rx) = realtime_rx { | ||||||||||||||
| tokio::select! { | ||||||||||||||
| client_event = recv_client_event(socket) => { | ||||||||||||||
| client_event.map(RealtimeEvent::ClientEvent) | ||||||||||||||
| } | ||||||||||||||
| vad_event = rx.next_event() => { | ||||||||||||||
| match vad_event { | ||||||||||||||
| Ok(event) => Some(RealtimeEvent::VadEvent(event)), | ||||||||||||||
| Err(e) => { | ||||||||||||||
| log::error!("Failed to receive VAD event: {}", e); | ||||||||||||||
| None | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
| } else { | ||||||||||||||
| recv_client_event(socket) | ||||||||||||||
| .await | ||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await { | ||||||||||||||
| while let Some(event) = recv_client_event(&mut receiver) | ||||||||||||||
| .await | ||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||
| { | ||||||||||||||
| if let Err(e) = handle_client_message( | ||||||||||||||
| event, | ||||||||||||||
| &mut session, | ||||||||||||||
|
|
@@ -360,14 +342,14 @@ async fn handle_client_message( | |||||||||||||
| return Ok(()); | ||||||||||||||
| } | ||||||||||||||
| if turn_detection.turn_type == TurnDetectionType::ServerVad | ||||||||||||||
| && session.vad_realtime_client.is_none() | ||||||||||||||
| && session.vad_session.is_none() | ||||||||||||||
| { | ||||||||||||||
| let error_event = ServerEvent::Error { | ||||||||||||||
| event_id: Uuid::new_v4().to_string(), | ||||||||||||||
| error: ErrorDetails { | ||||||||||||||
| error_type: "invalid_request_error".to_string(), | ||||||||||||||
| code: Some("vad_realtime_not_connected".to_string()), | ||||||||||||||
| message: "VAD realtime service is not connected".to_string(), | ||||||||||||||
| code: Some("vad_not_available".to_string()), | ||||||||||||||
| message: "VAD session is not available".to_string(), | ||||||||||||||
| param: Some("turn_detection.type".to_string()), | ||||||||||||||
| event_id: None, | ||||||||||||||
| }, | ||||||||||||||
|
|
@@ -433,11 +415,11 @@ async fn handle_client_message( | |||||||||||||
| .as_ref() | ||||||||||||||
| .map(|t| t.turn_type == TurnDetectionType::ServerVad) | ||||||||||||||
| .unwrap_or_default() | ||||||||||||||
| && session.vad_realtime_client.is_some(); | ||||||||||||||
| && session.vad_session.is_some(); | ||||||||||||||
|
|
||||||||||||||
| log::debug!( | ||||||||||||||
| "Server VAD status: {} {:?}", | ||||||||||||||
| session.vad_realtime_client.is_some(), | ||||||||||||||
| session.vad_session.is_some(), | ||||||||||||||
| session.config.turn_detection | ||||||||||||||
| ); | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -473,26 +455,55 @@ async fn handle_client_message( | |||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Process audio through built-in silero VAD | ||||||||||||||
| if server_vad { | ||||||||||||||
| let samples_24k = audio_data | ||||||||||||||
| .chunks_exact(2) | ||||||||||||||
| .map(|chunk| { | ||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||
| }) | ||||||||||||||
| .collect::<Vec<f32>>(); | ||||||||||||||
| let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||
|
|
||||||||||||||
| let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k); | ||||||||||||||
| log::debug!( | ||||||||||||||
| "Sending audio chunk to VAD realtime service, length: {}", | ||||||||||||||
| sample_16k.len() | ||||||||||||||
| ); | ||||||||||||||
| session | ||||||||||||||
| .vad_realtime_client | ||||||||||||||
| .as_mut() | ||||||||||||||
| .unwrap() | ||||||||||||||
| .push_audio_16k_chunk(Bytes::from(sample_16k)) | ||||||||||||||
| .await?; | ||||||||||||||
| if let Some(vad_session) = session.vad_session.as_mut() { | ||||||||||||||
| // Convert 24kHz PCM16 to 16kHz f32 for VAD | ||||||||||||||
| let samples_24k: Vec<f32> = audio_data | ||||||||||||||
| .chunks_exact(2) | ||||||||||||||
| .map(|chunk| { | ||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||
| }) | ||||||||||||||
| .collect(); | ||||||||||||||
|
Comment on lines
+471
to
+476
|
||||||||||||||
| let samples_16k = | ||||||||||||||
| wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||
|
|
||||||||||||||
| // Process through VAD in chunks | ||||||||||||||
| let chunk_size = VadSession::vad_chunk_size(); | ||||||||||||||
| let mut speech_detected = false; | ||||||||||||||
| for chunk in samples_16k.chunks(chunk_size) { | ||||||||||||||
| if let Ok(is_speech) = vad_session.detect(chunk) { | ||||||||||||||
| if is_speech { | ||||||||||||||
| speech_detected = true; | ||||||||||||||
| } else if session.triggered { | ||||||||||||||
| // Speech ended - trigger commit | ||||||||||||||
| log::info!("VAD detected speech end, triggering commit"); | ||||||||||||||
| if handle_audio_buffer_commit(session, tx, None, asr) | ||||||||||||||
| .await? | ||||||||||||||
| { | ||||||||||||||
| generate_response(session, tx, tts).await?; | ||||||||||||||
| } | ||||||||||||||
| session.triggered = false; | ||||||||||||||
| if let Some(vs) = session.vad_session.as_mut() { | ||||||||||||||
| vs.reset_state(); | ||||||||||||||
| } | ||||||||||||||
| break; | ||||||||||||||
|
||||||||||||||
| if let Some(vs) = session.vad_session.as_mut() { | |
| vs.reset_state(); | |
| } | |
| break; | |
| // Reset VAD state so we can detect subsequent speech segments | |
| vad_session.reset_state(); |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The audio_start_ms is hardcoded to 0, which doesn't reflect the actual timestamp when speech started. This should track the cumulative audio duration processed to provide accurate timing information for the InputAudioBufferSpeechStarted event.
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD session state is reset before validating speech in the committed buffer. This reset occurs regardless of whether the real-time VAD is currently active (session.triggered). If speech was ongoing and commit is called, resetting the state could cause inconsistency between the real-time VAD state and the validation check. Consider only resetting when appropriate or maintaining separate VAD instances for real-time and validation.
| vad_session.reset_state(); | |
| // Only reset VAD state when real-time VAD is not currently triggered | |
| if !session.triggered { | |
| vad_session.reset_state(); | |
| } |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD processing logic (conversion of 24kHz PCM16 to 16kHz f32, chunking, and detection) is duplicated between the inline audio processing (lines 454-500) and the commit handler (lines 668-688). Consider extracting this into a helper function to improve maintainability and ensure consistent behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This audio conversion logic (24kHz PCM16 to f32) is duplicated in three locations: here (lines 462-467), in the handle_audio_buffer_commit function (lines 677-680), and appears twice within the same InputAudioBufferAppend handler. Consider extracting this into a helper function to improve maintainability and reduce duplication.