Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/llm/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[Tool]] = NOT_GIVEN,
) -> asyncio.Future[GenerationCreatedEvent]: ... # can raise RealtimeError on Timeout

# commit the input audio buffer to the server
Expand Down
90 changes: 63 additions & 27 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from .. import inference, llm, stt, tts, utils, vad
from ..llm.chat_context import Instructions
from ..llm.tool_context import (
FunctionToolInfo,
RawFunctionToolInfo,
StopResponse,
ToolFlag,
get_fnc_tool_names,
Expand Down Expand Up @@ -977,6 +975,7 @@ def _generate_reply(
chat_ctx: NotGivenOr[llm.ChatContext | None] = NOT_GIVEN,
instructions: NotGivenOr[str | Instructions] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[str]] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
schedule_speech: bool = True,
input_details: InputDetails = DEFAULT_INPUT_DETAILS,
Expand All @@ -1002,24 +1001,39 @@ def _generate_reply(
# when generate_reply is called inside a function_tool, set tool_choice to None by default # noqa: E501
tool_choice = "none"

tools = self.tools
all_tools = self.tools.copy()

# resolve tool names to Tool objects if tools param is given
resolved_tools: NotGivenOr[list[llm.Tool | llm.Toolset]] = NOT_GIVEN
if is_given(tools):
tool_ctx = llm.ToolContext(all_tools)
toolset_dict = {t.id: t for t in tool_ctx.toolsets}
tool_dict = {t.id: t for t in tool_ctx.flatten()}
resolved_tools = list[llm.Tool | llm.Toolset]()
for name in set(tools):
tool = toolset_dict.get(name) or tool_dict.get(name)
if tool is None:
raise ValueError(
f"tool '{name}' not found in agent's registered tools. "
f"Available tools: {list(tool_ctx.function_tools.keys())}"
)
resolved_tools.append(tool)

# if tool has the IGNORE_ON_ENTER flag, every generate_reply inside on_enter will ignore it
if on_enter_data := _OnEnterContextVar.get(None):
if on_enter_data.agent == self._agent and on_enter_data.session == self._session:
filtered_tools: list[llm.Tool | llm.Toolset] = []
for tool in tools:
info: RawFunctionToolInfo | FunctionToolInfo | None = None
if isinstance(tool, llm.RawFunctionTool | llm.FunctionTool):
info = tool.info

if info and (info.flags & ToolFlag.IGNORE_ON_ENTER):
to_filter = resolved_tools if is_given(resolved_tools) else all_tools
for tool in to_filter:
if (
isinstance(tool, llm.RawFunctionTool | llm.FunctionTool)
and tool.info.flags & ToolFlag.IGNORE_ON_ENTER
):
continue

# TODO(long): add IGNORE_ON_ENTER to ToolSet?
filtered_tools.append(tool)

tools = filtered_tools
to_filter[:] = filtered_tools

handle = SpeechHandle.create(
allow_interruptions=allow_interruptions
Expand All @@ -1039,7 +1053,7 @@ def _generate_reply(
# TODO(theomonnom): support llm.ChatMessage for the realtime model
user_input=user_message.text_content if user_message else None,
instructions=instructions or None,
# TODO(theomonnom): the list of tools should always be passed here
tools=resolved_tools if is_given(resolved_tools) else None,
model_settings=ModelSettings(tool_choice=tool_choice),
),
speech_handle=handle,
Expand All @@ -1051,7 +1065,7 @@ def _generate_reply(
self._pipeline_reply_task(
speech_handle=handle,
chat_ctx=chat_ctx or self._agent._chat_ctx,
tools=tools,
tools=resolved_tools if is_given(resolved_tools) else all_tools,
new_message=user_message if is_given(user_message) else None,
instructions=instructions or None,
model_settings=ModelSettings(
Expand Down Expand Up @@ -2597,6 +2611,7 @@ async def _realtime_reply_task(
*,
speech_handle: SpeechHandle,
model_settings: ModelSettings,
tools: list[llm.Tool | llm.Toolset] | None = None,
user_input: str | None = None,
instructions: str | None = None,
tool_reply: bool = False,
Expand All @@ -2620,20 +2635,36 @@ async def _realtime_reply_task(
self._agent._chat_ctx._upsert_item(msg)
self._session._conversation_item_added(msg)

per_response_tool_choice = (
self._rt_session.realtime_model.capabilities.per_response_tool_choice
)
ori_tool_choice = self._tool_choice
if utils.is_given(model_settings.tool_choice) and not per_response_tool_choice:
self._rt_session.update_options(tool_choice=model_settings.tool_choice)

ori_tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN
ori_tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN
try:
if not (
per_response_tool_choice
:= self._rt_session.realtime_model.capabilities.per_response_tool_choice
):
# update the tool and tool choice at the session level if they are specified
if (
is_given(model_settings.tool_choice)
and model_settings.tool_choice != self._tool_choice
):
ori_tool_choice = self._tool_choice
self._rt_session.update_options(tool_choice=model_settings.tool_choice)

if tools is not None:
ori_tools = self._rt_session.tools.flatten()
await self._rt_session.update_tools(llm.ToolContext(tools).flatten())

try:
generation_ev = await self._rt_session.generate_reply(
instructions=instructions or NOT_GIVEN,
tool_choice=(
model_settings.tool_choice if per_response_tool_choice else NOT_GIVEN
),
tools=(
llm.ToolContext(tools).flatten()
if per_response_tool_choice and tools
else NOT_GIVEN
),
)
except llm.RealtimeError as e:
logger.error(
Expand All @@ -2652,13 +2683,18 @@ async def _realtime_reply_task(
instructions=instructions,
)
finally:
# reset tool_choice value (only needed for non-per-response models)
if (
not per_response_tool_choice
and utils.is_given(model_settings.tool_choice)
and model_settings.tool_choice != ori_tool_choice
):
self._rt_session.update_options(tool_choice=ori_tool_choice)
# reset tool_choice and tools
if is_given(ori_tool_choice):
try:
self._rt_session.update_options(tool_choice=ori_tool_choice)
except Exception:
logger.exception("failed to reset tool_choice")

if is_given(ori_tools):
try:
await self._rt_session.update_tools(ori_tools)
except Exception:
logger.exception("failed to reset tools")

@utils.log_exceptions(logger=logger)
async def _realtime_generation_task(
Expand Down
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def generate_reply(
user_input: NotGivenOr[str | llm.ChatMessage] = NOT_GIVEN,
instructions: NotGivenOr[str | Instructions] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[str]] = NOT_GIVEN,
allow_interruptions: NotGivenOr[bool] = NOT_GIVEN,
chat_ctx: NotGivenOr[ChatContext] = NOT_GIVEN,
input_modality: Literal["text", "audio"] = "text",
Expand All @@ -1111,6 +1112,8 @@ def generate_reply(
instructions (NotGivenOr[str], optional): Additional instructions for generating the reply.
tool_choice (NotGivenOr[llm.ToolChoice], optional): Specifies the external tool to use when
generating the reply. If generate_reply is invoked within a function_tool, defaults to "none".
tools (NotGivenOr[list[str]], optional): List of tool IDs to make available for this response.
When set, only the specified tools can be used. Tool IDs must match registered tools on the agent.
allow_interruptions (NotGivenOr[bool], optional): Indicates whether the user can interrupt this speech.
chat_ctx (NotGivenOr[ChatContext], optional): The chat context to use for generating the reply.
Defaults to the chat context of the current agent if not provided.
Expand Down Expand Up @@ -1144,6 +1147,7 @@ def generate_reply(
user_message=user_message if user_message else None,
instructions=instructions,
tool_choice=tool_choice,
tools=tools,
allow_interruptions=allow_interruptions,
chat_ctx=chat_ctx,
input_details=InputDetails(modality=input_modality),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,7 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
"""Generate a reply from the model.

Expand Down Expand Up @@ -2033,6 +2034,10 @@ def generate_reply(
update_chat_ctx() which sends interactive text to Nova Sonic.
This method handles the instructions parameter for system-level prompts.
"""
if is_given(tools):
logger.warning(
"per-response tools is not supported by AWS Nova Sonic Realtime API, ignoring"
)
# Check if generate_reply is supported (requires mixed modalities)
if self._realtime_model.modalities != "mixed":
logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,10 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
if is_given(tools):
logger.warning("per-response tools is not supported by Google Realtime API, ignoring")
if self._opts.model == "gemini-3.1-flash-live-preview":
logger.warning(
"generate_reply is not compatible with 'gemini-3.1-flash-live-preview' and will be ignored."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
raise NotImplementedError(
"generate_reply is not yet supported by the PersonaPlex realtime model."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,7 @@ async def update_tools(self, tools: list[llm.Tool]) -> None:
self._tools = llm.ToolContext(retained_tools)

# this function can be overrided
def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
def _convert_tools_to_oai(self, tools: list[llm.Tool]) -> list[RealtimeFunctionTool]:
oai_tools: list[RealtimeFunctionTool] = []

for tool in tools:
Expand Down Expand Up @@ -1385,6 +1385,11 @@ def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
)
continue

return oai_tools

def _create_tools_update_event(self, tools: list[llm.Tool]) -> dict[str, Any]:
oai_tools = self._convert_tools_to_oai(tools)

event = self._wrap_session_update(
event_id=utils.shortuuid("tools_update_"),
session=RealtimeSessionCreateRequest.model_construct(
Expand Down Expand Up @@ -1449,6 +1454,7 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
event_id = utils.shortuuid("response_create_")
fut = asyncio.Future[llm.GenerationCreatedEvent]()
Expand All @@ -1460,6 +1466,8 @@ def generate_reply(
)
if is_given(tool_choice):
params.tool_choice = to_oai_tool_choice(tool_choice)
if is_given(tools):
params.tools = self._convert_tools_to_oai(tools) # type: ignore

self.send_event(
ResponseCreateEvent(type="response.create", event_id=event_id, response=params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,10 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
if is_given(tools):
logger.warning("per-response tools is not supported by Phonic Realtime API, ignoring")
payload = GenerateReplyPayload(
system_message=instructions if is_given(instructions) else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,11 @@ def generate_reply(
*,
instructions: NotGivenOr[str] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
tools: NotGivenOr[list[llm.Tool]] = NOT_GIVEN,
) -> asyncio.Future[llm.GenerationCreatedEvent]:
"""Generate a reply from the LLM based on the instructions."""
if is_given(tools):
logger.warning("per-response tools is not supported by Ultravox Realtime API, ignoring")
# Cancel prior pending generation if exists
if self._pending_generation_fut and not self._pending_generation_fut.done():
logger.warning(
Expand Down
Loading