Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Generic
from typing import Any, Generic

from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -125,6 +125,38 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"


_EMPTY_MCP_ARGUMENT = object()


def _sanitize_mcp_arguments(value: Any) -> Any:
"""Remove empty optional payload values before sending to MCP tools."""
if value is None:
return _EMPTY_MCP_ARGUMENT

if isinstance(value, str):
return value if value != "" else _EMPTY_MCP_ARGUMENT

if isinstance(value, list):
cleaned_items = []
for item in value:
cleaned_item = _sanitize_mcp_arguments(item)
if cleaned_item is _EMPTY_MCP_ARGUMENT:
continue
cleaned_items.append(cleaned_item)
return cleaned_items if cleaned_items else _EMPTY_MCP_ARGUMENT

if isinstance(value, dict):
cleaned_dict = {}
for key, item in value.items():
cleaned_item = _sanitize_mcp_arguments(item)
if cleaned_item is _EMPTY_MCP_ARGUMENT:
continue
cleaned_dict[key] = cleaned_item
return cleaned_dict if cleaned_dict else _EMPTY_MCP_ARGUMENT

return value


class MCPClient:
def __init__(self) -> None:
# Initialize session and client objects
Expand Down Expand Up @@ -347,6 +379,17 @@ async def call_tool_with_reconnect(
anyio.ClosedResourceError: raised after reconnection failure
"""

sanitized_arguments = _sanitize_mcp_arguments(arguments)
if sanitized_arguments is _EMPTY_MCP_ARGUMENT:
sanitized_arguments = {}
if sanitized_arguments != arguments:
logger.debug(
"Sanitized MCP tool %s arguments from %s to %s",
tool_name,
arguments,
sanitized_arguments,
Comment on lines +410 to +415
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 suggestion (security): Debug logging of full arguments may expose sensitive data and be expensive

This logs both the raw and sanitized arguments, which can leak credentials/PII into logs and adds overhead for large payloads. Consider redacting known sensitive fields, truncating large values, or guarding this behind a more verbose/debug-only flag so it’s not enabled in typical deployments.

)

@retry(
retry=retry_if_exception_type(anyio.ClosedResourceError),
stop=stop_after_attempt(2),
Expand All @@ -361,7 +404,7 @@ async def _call_with_retry():
try:
return await self.session.call_tool(
name=tool_name,
arguments=arguments,
arguments=sanitized_arguments,
read_timeout_seconds=read_timeout_seconds,
)
except anyio.ClosedResourceError:
Expand Down
127 changes: 127 additions & 0 deletions tests/unit/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import importlib.util
import logging
import sys
import types
from pathlib import Path
from typing import Generic, TypeVar
from unittest.mock import AsyncMock

import pytest

REPO_ROOT = Path(__file__).resolve().parents[2]
MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py"


def load_mcp_client_module():
package_names = [
"astrbot",
"astrbot.core",
"astrbot.core.agent",
"astrbot.core.utils",
]
for name in package_names:
if name not in sys.modules:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module

astrbot_module = sys.modules["astrbot"]
astrbot_module.logger = logging.getLogger("astrbot-test")

log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe")
log_pipe_module.LogPipe = type("LogPipe", (), {})
sys.modules[log_pipe_module.__name__] = log_pipe_module

run_context_module = types.ModuleType("astrbot.core.agent.run_context")
run_context_module.TContext = TypeVar("TContext")

class ContextWrapper(Generic[run_context_module.TContext]):
pass

run_context_module.ContextWrapper = ContextWrapper
sys.modules[run_context_module.__name__] = run_context_module

tool_module = types.ModuleType("astrbot.core.agent.tool")
tool_module.FunctionTool = type("FunctionTool", (), {})
sys.modules[tool_module.__name__] = tool_module

anyio_module = types.ModuleType("anyio")
anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {})
sys.modules["anyio"] = anyio_module

Comment on lines +50 to +53
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Isolate sys.modules stubs to each test

load_mcp_client_module() writes fake modules into sys.modules and never restores them, so once this helper runs, later tests in the same pytest process can import these incomplete stubs (anyio, mcp, and related astrbot.* entries) instead of real modules, causing order-dependent failures or hiding real regressions. Please patch these entries via monkeypatch (or equivalent teardown) so global import state is restored after each test.

Useful? React with 👍 / 👎.

mcp_module = types.ModuleType("mcp")
mcp_module.Tool = type("Tool", (), {})
mcp_module.ClientSession = type("ClientSession", (), {})
mcp_module.ListToolsResult = type("ListToolsResult", (), {})
mcp_module.StdioServerParameters = type("StdioServerParameters", (), {})
mcp_module.stdio_client = lambda *args, **kwargs: None
mcp_module.types = types.SimpleNamespace(
LoggingMessageNotificationParams=type(
"LoggingMessageNotificationParams", (), {}
),
CallToolResult=type("CallToolResult", (), {}),
)
sys.modules["mcp"] = mcp_module

mcp_client_module = types.ModuleType("mcp.client")
sys.modules[mcp_client_module.__name__] = mcp_client_module

mcp_client_sse_module = types.ModuleType("mcp.client.sse")
mcp_client_sse_module.sse_client = lambda *args, **kwargs: None
sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module

mcp_client_streamable_http_module = types.ModuleType(
"mcp.client.streamable_http"
)
mcp_client_streamable_http_module.streamablehttp_client = (
lambda *args, **kwargs: None
)
sys.modules[mcp_client_streamable_http_module.__name__] = (
mcp_client_streamable_http_module
)

spec = importlib.util.spec_from_file_location(
"astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module


def test_sanitize_mcp_arguments_removes_nested_empty_collections():
mcp_client_module = load_mcp_client_module()

sanitized = mcp_client_module._sanitize_mcp_arguments(
{
"query": "hello",
"filters": {"tags": [], "scope": {}},
"metadata": {"owner": "", "visibility": None},
}
)

assert sanitized == {"query": "hello"}


@pytest.mark.asyncio
async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments():
mcp_client_module = load_mcp_client_module()

client = mcp_client_module.MCPClient()
client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok"))

result = await client.call_tool_with_reconnect(
tool_name="search",
arguments={"filters": {}, "query": ""},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)

assert result == "ok"
client.session.call_tool.assert_awaited_once_with(
name="search",
arguments={},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)