-
-
Notifications
You must be signed in to change notification settings - Fork 2k
[codex] Sanitize empty optional MCP arguments #7160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
a89eebe
53233e7
4200317
97dc7c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| import sys | ||
| from contextlib import AsyncExitStack | ||
| from datetime import timedelta | ||
| from typing import Generic | ||
| from typing import Any, Generic | ||
|
|
||
| from tenacity import ( | ||
| before_sleep_log, | ||
|
|
@@ -125,6 +125,38 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: | |
| return False, f"{e!s}" | ||
|
|
||
|
|
||
| _EMPTY_MCP_ARGUMENT = object() | ||
|
|
||
|
|
||
| def _sanitize_mcp_arguments(value: Any) -> Any: | ||
| """Remove empty optional payload values before sending to MCP tools.""" | ||
| if value is None: | ||
| return _EMPTY_MCP_ARGUMENT | ||
|
|
||
| if isinstance(value, str): | ||
| return value if value != "" else _EMPTY_MCP_ARGUMENT | ||
|
|
||
| if isinstance(value, list): | ||
| cleaned_items = [] | ||
| for item in value: | ||
| cleaned_item = _sanitize_mcp_arguments(item) | ||
| if cleaned_item is _EMPTY_MCP_ARGUMENT: | ||
| continue | ||
| cleaned_items.append(cleaned_item) | ||
| return cleaned_items if cleaned_items else _EMPTY_MCP_ARGUMENT | ||
|
|
||
| if isinstance(value, dict): | ||
| cleaned_dict = {} | ||
| for key, item in value.items(): | ||
| cleaned_item = _sanitize_mcp_arguments(item) | ||
| if cleaned_item is _EMPTY_MCP_ARGUMENT: | ||
| continue | ||
Tobi1chi marked this conversation as resolved.
Show resolved
Hide resolved
Tobi1chi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| cleaned_dict[key] = cleaned_item | ||
| return cleaned_dict if cleaned_dict else _EMPTY_MCP_ARGUMENT | ||
|
|
||
| return value | ||
|
|
||
|
|
||
| class MCPClient: | ||
| def __init__(self) -> None: | ||
| # Initialize session and client objects | ||
|
|
@@ -347,6 +379,17 @@ async def call_tool_with_reconnect( | |
| anyio.ClosedResourceError: raised after reconnection failure | ||
| """ | ||
|
|
||
| sanitized_arguments = _sanitize_mcp_arguments(arguments) | ||
| if sanitized_arguments is _EMPTY_MCP_ARGUMENT: | ||
| sanitized_arguments = {} | ||
| if sanitized_arguments != arguments: | ||
| logger.debug( | ||
| "Sanitized MCP tool %s arguments from %s to %s", | ||
| tool_name, | ||
| arguments, | ||
| sanitized_arguments, | ||
|
Comment on lines
+410
to
+415
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚨 suggestion (security): Debug logging of full arguments may expose sensitive data and be expensive This logs both the raw and sanitized arguments, which can leak credentials/PII into logs and adds overhead for large payloads. Consider redacting known sensitive fields, truncating large values, or guarding this behind a more verbose/debug-only flag so it’s not enabled in typical deployments. |
||
| ) | ||
|
|
||
| @retry( | ||
| retry=retry_if_exception_type(anyio.ClosedResourceError), | ||
| stop=stop_after_attempt(2), | ||
|
|
@@ -361,7 +404,7 @@ async def _call_with_retry(): | |
| try: | ||
| return await self.session.call_tool( | ||
| name=tool_name, | ||
| arguments=arguments, | ||
| arguments=sanitized_arguments, | ||
| read_timeout_seconds=read_timeout_seconds, | ||
| ) | ||
| except anyio.ClosedResourceError: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import importlib.util | ||
| import logging | ||
| import sys | ||
| import types | ||
| from pathlib import Path | ||
| from typing import Generic, TypeVar | ||
| from unittest.mock import AsyncMock | ||
|
|
||
| import pytest | ||
|
|
||
| REPO_ROOT = Path(__file__).resolve().parents[2] | ||
| MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py" | ||
|
|
||
|
|
||
| def load_mcp_client_module(): | ||
| package_names = [ | ||
| "astrbot", | ||
| "astrbot.core", | ||
| "astrbot.core.agent", | ||
| "astrbot.core.utils", | ||
| ] | ||
| for name in package_names: | ||
| if name not in sys.modules: | ||
| module = types.ModuleType(name) | ||
| module.__path__ = [] | ||
| sys.modules[name] = module | ||
|
|
||
| astrbot_module = sys.modules["astrbot"] | ||
| astrbot_module.logger = logging.getLogger("astrbot-test") | ||
|
|
||
| log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe") | ||
| log_pipe_module.LogPipe = type("LogPipe", (), {}) | ||
| sys.modules[log_pipe_module.__name__] = log_pipe_module | ||
|
|
||
| run_context_module = types.ModuleType("astrbot.core.agent.run_context") | ||
| run_context_module.TContext = TypeVar("TContext") | ||
|
|
||
| class ContextWrapper(Generic[run_context_module.TContext]): | ||
| pass | ||
|
|
||
| run_context_module.ContextWrapper = ContextWrapper | ||
| sys.modules[run_context_module.__name__] = run_context_module | ||
|
|
||
| tool_module = types.ModuleType("astrbot.core.agent.tool") | ||
| tool_module.FunctionTool = type("FunctionTool", (), {}) | ||
| sys.modules[tool_module.__name__] = tool_module | ||
|
|
||
| anyio_module = types.ModuleType("anyio") | ||
| anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {}) | ||
| sys.modules["anyio"] = anyio_module | ||
|
|
||
Tobi1chi marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+50
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| mcp_module = types.ModuleType("mcp") | ||
| mcp_module.Tool = type("Tool", (), {}) | ||
| mcp_module.ClientSession = type("ClientSession", (), {}) | ||
| mcp_module.ListToolsResult = type("ListToolsResult", (), {}) | ||
| mcp_module.StdioServerParameters = type("StdioServerParameters", (), {}) | ||
| mcp_module.stdio_client = lambda *args, **kwargs: None | ||
| mcp_module.types = types.SimpleNamespace( | ||
| LoggingMessageNotificationParams=type( | ||
| "LoggingMessageNotificationParams", (), {} | ||
| ), | ||
| CallToolResult=type("CallToolResult", (), {}), | ||
| ) | ||
| sys.modules["mcp"] = mcp_module | ||
|
|
||
| mcp_client_module = types.ModuleType("mcp.client") | ||
| sys.modules[mcp_client_module.__name__] = mcp_client_module | ||
|
|
||
| mcp_client_sse_module = types.ModuleType("mcp.client.sse") | ||
| mcp_client_sse_module.sse_client = lambda *args, **kwargs: None | ||
| sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module | ||
|
|
||
| mcp_client_streamable_http_module = types.ModuleType( | ||
| "mcp.client.streamable_http" | ||
| ) | ||
| mcp_client_streamable_http_module.streamablehttp_client = ( | ||
| lambda *args, **kwargs: None | ||
| ) | ||
| sys.modules[mcp_client_streamable_http_module.__name__] = ( | ||
| mcp_client_streamable_http_module | ||
| ) | ||
|
|
||
| spec = importlib.util.spec_from_file_location( | ||
| "astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH | ||
| ) | ||
| assert spec and spec.loader | ||
| module = importlib.util.module_from_spec(spec) | ||
| sys.modules[spec.name] = module | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| def test_sanitize_mcp_arguments_removes_nested_empty_collections(): | ||
| mcp_client_module = load_mcp_client_module() | ||
|
|
||
| sanitized = mcp_client_module._sanitize_mcp_arguments( | ||
| { | ||
| "query": "hello", | ||
| "filters": {"tags": [], "scope": {}}, | ||
| "metadata": {"owner": "", "visibility": None}, | ||
| } | ||
| ) | ||
sourcery-ai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| assert sanitized == {"query": "hello"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments(): | ||
| mcp_client_module = load_mcp_client_module() | ||
|
|
||
| client = mcp_client_module.MCPClient() | ||
| client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok")) | ||
|
|
||
| result = await client.call_tool_with_reconnect( | ||
| tool_name="search", | ||
| arguments={"filters": {}, "query": ""}, | ||
sourcery-ai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| read_timeout_seconds=mcp_client_module.timedelta(seconds=1), | ||
| ) | ||
|
|
||
| assert result == "ok" | ||
| client.session.call_tool.assert_awaited_once_with( | ||
| name="search", | ||
| arguments={}, | ||
| read_timeout_seconds=mcp_client_module.timedelta(seconds=1), | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.