Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def __init__(
self._http_session = http_session
self._http_session_owned = False
self._sessions = weakref.WeakSet[RealtimeSession]()
self._provider_label = "OpenAI Realtime API"

@property
def model(self) -> str:
Expand Down Expand Up @@ -791,7 +792,7 @@ async def _main_task(self) -> None:

async def _reconnect() -> None:
logger.debug(
"reconnecting to OpenAI Realtime API",
f"reconnecting to {self._realtime_model._provider_label}",
extra={"max_session_duration": self._realtime_model._opts.max_session_duration},
)

Expand Down Expand Up @@ -834,7 +835,7 @@ async def _reconnect() -> None:
self._remote_chat_ctx = old_chat_ctx # restore the old chat context
raise APIConnectionError(
message=(
"Failed to send message to OpenAI Realtime API during session re-connection"
f"Failed to send message to {self._realtime_model._provider_label} during session re-connection"
),
) from e

Expand All @@ -846,7 +847,7 @@ async def _reconnect() -> None:
self._response_created_futures.clear()
self._close_current_generation("session reconnection")

logger.debug("reconnected to OpenAI Realtime API")
logger.debug(f"reconnected to {self._realtime_model._provider_label}")
self.emit("session_reconnected", llm.RealtimeSessionReconnectedEvent())

reconnecting = False
Expand All @@ -865,7 +866,7 @@ async def _reconnect() -> None:
elif num_retries == max_retries:
self._emit_error(e, recoverable=False)
raise APIConnectionError(
f"OpenAI Realtime API connection failed after {num_retries} attempts",
f"{self._realtime_model._provider_label} connection failed after {num_retries} attempts",
) from e
else:
self._emit_error(e, recoverable=True)
Expand All @@ -874,7 +875,7 @@ async def _reconnect() -> None:
num_retries
)
logger.warning(
f"OpenAI Realtime API connection failed, retrying in {retry_interval}s",
f"{self._realtime_model._provider_label} connection failed, retrying in {retry_interval}s",
exc_info=e,
extra={"attempt": num_retries, "max_retries": max_retries},
)
Expand Down Expand Up @@ -918,10 +919,12 @@ async def _create_ws_conn(self) -> aiohttp.ClientWebSocketResponse:
self._report_connection_acquired(time.perf_counter() - t0)
return ws
except aiohttp.ClientError as e:
raise APIConnectionError("OpenAI Realtime API client connection error") from e
raise APIConnectionError(
f"{self._realtime_model._provider_label} client connection error"
) from e
except asyncio.TimeoutError as e:
raise APIConnectionError(
message="OpenAI Realtime API connection timed out",
message=f"{self._realtime_model._provider_label} connection timed out",
) from e

