diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 1fb4b03368..1afbecfd83 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -37,6 +37,8 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.star.session_plugin_manager import SessionPluginManager +from astrbot.core.star.star import star_map from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history from astrbot.core.utils.image_ref_utils import is_supported_image_ref @@ -44,6 +46,26 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + @classmethod + def _tool_enabled_for_session( + cls, + tool: FunctionTool, + session_config: dict | None, + ) -> bool: + mp = tool.handler_module_path + if not mp: + return True + + plugin = star_map.get(mp) + if not plugin: + return True + + return SessionPluginManager.is_plugin_enabled_for_session_config( + plugin.name, + session_config, + reserved=plugin.reserved, + ) + @classmethod def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: if image_urls_raw is None: @@ -193,7 +215,7 @@ def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]: return {} @classmethod - def _build_handoff_toolset( + async def _build_handoff_toolset( cls, run_context: ContextWrapper[AstrAgentContext], tools: list[str | FunctionTool] | None, @@ -201,6 +223,9 @@ def _build_handoff_toolset( ctx = run_context.context.context event = run_context.context.event cfg = ctx.get_config(umo=event.unified_msg_origin) + session_config = await SessionPluginManager.get_session_plugin_config( + event.unified_msg_origin + ) provider_settings = cfg.get("provider_settings", {}) runtime = str(provider_settings.get("computer_use_runtime", "local")) runtime_computer_tools = cls._get_runtime_computer_tools(runtime) @@ -212,7 +237,10 @@ def _build_handoff_toolset( for registered_tool in llm_tools.func_list: if isinstance(registered_tool, HandoffTool): continue - if registered_tool.active: + if registered_tool.active and cls._tool_enabled_for_session( + registered_tool, + session_config, + ): toolset.add_tool(registered_tool) for runtime_tool in runtime_computer_tools.values(): toolset.add_tool(runtime_tool) @@ -225,14 +253,19 @@ def _build_handoff_toolset( for tool_name_or_obj in tools: if isinstance(tool_name_or_obj, str): registered_tool = llm_tools.get_func(tool_name_or_obj) - if registered_tool and registered_tool.active: + if ( + registered_tool + and registered_tool.active + and cls._tool_enabled_for_session(registered_tool, session_config) + ): toolset.add_tool(registered_tool) continue runtime_tool = runtime_computer_tools.get(tool_name_or_obj) if runtime_tool: toolset.add_tool(runtime_tool) elif isinstance(tool_name_or_obj, FunctionTool): - toolset.add_tool(tool_name_or_obj) + if cls._tool_enabled_for_session(tool_name_or_obj, session_config): + toolset.add_tool(tool_name_or_obj) return None if toolset.empty() else toolset @classmethod @@ -264,7 +297,7 @@ async def _execute_handoff( tool_args["image_urls"] = image_urls # Build handoff toolset from registered tools plus runtime computer tools. - toolset = cls._build_handoff_toolset(run_context, tool.agent.tools) + toolset = await cls._build_handoff_toolset(run_context, tool.agent.tools) ctx = run_context.context.context event = run_context.context.event diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 2b4a04907e..8703c509f1 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -61,6 +61,7 @@ from astrbot.core.provider.entities import ProviderRequest from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt from astrbot.core.star.context import Context +from astrbot.core.star.session_plugin_manager import SessionPluginManager from astrbot.core.star.star_handler import star_map from astrbot.core.tools.cron_tools import ( CREATE_CRON_JOB_TOOL, @@ -846,33 +847,49 @@ def _sanitize_context_by_modalities( req.contexts = sanitized_contexts -def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: +async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: """根据事件中的插件设置,过滤请求中的工具列表。 注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留, 因为它们不属于任何插件,不应被插件过滤逻辑影响。 """ - if event.plugins_name is not None and req.func_tool: - new_tool_set = ToolSet() - for tool in req.func_tool.tools: - if isinstance(tool, MCPTool): - # 保留 MCP 工具 - new_tool_set.add_tool(tool) - continue - mp = tool.handler_module_path - if not mp: - # 没有 plugin 归属信息的工具(如 subagent transfer_to_*) - # 不应受到会话插件过滤影响。 - new_tool_set.add_tool(tool) - continue - plugin = star_map.get(mp) - if not plugin: - # 无法解析插件归属时,保守保留工具,避免误过滤。 - new_tool_set.add_tool(tool) - continue - if plugin.name in event.plugins_name or plugin.reserved: - new_tool_set.add_tool(tool) - req.func_tool = new_tool_set + if not req.func_tool: + return + + session_config = await SessionPluginManager.get_session_plugin_config( + event.unified_msg_origin + ) + new_tool_set = ToolSet() + for tool in req.func_tool.tools: + if isinstance(tool, MCPTool): + # 保留 MCP 工具 + new_tool_set.add_tool(tool) + continue + mp = tool.handler_module_path + if not mp: + # 没有 plugin 归属信息的工具(如 subagent transfer_to_*) + # 不应受到会话插件过滤影响。 + new_tool_set.add_tool(tool) + continue + plugin = star_map.get(mp) + if not plugin: + # 无法解析插件归属时,保守保留工具,避免误过滤。 + new_tool_set.add_tool(tool) + continue + if ( + event.plugins_name is not None + and not plugin.reserved + and plugin.name not in event.plugins_name + ): + continue + if not SessionPluginManager.is_plugin_enabled_for_session_config( + plugin.name, + session_config, + reserved=plugin.reserved, + ): + continue + new_tool_set.add_tool(tool) + req.func_tool = new_tool_set async def _handle_webchat( @@ -1243,7 +1260,7 @@ async def build_main_agent( req.session_id = event.unified_msg_origin _modalities_fix(provider, req) - _plugin_tool_fix(event, req) + await _plugin_tool_fix(event, req) _sanitize_context_by_modalities(config, provider, req) if config.llm_safety_mode: diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 9402ce3e62..bc26f9d398 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -5,6 +5,7 @@ from astrbot import logger from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.session_plugin_manager import SessionPluginManager from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry @@ -89,11 +90,24 @@ async def call_event_hook( hook_type, plugins_name=event.plugins_name, ) + session_config = await SessionPluginManager.get_session_plugin_config( + event.unified_msg_origin + ) for handler in handlers: + plugin = star_map.get(handler.handler_module_path) + if plugin and not SessionPluginManager.is_plugin_enabled_for_session_config( + plugin.name, + session_config, + reserved=plugin.reserved, + ): + logger.debug( + f"插件 {plugin.name} 在会话 {event.unified_msg_origin} 中被禁用,跳过 hook {handler.handler_name}", + ) + continue try: assert inspect.iscoroutinefunction(handler.handler) logger.debug( - f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", + f"hook({hook_type.name}) -> {plugin.name if plugin else handler.handler_module_path} - {handler.handler_name}", ) await handler.handler(event, *args, **kwargs) except BaseException: @@ -101,7 +115,7 @@ async def call_event_hook( if event.is_stopped(): logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。", + f"{plugin.name if plugin else handler.handler_module_path} - {handler.handler_name} 终止了事件传播。", ) return True diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index a81113415b..474b0cc9ca 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -8,43 +8,67 @@ class SessionPluginManager: """管理会话级别的插件启停状态""" @staticmethod - async def is_plugin_enabled_for_session( - session_id: str, - plugin_name: str, - ) -> bool: - """检查插件是否在指定会话中启用 - - Args: - session_id: 会话ID (unified_msg_origin) - plugin_name: 插件名称 - - Returns: - bool: True表示启用,False表示禁用 - - """ - # 获取会话插件配置 + async def get_session_plugin_config(session_id: str) -> dict: + """获取指定会话的插件配置。""" session_plugin_config = await sp.get_async( scope="umo", scope_id=session_id, key="session_plugin_config", default={}, ) - session_config = session_plugin_config.get(session_id, {}) + return session_plugin_config.get(session_id, {}) + + @staticmethod + def is_plugin_enabled_for_session_config( + plugin_name: str | None, + session_config: dict | None, + *, + reserved: bool = False, + ) -> bool: + """检查插件是否在指定会话配置中启用。""" + if reserved or not plugin_name: + return True + + if not session_config: + return True enabled_plugins = session_config.get("enabled_plugins", []) disabled_plugins = session_config.get("disabled_plugins", []) - # 如果插件在禁用列表中,返回False if plugin_name in disabled_plugins: return False - # 如果插件在启用列表中,返回True if plugin_name in enabled_plugins: return True - # 如果都没有配置,默认为启用(兼容性考虑) return True + @staticmethod + async def is_plugin_enabled_for_session( + session_id: str, + plugin_name: str, + *, + reserved: bool = False, + ) -> bool: + """检查插件是否在指定会话中启用 + + Args: + session_id: 会话ID (unified_msg_origin) + plugin_name: 插件名称 + + Returns: + bool: True表示启用,False表示禁用 + + """ + session_config = await SessionPluginManager.get_session_plugin_config( + session_id + ) + return SessionPluginManager.is_plugin_enabled_for_session_config( + plugin_name, + session_config, + reserved=reserved, + ) + @staticmethod async def filter_handlers_by_session( event: AstrMessageEvent, @@ -65,14 +89,9 @@ async def filter_handlers_by_session( session_id = event.unified_msg_origin filtered_handlers = [] - session_plugin_config = await sp.get_async( - scope="umo", - scope_id=session_id, - key="session_plugin_config", - default={}, + session_config = await SessionPluginManager.get_session_plugin_config( + session_id ) - session_config = session_plugin_config.get(session_id, {}) - disabled_plugins = session_config.get("disabled_plugins", []) for handler in handlers: # 获取处理器对应的插件 @@ -91,7 +110,11 @@ async def filter_handlers_by_session( continue # 检查插件是否在当前会话中启用 - if plugin.name in disabled_plugins: + if not SessionPluginManager.is_plugin_enabled_for_session_config( + plugin.name, + session_config, + reserved=plugin.reserved, + ): logger.debug( f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}", ) diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..101eff4969 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -1,9 +1,11 @@ from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch import mcp import pytest from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image @@ -272,6 +274,44 @@ async def _fake_convert_to_file_path(self): assert image_urls == [] +@pytest.mark.asyncio +async def test_build_handoff_toolset_filters_session_disabled_plugin_tool(): + plugin_tool = FunctionTool( + name="memorix_tool", + description="memorix tool", + parameters={"type": "object", "properties": {}}, + handler_module_path="test_plugin", + active=True, + ) + run_context = ContextWrapper( + context=SimpleNamespace( + event=_DummyEvent([]), + context=SimpleNamespace( + get_config=lambda **_kwargs: {"provider_settings": {"computer_use_runtime": "none"}} + ), + ) + ) + + with patch( + "astrbot.core.astr_agent_tool_exec.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={"disabled_plugins": ["astrbot_plugin_memorix"]}), + ) as mock_get_config, patch( + "astrbot.core.astr_agent_tool_exec.llm_tools" + ) as mock_llm_tools, patch( + "astrbot.core.astr_agent_tool_exec.star_map" + ) as mock_star_map: + mock_llm_tools.func_list = [plugin_tool] + mock_plugin = MagicMock() + mock_plugin.name = "astrbot_plugin_memorix" + mock_plugin.reserved = False + mock_star_map.get.return_value = mock_plugin + + toolset = await FunctionToolExecutor._build_handoff_toolset(run_context, None) + + mock_get_config.assert_awaited_once() + assert toolset is None + + @pytest.mark.asyncio async def test_execute_handoff_passes_tool_call_timeout_to_tool_loop_agent( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..84d63c8534 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -805,17 +805,23 @@ def test_sanitize_removes_image_blocks(self, mock_provider): class TestPluginToolFix: """Tests for _plugin_tool_fix function.""" - def test_plugin_tool_fix_none_plugins(self, mock_event): + @pytest.mark.asyncio + async def test_plugin_tool_fix_none_plugins(self, mock_event): """Test plugin tool fix when no plugins specified.""" module = ama req = ProviderRequest(func_tool=ToolSet()) mock_event.plugins_name = None - module._plugin_tool_fix(mock_event, req) + with patch( + "astrbot.core.astr_main_agent.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={}), + ): + await module._plugin_tool_fix(mock_event, req) assert req.func_tool is not None - def test_plugin_tool_fix_filters_by_plugin(self, mock_event): + @pytest.mark.asyncio + async def test_plugin_tool_fix_filters_by_plugin(self, mock_event): """Test plugin tool fix filters tools by enabled plugins.""" module = ama mcp_tool = MagicMock(spec=MCPTool) @@ -839,12 +845,17 @@ def test_plugin_tool_fix_filters_by_plugin(self, mock_event): mock_plugin.reserved = False mock_star_map.get.return_value = mock_plugin - module._plugin_tool_fix(mock_event, req) + with patch( + "astrbot.core.astr_main_agent.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={}), + ): + await module._plugin_tool_fix(mock_event, req) assert "mcp_tool" in req.func_tool.names() assert "plugin_tool" in req.func_tool.names() - def test_plugin_tool_fix_mcp_preserved(self, mock_event): + @pytest.mark.asyncio + async def test_plugin_tool_fix_mcp_preserved(self, mock_event): """Test that MCP tools are always preserved.""" module = ama mcp_tool = MagicMock(spec=MCPTool) @@ -858,11 +869,18 @@ def test_plugin_tool_fix_mcp_preserved(self, mock_event): mock_event.plugins_name = ["other_plugin"] with patch("astrbot.core.astr_main_agent.star_map"): - module._plugin_tool_fix(mock_event, req) + with patch( + "astrbot.core.astr_main_agent.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={}), + ): + await module._plugin_tool_fix(mock_event, req) assert "mcp_tool" in req.func_tool.names() - def test_plugin_tool_fix_preserves_tools_without_plugin_origin(self, mock_event): + @pytest.mark.asyncio + async def test_plugin_tool_fix_preserves_tools_without_plugin_origin( + self, mock_event + ): """Tools without handler_module_path should not be filtered out.""" module = ama handoff_tool = FunctionTool( @@ -880,10 +898,44 @@ def test_plugin_tool_fix_preserves_tools_without_plugin_origin(self, mock_event) mock_event.plugins_name = ["other_plugin"] with patch("astrbot.core.astr_main_agent.star_map"): - module._plugin_tool_fix(mock_event, req) + with patch( + "astrbot.core.astr_main_agent.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={}), + ): + await module._plugin_tool_fix(mock_event, req) assert "transfer_to_demo_agent" in req.func_tool.names() + @pytest.mark.asyncio + async def test_plugin_tool_fix_filters_session_disabled_plugin(self, mock_event): + """Session-disabled plugin tools should not enter the request toolset.""" + module = ama + plugin_tool = FunctionTool( + name="memorix_search", + description="memorix tool", + parameters={"type": "object", "properties": {}}, + handler_module_path="test_plugin", + active=True, + ) + req = ProviderRequest(func_tool=ToolSet([plugin_tool])) + mock_event.plugins_name = ["*"] + + with patch("astrbot.core.astr_main_agent.star_map") as mock_star_map: + mock_plugin = MagicMock() + mock_plugin.name = "astrbot_plugin_memorix" + mock_plugin.reserved = False + mock_star_map.get.return_value = mock_plugin + + with patch( + "astrbot.core.astr_main_agent.SessionPluginManager.get_session_plugin_config", + new=AsyncMock( + return_value={"disabled_plugins": ["astrbot_plugin_memorix"]} + ), + ): + await module._plugin_tool_fix(mock_event, req) + + assert "memorix_search" not in req.func_tool.names() + class TestBuildMainAgent: """Tests for build_main_agent function.""" diff --git a/tests/unit/test_context_utils.py b/tests/unit/test_context_utils.py new file mode 100644 index 0000000000..370da9fb40 --- /dev/null +++ b/tests/unit/test_context_utils.py @@ -0,0 +1,38 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.pipeline.context_utils import call_event_hook +from astrbot.core.star.star_handler import EventType + + +@pytest.mark.asyncio +async def test_call_event_hook_skips_session_disabled_plugin(): + event = MagicMock() + event.plugins_name = ["*"] + event.unified_msg_origin = "test_platform:private:session123" + event.is_stopped.return_value = False + + handler = MagicMock() + handler.handler_name = "on_llm_request" + handler.handler_module_path = "test_plugin" + handler.handler = AsyncMock() + + with patch( + "astrbot.core.pipeline.context_utils.star_handlers_registry.get_handlers_by_event_type", + return_value=[handler], + ), patch( + "astrbot.core.pipeline.context_utils.star_map" + ) as mock_star_map, patch( + "astrbot.core.pipeline.context_utils.SessionPluginManager.get_session_plugin_config", + new=AsyncMock(return_value={"disabled_plugins": ["astrbot_plugin_memorix"]}), + ): + mock_plugin = MagicMock() + mock_plugin.name = "astrbot_plugin_memorix" + mock_plugin.reserved = False + mock_star_map.get.return_value = mock_plugin + + stopped = await call_event_hook(event, EventType.OnLLMRequestEvent) + + assert stopped is False + handler.handler.assert_not_awaited()