Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions agent/component/agent_with_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.tenant_llm_service import TenantLLMService
from common.connection_utils import timeout
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, mcp_tool_metadata_to_openai_tool
from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt


Expand Down Expand Up @@ -97,13 +97,16 @@ def __init__(self, canvas, id, param: LLMParam):
indexed_meta["function"]["name"] = indexed_name
self.tool_meta.append(indexed_meta)

tool_idx = len(self.tools)
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
custom_header = self._param.custom_header
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session
indexed_name = f"{tnm}_{tool_idx}"
tool_idx += 1
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta, function_name=indexed_name))
self.tools[indexed_name] = MCPToolBinding(tool_call_session, tnm)
self.callback = partial(self._canvas.tool_use_callback, id)
self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
if self.tool_meta:
Expand Down
14 changes: 8 additions & 6 deletions agent/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from agent.component.base import ComponentParamBase, ComponentBase
from common.misc_utils import hash_str2int
from rag.prompts.generator import kb_prompt
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, ToolCallSession
from timeit import default_timer as timer


Expand Down Expand Up @@ -52,16 +52,18 @@ def __init__(self, tools_map: dict[str, object], callback: partial):
self.tools_map = tools_map
self.callback = callback

def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
return asyncio.run(self.tool_call_async(name, arguments))
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> Any:
return asyncio.run(self.tool_call_async(name, arguments, request_timeout=timeout))

async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any:
async def tool_call_async(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
logging.info(f"[ToolCall] invoke name={name} arguments={str(arguments)[:200]}")
st = timer()
tool_obj = self.tools_map[name]
if isinstance(tool_obj, MCPToolCallSession):
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60)
if isinstance(tool_obj, MCPToolBinding):
resp = await thread_pool_exec(tool_obj.session.tool_call, tool_obj.original_name, arguments, request_timeout)
elif isinstance(tool_obj, MCPToolCallSession):
resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, request_timeout)
elif hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async):
resp = await tool_obj.invoke_async(**arguments)
else:
Expand Down
15 changes: 11 additions & 4 deletions common/mcp_tool_call_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import weakref
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import TimeoutError as FuturesTimeoutError
from dataclasses import dataclass
from string import Template
from typing import Any, Literal, Protocol

Expand All @@ -36,7 +37,13 @@


class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: ...


@dataclass(frozen=True)
class MCPToolBinding:
session: ToolCallSession
original_name: str


class MCPToolCallSession(ToolCallSession):
Expand Down Expand Up @@ -316,12 +323,12 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.")


def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict, function_name: str | None = None) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",
"function": {
"name": mcp_tool["name"],
"name": function_name or mcp_tool["name"],
"description": mcp_tool["description"],
"parameters": mcp_tool["inputSchema"],
},
Expand All @@ -330,7 +337,7 @@ def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": mcp_tool.name,
"name": function_name or mcp_tool.name,
"description": mcp_tool.description,
"parameters": mcp_tool.inputSchema,
},
Expand Down
Loading