async def _run_ws(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
Expand Down Expand Up @@ -970,7 +973,9 @@ async def _recv_task() -> None:
return

# this will trigger a reconnection
raise APIConnectionError(message="OpenAI S2S connection closed unexpectedly")
raise APIConnectionError(
message=f"{self._realtime_model._provider_label} connection closed unexpectedly"
)

if msg.type != aiohttp.WSMsgType.TEXT:
continue
Expand Down Expand Up @@ -1371,7 +1376,8 @@ def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
continue # currently only xAI supports ProviderTools
else:
logger.error(
"OpenAI Realtime API doesn't support this tool type", extra={"tool": tool}
f"{self._realtime_model._provider_label} doesn't support this tool type",
extra={"tool": tool},
)
continue

Expand All @@ -1380,7 +1386,7 @@ def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
oai_tools.append(session_tool)
except ValidationError:
logger.error(
"OpenAI Realtime API doesn't support this tool",
f"{self._realtime_model._provider_label} doesn't support this tool",
extra={"tool": tool_desc},
)
continue
Expand Down Expand Up @@ -1646,7 +1652,9 @@ def _handle_response_content_part_added(self, event: ResponseContentPartAddedEve
assert (item_type := event.part.type) is not None, "part.type is None"

if item_type == "text" and self._realtime_model.capabilities.audio_output:
logger.warning("Text response received from OpenAI Realtime API in audio modality.")
logger.warning(
f"Text response received from {self._realtime_model._provider_label} in audio modality."
)

with contextlib.suppress(asyncio.InvalidStateError):
self._current_generation.messages[item_id].modalities.set_result(
Expand Down Expand Up @@ -1708,7 +1716,7 @@ def _handle_conversion_item_input_audio_transcription_failed(
self, event: ConversationItemInputAudioTranscriptionFailedEvent
) -> None:
logger.error(
"OpenAI Realtime API failed to transcribe input audio",
f"{self._realtime_model._provider_label} failed to transcribe input audio",
extra={"error": event.error},
)

Expand Down Expand Up @@ -1885,14 +1893,15 @@ def _handle_response_done_but_not_complete(self, event: ResponseDoneEvent) -> No
if event.response.status == "completed":
return

provider_label = self._realtime_model._provider_label
if event.response.status == "failed":
if event.response.status_details and hasattr(event.response.status_details, "error"):
error_type = getattr(event.response.status_details.error, "type", "unknown")
error_body = event.response.status_details.error
message = f"OpenAI Realtime API response failed with error type: {error_type}"
message = f"{provider_label} response failed with error type: {error_type}"
else:
error_body = None
message = "OpenAI Realtime API response failed with unknown error"
message = f"{provider_label} response failed with unknown error"
self._emit_error(
APIError(
message=message,
Expand All @@ -1905,7 +1914,8 @@ def _handle_response_done_but_not_complete(self, event: ResponseDoneEvent) -> No
)
elif event.response.status in {"cancelled", "incomplete"}:
logger.debug(
"OpenAI Realtime API response done but not complete with status: %s",
"%s response done but not complete with status: %s",
provider_label,
event.response.status,
extra={
"event_id": event.response.id,
Expand All @@ -1919,13 +1929,14 @@ def _handle_error(self, event: RealtimeErrorEvent) -> None:
if event.error.message.startswith("Cancellation failed"):
return

provider_label = self._realtime_model._provider_label
logger.error(
"OpenAI Realtime API returned an error",
f"{provider_label} returned an error",
extra={"error": event.error},
)
self._emit_error(
APIError(
message="OpenAI Realtime API returned an error",
message=f"{provider_label} returned an error",
body=event.error,
retryable=True,
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import asyncio
import os
import time
from collections import OrderedDict
from typing import Any

import aiohttp
from openai.types.beta.realtime.session import TurnDetection
from openai.types.realtime import AudioTranscription, RealtimeConversationItemFunctionCall
from openai.types.realtime import (
AudioTranscription,
ConversationItemAdded,
ConversationItemDeletedEvent,
ConversationItemInputAudioTranscriptionCompletedEvent,
RealtimeConversationItemFunctionCall,
)
from openai.types.realtime.realtime_audio_input_turn_detection import ServerVad

from livekit.agents import llm
Expand Down Expand Up @@ -72,6 +80,7 @@ def __init__(
conn_options=conn_options,
)
self._capabilities.per_response_tool_choice = False
self._provider_label = "xAI Realtime API"

def session(self) -> "RealtimeSession":
sess = RealtimeSession(self)
Expand All @@ -87,6 +96,9 @@ def __init__(self, realtime_model: RealtimeModel) -> None:
self._xai_model: RealtimeModel = realtime_model
self._session_connected_at: float = 0.0

# keep the order of item deletion futures
self._item_delete_future = OrderedDict[str, asyncio.Future]()

async def _run_ws(self, ws_conn: aiohttp.ClientWebSocketResponse) -> None:
self._session_connected_at = time.time()
await super()._run_ws(ws_conn)
Expand Down Expand Up @@ -127,3 +139,34 @@ def _handle_function_call(self, item: RealtimeConversationItemFunctionCall) -> N
return

super()._handle_function_call(item)

def _handle_conversion_item_added(self, event: ConversationItemAdded) -> None:
# xAI's `conversation.item.added` event always has the previous_item_id as None
# replace it with the last item in the remote chat context
if event.previous_item_id is None:
event.previous_item_id = (
self._remote_chat_ctx._tail.item.id if self._remote_chat_ctx._tail else None
)

super()._handle_conversion_item_added(event)

def _handle_conversion_item_deleted(self, event: ConversationItemDeletedEvent) -> None:
# xAI's `conversation.item.deleted` event has item_id empty
# assuming it's the first item in the deletion list
if event.item_id == "" and self._item_delete_future:
event.item_id = list(self._item_delete_future.keys())[0]

super()._handle_conversion_item_deleted(event)

def _handle_conversion_item_input_audio_transcription_completed(
self, event: ConversationItemInputAudioTranscriptionCompletedEvent
) -> None:
# audio transcription is included when the item is added
# clear the content before appending the transcript to avoid duplicates
if remote_item := self._remote_chat_ctx.get(event.item_id):
if (
remote_item.item.type == "message"
and remote_item.item.text_content == event.transcript
):
remote_item.item.content.clear()
super()._handle_conversion_item_input_audio_transcription_completed(event)
Loading