diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 56f23afe350..39052eb425b 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -32,7 +32,7 @@ from api.db.services.mcp_server_service import MCPServerService from api.db.services.tenant_llm_service import TenantLLMService from common.connection_utils import timeout -from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool +from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, mcp_tool_metadata_to_openai_tool from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt @@ -97,13 +97,16 @@ def __init__(self, canvas, id, param: LLMParam): indexed_meta["function"]["name"] = indexed_name self.tool_meta.append(indexed_meta) + tool_idx = len(self.tools) for mcp in self._param.mcp: _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"]) custom_header = self._param.custom_header tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header) for tnm, meta in mcp["tools"].items(): - self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta)) - self.tools[tnm] = tool_call_session + indexed_name = f"{tnm}_{tool_idx}" + tool_idx += 1 + self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta, function_name=indexed_name)) + self.tools[indexed_name] = MCPToolBinding(tool_call_session, tnm) self.callback = partial(self._canvas.tool_use_callback, id) self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback) if self.tool_meta: diff --git a/agent/tools/base.py b/agent/tools/base.py index f5a42de4d10..0110a84142f 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -23,7 +23,7 @@ from agent.component.base import ComponentParamBase, ComponentBase from common.misc_utils import hash_str2int from rag.prompts.generator import kb_prompt -from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession +from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, ToolCallSession from timeit import default_timer as timer @@ -52,16 +52,18 @@ def __init__(self, tools_map: dict[str, object], callback: partial): self.tools_map = tools_map self.callback = callback - def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: - return asyncio.run(self.tool_call_async(name, arguments)) + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> Any: + return asyncio.run(self.tool_call_async(name, arguments, request_timeout=timeout)) - async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any: + async def tool_call_async(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" logging.info(f"[ToolCall] invoke name={name} arguments={str(arguments)[:200]}") st = timer() tool_obj = self.tools_map[name] - if isinstance(tool_obj, MCPToolCallSession): - resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60) + if isinstance(tool_obj, MCPToolBinding): + resp = await thread_pool_exec(tool_obj.session.tool_call, tool_obj.original_name, arguments, request_timeout) + elif isinstance(tool_obj, MCPToolCallSession): + resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, request_timeout) elif hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): resp = await tool_obj.invoke_async(**arguments) else: diff --git a/common/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py index 95e3581bb0b..676978d052e 100644 --- a/common/mcp_tool_call_conn.py +++ b/common/mcp_tool_call_conn.py @@ -20,6 +20,7 @@ import weakref from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError +from dataclasses import dataclass from string import Template from typing import Any, Literal, Protocol @@ -36,7 +37,13 @@ class ToolCallSession(Protocol): - def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: ... + + +@dataclass(frozen=True) +class MCPToolBinding: + session: ToolCallSession + original_name: str class MCPToolCallSession(ToolCallSession): @@ -316,12 +323,12 @@ def shutdown_all_mcp_sessions(): logging.info("All MCPToolCallSession instances have been closed.") -def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]: +def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict, function_name: str | None = None) -> dict[str, Any]: if isinstance(mcp_tool, dict): return { "type": "function", "function": { - "name": mcp_tool["name"], + "name": function_name or mcp_tool["name"], "description": mcp_tool["description"], "parameters": mcp_tool["inputSchema"], }, @@ -330,7 +337,7 @@ def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]: return { "type": "function", "function": { - "name": mcp_tool.name, + "name": function_name or mcp_tool.name, "description": mcp_tool.description, "parameters": mcp_tool.inputSchema, },