diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 54c739bbda..95ddfb5b07 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -159,6 +159,8 @@ ShellToolLocalSkill, ShellToolSkillReference, Tool, + ToolOrigin, + ToolOriginType, ToolOutputFileContent, ToolOutputFileContentDict, ToolOutputImage, @@ -359,6 +361,8 @@ def enable_verbose_stdout_logging(): "MCPApprovalResponseItem", "ToolCallItem", "ToolCallOutputItem", + "ToolOrigin", + "ToolOriginType", "ReasoningItem", "ItemHelpers", "RunHooks", diff --git a/src/agents/agent.py b/src/agents/agent.py index 5d700ebaa3..69b7537a6a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -46,6 +46,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, _build_handled_function_tool_error_handler, _build_wrapped_function_tool, _log_function_tool_invocation, @@ -886,6 +888,11 @@ async def dispatch_stream_events() -> None: strict_json_schema=True, is_enabled=is_enabled, needs_approval=needs_approval, + tool_origin=ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_name=self.name, + agent_tool_name=tool_name_resolved, + ), ) run_agent_tool._is_agent_tool = True run_agent_tool._agent_instance = self diff --git a/src/agents/items.py b/src/agents/items.py index 9d6219f37d..a7789682d6 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -54,6 +54,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -358,6 +359,9 @@ class ToolCallItem(RunItemBase[Any]): title: str | None = None """Optional short display label if known at item creation time.""" + tool_origin: ToolOrigin | None = None + """Optional metadata describing the source of a function-tool-backed item.""" + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -382,6 +386,9 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = None + """Optional metadata describing the source of a function-tool-backed item.""" + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. @@ -493,6 +500,9 @@ class ToolApprovalItem(RunItemBase[Any]): tool_namespace: str | None = None """Optional Responses API namespace for function-tool approvals.""" + tool_origin: ToolOrigin | None = None + """Optional metadata describing where the approved tool call came from.""" + tool_lookup_key: FunctionToolLookupKey | None = field( default=None, kw_only=True, diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 33bea065c5..caea4324d3 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -27,6 +27,8 @@ FunctionTool, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, ToolOutputImageDict, ToolOutputTextDict, _build_handled_function_tool_error_handler, @@ -313,6 +315,10 @@ def to_function_tool( strict_json_schema=is_strict, needs_approval=needs_approval, mcp_title=resolve_mcp_tool_title(tool), + tool_origin=ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name=server.name, + ), ) return function_tool diff --git a/src/agents/run_internal/approvals.py b/src/agents/run_internal/approvals.py index 2c6bd6c94f..4d44d1ec94 100644 --- a/src/agents/run_internal/approvals.py +++ b/src/agents/run_internal/approvals.py @@ -13,6 +13,7 @@ from ..agent import Agent from ..items import ItemHelpers, RunItem, ToolApprovalItem, ToolCallOutputItem, TResponseInputItem +from ..tool import ToolOrigin from .items import ReasoningItemIdPolicy, run_item_to_input_item # -------------------------- @@ -28,6 +29,7 @@ def append_approval_error_output( tool_name: str, call_id: str | None, message: str, + tool_origin: ToolOrigin | None = None, ) -> None: """Emit a synthetic tool output so users see why an approval failed.""" error_tool_call = _build_function_tool_call_for_approval_error(tool_call, tool_name, call_id) @@ -36,6 +38,7 @@ def append_approval_error_output( output=message, raw_item=ItemHelpers.tool_call_output_item(error_tool_call, message), agent=agent, + tool_origin=tool_origin, ) ) diff --git a/src/agents/run_internal/items.py b/src/agents/run_internal/items.py index 3e0693b02e..8132d44ab2 100644 --- a/src/agents/run_internal/items.py +++ b/src/agents/run_internal/items.py @@ -307,6 +307,7 @@ def function_rejection_item( *, rejection_message: str = REJECTION_MESSAGE, scope_id: str | None = None, + tool_origin: Any = None, ) -> ToolCallOutputItem: """Build a ToolCallOutputItem representing a rejected function tool call.""" if isinstance(tool_call, ResponseFunctionToolCall): @@ -315,6 +316,7 @@ def function_rejection_item( output=rejection_message, raw_item=ItemHelpers.tool_call_output_item(tool_call, rejection_message), agent=agent, + tool_origin=tool_origin, ) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 36e34b4f56..993d079c6a 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -11,7 +11,12 @@ from collections.abc import Awaitable, Callable, Mapping from typing import Any, TypeVar, cast -from openai.types.responses import Response, ResponseCompletedEvent, ResponseOutputItemDoneEvent +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, +) from openai.types.responses.response_output_item import McpCall, McpListTools from openai.types.responses.response_prompt_param import ResponsePromptParam from openai.types.responses.response_reasoning_item import ResponseReasoningItem @@ -63,7 +68,14 @@ RawResponsesStreamEvent, RunItemStreamEvent, ) -from ..tool import FunctionTool, Tool, dispose_resolved_computers +from ..tool import ( + FunctionTool, + Tool, + ToolOrigin, + ToolOriginType, + dispose_resolved_computers, + get_function_tool_origin, +) from ..tracing import Span, SpanError, agent_span, get_current_trace from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData @@ -130,6 +142,7 @@ from .streaming import stream_step_items_to_queue, stream_step_result_to_queue from .tool_actions import ApplyPatchAction, ComputerAction, LocalShellAction, ShellAction from .tool_execution import ( + build_litellm_json_tool_call, coerce_shell_call, execute_apply_patch_calls, execute_computer_actions, @@ -1383,8 +1396,16 @@ async def rewind_model_request() -> None: matched_tool = ( tool_map.get(tool_lookup_key) if tool_lookup_key is not None else None ) + if ( + matched_tool is None + and output_schema is not None + and isinstance(output_item, ResponseFunctionToolCall) + and output_item.name == "json_tool_call" + ): + matched_tool = build_litellm_json_tool_call(output_item) tool_description: str | None = None tool_title: str | None = None + tool_origin = None if isinstance(output_item, McpCall): metadata = hosted_mcp_tool_metadata.get( (output_item.server_label, output_item.name) @@ -1392,15 +1413,21 @@ async def rewind_model_request() -> None: if metadata is not None: tool_description = metadata.description tool_title = metadata.title + tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name=output_item.server_label, + ) elif matched_tool is not None: tool_description = getattr(matched_tool, "description", None) tool_title = getattr(matched_tool, "_mcp_title", None) + tool_origin = get_function_tool_origin(matched_tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, description=tool_description, title=tool_title, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index f2a807020f..2435f89180 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -71,6 +71,8 @@ ShellCallOutcome, ShellCommandOutput, Tool, + ToolOrigin, + get_function_tool_origin, invoke_function_tool, maybe_invoke_function_tool_failure_error_function, resolve_computer, @@ -980,6 +982,7 @@ async def on_invoke_tool(_ctx: ToolContext[Any], value: Any) -> Any: on_invoke_tool=on_invoke_tool, strict_json_schema=True, is_enabled=True, + _emit_tool_origin=False, ) @@ -992,6 +995,7 @@ async def resolve_approval_status( context_wrapper: RunContextWrapper[Any], tool_namespace: str | None = None, tool_lookup_key: FunctionToolLookupKey | None = None, + tool_origin: ToolOrigin | None = None, on_approval: Callable[[RunContextWrapper[Any], ToolApprovalItem], Any] | None = None, ) -> tuple[bool | None, ToolApprovalItem]: """Build approval item, run on_approval hook if needed, and return latest approval status.""" @@ -1000,6 +1004,7 @@ async def resolve_approval_status( raw_item=raw_item, tool_name=tool_name, tool_namespace=tool_namespace, + tool_origin=tool_origin, tool_lookup_key=tool_lookup_key, ) approval_status = context_wrapper.get_approval_status( @@ -1538,6 +1543,7 @@ async def _maybe_execute_tool_approval( raw_item=raw_tool_call, tool_name=func_tool.name, tool_namespace=tool_namespace, + tool_origin=get_function_tool_origin(func_tool), tool_lookup_key=tool_lookup_key, _allow_bare_name_alias=should_allow_bare_name_approval_alias( func_tool, @@ -1578,6 +1584,7 @@ async def _maybe_execute_tool_approval( tool_call, rejection_message=rejection_message, scope_id=self.tool_state_scope_id, + tool_origin=get_function_tool_origin(func_tool), ), ) @@ -1773,6 +1780,7 @@ def _build_function_tool_results(self) -> list[FunctionToolResult]: output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=self.agent, + tool_origin=get_function_tool_origin(tool_run.function_tool), ) else: # Skip tool output until nested interruptions are resolved. @@ -1964,7 +1972,14 @@ async def execute_approved_tools( if isinstance(tool_name, str) and tool_name: tool_map[tool_name] = tool - def _append_error(message: str, *, tool_call: Any, tool_name: str, call_id: str) -> None: + def _append_error( + message: str, + *, + tool_call: Any, + tool_name: str, + call_id: str, + tool_origin: ToolOrigin | None = None, + ) -> None: append_approval_error_output( message=message, tool_call=tool_call, @@ -1972,6 +1987,7 @@ def _append_error(message: str, *, tool_call: Any, tool_name: str, call_id: str) call_id=call_id, generated_items=generated_items, agent=agent, + tool_origin=tool_origin, ) async def _resolve_tool_run( @@ -1999,14 +2015,25 @@ async def _resolve_tool_run( call_id = extract_tool_call_id(tool_call) if not call_id: + resolved_tool = tool_map.get(approval_key) if approval_key is not None else None + if resolved_tool is None and tool_namespace is None: + resolved_tool = tool_map.get(tool_name) _append_error( message="Tool approval item missing call ID.", tool_call=tool_call, tool_name=tool_name, call_id="unknown", + tool_origin=( + get_function_tool_origin(resolved_tool) + if isinstance(resolved_tool, FunctionTool) + else None + ), ) return None + resolved_tool = tool_map.get(approval_key) if approval_key is not None else None + if resolved_tool is None and tool_namespace is None: + resolved_tool = tool_map.get(tool_name) approval_status = context_wrapper.get_approval_status( tool_name, call_id, @@ -2015,9 +2042,6 @@ async def _resolve_tool_run( tool_lookup_key=tool_lookup_key, ) if approval_status is False: - resolved_tool = tool_map.get(approval_key) if approval_key is not None else None - if resolved_tool is None and tool_namespace is None: - resolved_tool = tool_map.get(tool_name) message = REJECTION_MESSAGE if isinstance(resolved_tool, FunctionTool): message = await resolve_approval_rejection_message( @@ -2035,6 +2059,11 @@ async def _resolve_tool_run( tool_call=tool_call, tool_name=tool_name, call_id=call_id, + tool_origin=( + get_function_tool_origin(resolved_tool) + if isinstance(resolved_tool, FunctionTool) + else None + ), ) return None @@ -2044,12 +2073,15 @@ async def _resolve_tool_run( tool_call=tool_call, tool_name=tool_name, call_id=call_id, + tool_origin=( + get_function_tool_origin(resolved_tool) + if isinstance(resolved_tool, FunctionTool) + else None + ), ) return None - tool = tool_map.get(approval_key) if approval_key is not None else None - if tool is None and tool_namespace is None: - tool = tool_map.get(tool_name) + tool = resolved_tool if tool is None: _append_error( message=f"Tool '{display_tool_name}' not found.", diff --git a/src/agents/run_internal/tool_planning.py b/src/agents/run_internal/tool_planning.py index dabb83b4ac..ea836541cc 100644 --- a/src/agents/run_internal/tool_planning.py +++ b/src/agents/run_internal/tool_planning.py @@ -22,7 +22,7 @@ ToolCallOutputItem, ) from ..run_context import RunContextWrapper -from ..tool import FunctionTool, MCPToolApprovalRequest +from ..tool import FunctionTool, MCPToolApprovalRequest, get_function_tool_origin from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from .run_steps import ( ToolRunApplyPatchCall, @@ -410,11 +410,17 @@ async def _collect_runs_by_approval( if approval_status is True: approved_runs.append(run) else: + function_tool = get_mapping_or_attr(run, "function_tool") pending_item = existing_pending or ToolApprovalItem( agent=agent, raw_item=get_mapping_or_attr(run, "tool_call"), tool_name=tool_name, tool_namespace=get_tool_call_namespace(get_mapping_or_attr(run, "tool_call")), + tool_origin=( + get_function_tool_origin(function_tool) + if isinstance(function_tool, FunctionTool) + else None + ), tool_lookup_key=get_function_tool_lookup_key_for_call( get_mapping_or_attr(run, "tool_call") ), diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index c34c720fcc..ab51f99033 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -79,6 +79,9 @@ LocalShellTool, ShellTool, Tool, + ToolOrigin, + ToolOriginType, + get_function_tool_origin, ) from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from ..tracing import SpanError, handoff_span @@ -723,6 +726,7 @@ async def _record_function_rejection( tool_call, rejection_message=rejection_message, scope_id=tool_state_scope_id, + tool_origin=get_function_tool_origin(function_tool), ) ) if isinstance(call_id, str): @@ -1034,6 +1038,7 @@ def _add_unmatched_pending(approval: ToolApprovalItem) -> None: raw_item=run.tool_call, tool_name=run.function_tool.name, tool_namespace=get_tool_call_namespace(run.tool_call), + tool_origin=get_function_tool_origin(run.function_tool), tool_lookup_key=get_function_tool_lookup_key_for_call(run.tool_call), _allow_bare_name_alias=should_allow_bare_name_approval_alias( run.function_tool, @@ -1523,6 +1528,10 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]: agent=agent, description=metadata.description if metadata is not None else None, title=metadata.title if metadata is not None else None, + tool_origin=ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name=output.server_label, + ), ) ) tools_used.append("mcp") @@ -1634,11 +1643,19 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]: func_tool = function_map.get(lookup_key) if lookup_key is not None else None if func_tool is None: if output_schema is not None and output.name == "json_tool_call": - items.append(ToolCallItem(raw_item=output, agent=agent)) + synthetic_tool = build_litellm_json_tool_call(output) + items.append( + ToolCallItem( + raw_item=output, + agent=agent, + description=synthetic_tool.description, + tool_origin=get_function_tool_origin(synthetic_tool), + ) + ) functions.append( ToolRunFunction( tool_call=output, - function_tool=build_litellm_json_tool_call(output), + function_tool=synthetic_tool, ) ) continue @@ -1659,6 +1676,7 @@ def _dump_output_item(raw_item: Any) -> dict[str, Any]: agent=agent, description=func_tool.description, title=func_tool._mcp_title, + tool_origin=get_function_tool_origin(func_tool), ) ) functions.append( diff --git a/src/agents/run_state.py b/src/agents/run_state.py index dcda9e073c..c89bde3a0e 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -80,6 +80,7 @@ HostedMCPTool, LocalShellTool, ShellTool, + ToolOrigin, ) from .tool_guardrails import ( AllowBehavior, @@ -135,6 +136,11 @@ _MISSING_CONTEXT_SENTINEL = object() +def _deserialize_tool_origin(data: Any) -> ToolOrigin | None: + """Best-effort deserialization for optional tool origin metadata.""" + return ToolOrigin.from_json_dict(data) + + @dataclass class RunState(Generic[TContext, TAgent]): """Serializable snapshot of an agent run, including context, usage, and interruptions. @@ -784,6 +790,9 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: result["description"] = item.description if hasattr(item, "title") and item.title is not None: result["title"] = item.title + tool_origin = getattr(item, "tool_origin", None) + if isinstance(tool_origin, ToolOrigin): + result["tool_origin"] = tool_origin.to_json_dict() return result @@ -1226,6 +1235,8 @@ def _serialize_tool_approval_interruption( interruption_dict["tool_name"] = interruption.tool_name if interruption.tool_namespace is not None: interruption_dict["tool_namespace"] = interruption.tool_namespace + if interruption.tool_origin is not None: + interruption_dict["tool_origin"] = interruption.tool_origin.to_json_dict() tool_lookup_key = serialize_function_tool_lookup_key( getattr(interruption, "tool_lookup_key", None) ) @@ -1897,6 +1908,7 @@ def _deserialize_tool_approval_item( tool_name = item_data.get("tool_name") tool_namespace = item_data.get("tool_namespace") + tool_origin = _deserialize_tool_origin(item_data.get("tool_origin")) tool_lookup_key = deserialize_function_tool_lookup_key(item_data.get("tool_lookup_key")) allow_bare_name_alias = item_data.get("allow_bare_name_alias") is True raw_item = _deserialize_tool_approval_raw_item(raw_item_data) @@ -1905,6 +1917,7 @@ def _deserialize_tool_approval_item( raw_item=raw_item, tool_name=tool_name, tool_namespace=tool_namespace, + tool_origin=tool_origin, tool_lookup_key=tool_lookup_key, _allow_bare_name_alias=allow_bare_name_alias, ) @@ -2505,12 +2518,14 @@ def _resolve_agent_info( # Preserve display metadata if it was stored with the item. description = item_data.get("description") title = item_data.get("title") + tool_origin = _deserialize_tool_origin(item_data.get("tool_origin")) result.append( ToolCallItem( agent=agent, raw_item=raw_item_tool, description=description, title=title, + tool_origin=tool_origin, ) ) @@ -2525,6 +2540,7 @@ def _resolve_agent_info( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), + tool_origin=_deserialize_tool_origin(item_data.get("tool_origin")), ) ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 1ac3c29ae3..91dfc77fc6 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -10,6 +10,7 @@ import weakref from collections.abc import Awaitable, Mapping from dataclasses import dataclass, field +from enum import Enum from types import UnionType from typing import ( TYPE_CHECKING, @@ -163,6 +164,62 @@ class ToolOutputFileContentDict(TypedDict, total=False): ValidToolOutputPydanticModels ) + +class ToolOriginType(str, Enum): + """Enumerates the runtime source of a function-tool-backed run item.""" + + FUNCTION = "function" + MCP = "mcp" + AGENT_AS_TOOL = "agent_as_tool" + + +@dataclass(frozen=True) +class ToolOrigin: + """Serializable metadata describing where a function-tool-backed item came from.""" + + type: ToolOriginType + mcp_server_name: str | None = None + agent_name: str | None = None + agent_tool_name: str | None = None + + def to_json_dict(self) -> dict[str, str]: + """Convert the metadata to a JSON-compatible dict.""" + result: dict[str, str] = {"type": self.type.value} + if self.mcp_server_name is not None: + result["mcp_server_name"] = self.mcp_server_name + if self.agent_name is not None: + result["agent_name"] = self.agent_name + if self.agent_tool_name is not None: + result["agent_tool_name"] = self.agent_tool_name + return result + + @classmethod + def from_json_dict(cls, data: Any) -> ToolOrigin | None: + """Deserialize tool origin metadata from JSON-compatible data.""" + if not isinstance(data, Mapping): + return None + + raw_type = data.get("type") + if not isinstance(raw_type, str): + return None + + try: + origin_type = ToolOriginType(raw_type) + except ValueError: + return None + + def _optional_string(key: str) -> str | None: + value = data.get(key) + return value if isinstance(value, str) else None + + return cls( + type=origin_type, + mcp_server_name=_optional_string("mcp_server_name"), + agent_name=_optional_string("agent_name"), + agent_tool_name=_optional_string("agent_tool_name"), + ) + + ComputerLike = Union[Computer, AsyncComputer] ComputerT = TypeVar("ComputerT", bound=ComputerLike) ComputerT_co = TypeVar("ComputerT_co", bound=ComputerLike, covariant=True) @@ -326,6 +383,12 @@ class FunctionTool: _mcp_title: str | None = field(default=None, kw_only=True, repr=False) """Internal MCP display title used for ToolCallItem metadata.""" + _tool_origin: ToolOrigin | None = field(default=None, kw_only=True, repr=False) + """Internal scalar metadata describing the origin of function-tool-backed items.""" + + _emit_tool_origin: bool = field(default=True, kw_only=True, repr=False) + """Whether runtime item generation should emit tool origin metadata for this tool.""" + @property def qualified_name(self) -> str: """Return the public qualified name used to identify this function tool.""" @@ -428,6 +491,7 @@ def _build_wrapped_function_tool( defer_loading: bool = False, sync_invoker: bool = False, mcp_title: str | None = None, + tool_origin: ToolOrigin | None = None, ) -> FunctionTool: """Create a FunctionTool with copied-tool-aware failure handling bound in one place.""" on_invoke_tool = with_function_tool_failure_error_handler( @@ -453,11 +517,19 @@ def _build_wrapped_function_tool( timeout_error_function=timeout_error_function, defer_loading=defer_loading, _mcp_title=mcp_title, + _tool_origin=tool_origin, ), failure_error_function, ) +def get_function_tool_origin(function_tool: FunctionTool) -> ToolOrigin | None: + """Return scalar origin metadata for a function tool.""" + if not function_tool._emit_tool_origin: + return None + return function_tool._tool_origin or ToolOrigin(type=ToolOriginType.FUNCTION) + + @dataclass class FileSearchTool: """A hosted tool that lets the LLM search through a vector store. Currently only supported with diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..a277452cf6 --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +import gc +import json +import weakref +from collections.abc import Sequence +from typing import Any, TypeVar, cast + +import pytest +from mcp import Tool as MCPTool +from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool +from pydantic import BaseModel + +from agents import ( + Agent, + HostedMCPTool, + ModelResponse, + RunConfig, + RunContextWrapper, + RunHooks, + Runner, + RunState, + ToolCallItem, + ToolCallOutputItem, + ToolOrigin, + ToolOriginType, + Usage, + function_tool, +) +from agents.items import MCPListToolsItem, ToolApprovalItem +from agents.mcp import MCPUtil +from agents.run_internal import run_loop +from agents.run_internal.run_loop import get_output_schema +from agents.run_internal.tool_execution import execute_function_tool_calls +from tests.fake_model import FakeModel +from tests.mcp.helpers import FakeMCPServer +from tests.test_responses import get_function_tool_call, get_text_message +from tests.utils.factories import make_run_state, make_tool_call, roundtrip_state + +TItem = TypeVar("TItem") + + +def _first_item(items: Sequence[object], item_type: type[TItem]) -> TItem: + for item in items: + if isinstance(item, item_type): + return item + raise AssertionError(f"Expected item of type {item_type.__name__}.") + + +class StructuredOutputPayload(BaseModel): + status: str + + +def _make_hosted_mcp_list_tools(server_label: str, tool_name: str) -> McpListTools: + return McpListTools( + id=f"list_{server_label}", + server_label=server_label, + tools=[ + McpListToolsTool( + name=tool_name, + input_schema={}, + description="Search the docs.", + annotations={"title": "Search Docs"}, + ) + ], + type="mcp_list_tools", + ) + + +@pytest.mark.asyncio +async def test_runner_attaches_function_tool_origin_to_call_and_output_items() -> None: + model = FakeModel() + + @function_tool(name_override="lookup_account") + def lookup_account() -> str: + return "account" + + agent = Agent(name="tool-origin-agent", model=model, tools=[lookup_account]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("lookup_account", json.dumps({}), call_id="call_lookup")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="hello") + + expected = ToolOrigin(type=ToolOriginType.FUNCTION) + assert _first_item(result.new_items, ToolCallItem).tool_origin == expected + assert _first_item(result.new_items, ToolCallOutputItem).tool_origin == expected + + +@pytest.mark.asyncio +async def test_rejected_function_tool_output_preserves_tool_origin() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool() -> str: + raise AssertionError("The tool should not run when rejected.") + + agent = Agent(name="approval-agent", model=model, tools=[approval_tool]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("approval_tool", json.dumps({}), call_id="call_approval")], + [get_text_message("done")], + ] + ) + + first_run = await Runner.run(agent, input="hello") + assert first_run.interruptions + + state = first_run.to_state() + state.reject(first_run.interruptions[0]) + resumed = await Runner.run(agent, state) + + assert _first_item(resumed.new_items, ToolCallOutputItem).tool_origin == ToolOrigin( + type=ToolOriginType.FUNCTION + ) + + +def test_tool_call_output_item_preserves_positional_type_argument() -> None: + agent = Agent(name="positional") + item = ToolCallOutputItem( + agent, + { + "type": "function_call_output", + "call_id": "call_positional", + "output": "result", + }, + "result", + "tool_call_output_item", + ) + + assert item.type == "tool_call_output_item" + assert item.tool_origin is None + + +@pytest.mark.asyncio +async def test_runner_attaches_local_mcp_tool_origin_to_call_and_output_items() -> None: + model = FakeModel() + server = FakeMCPServer( + server_name="docs_server", + tools=[ + MCPTool( + name="search_docs", + inputSchema={}, + description="Search the docs.", + title="Search Docs", + ) + ], + ) + agent = Agent(name="mcp-agent", model=model, mcp_servers=[server]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("search_docs", json.dumps({}), call_id="call_search_docs")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="hello") + + expected = ToolOrigin(type=ToolOriginType.MCP, mcp_server_name="docs_server") + assert _first_item(result.new_items, ToolCallItem).tool_origin == expected + assert _first_item(result.new_items, ToolCallOutputItem).tool_origin == expected + + +@pytest.mark.asyncio +async def test_streamed_tool_call_item_includes_local_mcp_origin() -> None: + model = FakeModel() + server = FakeMCPServer( + server_name="docs_server", + tools=[ + MCPTool( + name="search_docs", + inputSchema={}, + description=None, + title="Search Docs", + ) + ], + ) + agent = Agent(name="stream-mcp-agent", model=model, mcp_servers=[server]) + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("search_docs", json.dumps({}), call_id="call_stream_search")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="hello") + seen_tool_item: ToolCallItem | None = None + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event.item, ToolCallItem) + and seen_tool_item is None + ): + seen_tool_item = event.item + + assert seen_tool_item is not None + assert seen_tool_item.tool_origin == ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name="docs_server", + ) + + +def test_process_model_response_attaches_hosted_mcp_tool_origin() -> None: + agent = Agent(name="hosted-mcp") + hosted_tool = HostedMCPTool( + tool_config=cast( + Any, + { + "type": "mcp", + "server_label": "docs_server", + "server_url": "https://example.com/mcp", + }, + ) + ) + existing_items = [ + MCPListToolsItem( + agent=agent, + raw_item=_make_hosted_mcp_list_tools("docs_server", "search_docs"), + ) + ] + response = ModelResponse( + output=[ + McpCall( + id="mcp_call_1", + arguments="{}", + name="search_docs", + server_label="docs_server", + type="mcp_call", + status="completed", + ) + ], + usage=Usage(), + response_id="resp_hosted_mcp", + ) + + processed = run_loop.process_model_response( + agent=agent, + all_tools=[hosted_tool], + response=response, + output_schema=None, + handoffs=[], + existing_items=existing_items, + ) + + tool_call_item = _first_item(processed.new_items, ToolCallItem) + assert tool_call_item.tool_origin == ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name="docs_server", + ) + + +@pytest.mark.asyncio +async def test_streamed_tool_call_item_includes_hosted_mcp_origin() -> None: + model = FakeModel() + hosted_tool = HostedMCPTool( + tool_config=cast( + Any, + { + "type": "mcp", + "server_label": "docs_server", + "server_url": "https://example.com/mcp", + }, + ) + ) + agent = Agent(name="stream-hosted-mcp", model=model, tools=[hosted_tool]) + model.add_multiple_turn_outputs( + [ + [ + _make_hosted_mcp_list_tools("docs_server", "search_docs"), + McpCall( + id="mcp_call_stream_1", + arguments="{}", + name="search_docs", + server_label="docs_server", + type="mcp_call", + status="completed", + ), + ], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="hello") + seen_tool_item: ToolCallItem | None = None + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event.item, ToolCallItem) + and isinstance(event.item.raw_item, McpCall) + ): + seen_tool_item = event.item + break + + assert seen_tool_item is not None + assert seen_tool_item.tool_origin == ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name="docs_server", + ) + + +def test_local_mcp_tool_origin_does_not_retain_server_object() -> None: + server = FakeMCPServer(server_name="docs_server") + function_tool = MCPUtil.to_function_tool( + MCPTool( + name="search_docs", + inputSchema={}, + description="Search the docs.", + title="Search Docs", + ), + server, + convert_schemas_to_strict=False, + ) + item = ToolCallItem( + agent=Agent(name="release-agent"), + raw_item=make_tool_call(name="search_docs"), + description=function_tool.description, + title=function_tool._mcp_title, + tool_origin=function_tool._tool_origin, + ) + + server_ref = weakref.ref(server) + item.release_agent() + + del function_tool + del server + gc.collect() + + assert server_ref() is None + assert item.tool_origin == ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name="docs_server", + ) + + +@pytest.mark.asyncio +async def test_json_tool_call_does_not_emit_function_tool_origin() -> None: + agent = Agent(name="structured-output", output_type=StructuredOutputPayload) + response = ModelResponse( + output=[ + get_function_tool_call( + "json_tool_call", + StructuredOutputPayload(status="ok").model_dump_json(), + call_id="call_json_tool", + ) + ], + usage=Usage(), + response_id="resp_json_tool", + ) + context_wrapper = RunContextWrapper(None) + processed = run_loop.process_model_response( + agent=agent, + all_tools=[], + response=response, + output_schema=get_output_schema(agent), + handoffs=[], + ) + + tool_call_item = _first_item(processed.new_items, ToolCallItem) + assert tool_call_item.tool_origin is None + + function_results, _, _ = await execute_function_tool_calls( + agent=agent, + tool_runs=processed.functions, + hooks=RunHooks(), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + tool_output_item = _first_item( + [result.run_item for result in function_results if result.run_item is not None], + ToolCallOutputItem, + ) + assert tool_output_item.tool_origin is None + + +@pytest.mark.asyncio +async def test_run_state_roundtrip_preserves_distinct_agent_tool_names() -> None: + outer_agent = Agent(name="outer") + worker_a = Agent(name="worker") + worker_b = Agent(name="worker") + + tool_a = worker_a.as_tool(tool_name="worker_lookup_a", tool_description="Worker A") + tool_b = worker_b.as_tool(tool_name="worker_lookup_b", tool_description="Worker B") + + state: RunState[Any, Agent[Any]] = make_run_state(outer_agent) + state._generated_items.extend( + [ + ToolCallItem( + agent=outer_agent, + raw_item=make_tool_call(call_id="call_worker_a", name=tool_a.name), + description=tool_a.description, + tool_origin=tool_a._tool_origin, + ), + ToolCallItem( + agent=outer_agent, + raw_item=make_tool_call(call_id="call_worker_b", name=tool_b.name), + description=tool_b.description, + tool_origin=tool_b._tool_origin, + ), + ] + ) + + restored = await roundtrip_state(outer_agent, state) + restored_items = [item for item in restored._generated_items if isinstance(item, ToolCallItem)] + + assert [item.tool_origin for item in restored_items] == [ + ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_name="worker", + agent_tool_name="worker_lookup_a", + ), + ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_name="worker", + agent_tool_name="worker_lookup_b", + ), + ] + + +@pytest.mark.asyncio +async def test_run_state_from_json_reads_legacy_1_5_without_tool_origin() -> None: + agent = Agent(name="legacy") + state: RunState[Any, Agent[Any]] = make_run_state(agent) + state._generated_items.append( + ToolCallItem( + agent=agent, + raw_item=make_tool_call(call_id="call_legacy", name="legacy_tool"), + description="Legacy tool", + tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION), + ) + ) + + restored = await roundtrip_state( + agent, + state, + mutate_json=lambda data: { + **data, + "$schemaVersion": "1.5", + "generated_items": [ + {key: value for key, value in item.items() if key != "tool_origin"} + for item in data["generated_items"] + ], + }, + ) + + restored_item = _first_item(restored._generated_items, ToolCallItem) + assert restored_item.description == "Legacy tool" + assert restored_item.tool_origin is None + + +@pytest.mark.asyncio +async def test_run_state_roundtrip_preserves_tool_origin_on_approval_interruptions() -> None: + agent = Agent(name="approval-origin") + state: RunState[Any, Agent[Any]] = make_run_state(agent) + state._generated_items.append( + ToolApprovalItem( + agent=agent, + raw_item=make_tool_call(call_id="call_approval", name="approval_tool"), + tool_name="approval_tool", + tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION), + ) + ) + + restored = await roundtrip_state(agent, state) + + approval_item = _first_item(restored._generated_items, ToolApprovalItem) + assert approval_item.tool_origin == ToolOrigin(type=ToolOriginType.FUNCTION) + + +@pytest.mark.asyncio +async def test_run_state_from_json_reads_legacy_1_6_approval_without_tool_origin() -> None: + agent = Agent(name="approval-origin-legacy") + state: RunState[Any, Agent[Any]] = make_run_state(agent) + state._generated_items.append( + ToolApprovalItem( + agent=agent, + raw_item=make_tool_call(call_id="call_legacy_approval", name="approval_tool"), + tool_name="approval_tool", + tool_origin=ToolOrigin(type=ToolOriginType.FUNCTION), + ) + ) + + restored = await roundtrip_state( + agent, + state, + mutate_json=lambda data: { + **data, + "$schemaVersion": "1.6", + "generated_items": [ + {key: value for key, value in item.items() if key != "tool_origin"} + for item in data["generated_items"] + ], + }, + ) + + approval_item = _first_item(restored._generated_items, ToolApprovalItem) + assert approval_item.tool_origin is None