diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index f5ab15ed09..ad7002bda9 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -27,6 +27,9 @@ from astrbot.core.star.register import register_on_plugin_error as on_plugin_error from astrbot.core.star.register import register_on_plugin_loaded as on_plugin_loaded from astrbot.core.star.register import register_on_plugin_unloaded as on_plugin_unloaded +from astrbot.core.star.register import ( + register_on_prompt_assembly as on_prompt_assembly, +) from astrbot.core.star.register import register_on_using_llm_tool as on_using_llm_tool from astrbot.core.star.register import ( register_on_waiting_llm_request as on_waiting_llm_request, @@ -54,6 +57,7 @@ "on_astrbot_loaded", "on_decorating_result", "on_llm_request", + "on_prompt_assembly", "on_llm_response", "on_plugin_error", "on_plugin_loaded", diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 2b4a04907e..242a92767c 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -13,7 +13,6 @@ from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool -from astrbot.core.agent.message import TextPart from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS @@ -56,12 +55,37 @@ extract_persona_custom_error_message_from_persona, set_persona_custom_error_message_on_event, ) +from astrbot.core.pipeline.context_utils import call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.prompt import ( + CONTEXT_ORDER_FILE_EXTRACT, + CONTEXT_ORDER_PERSONA_BEGIN_DIALOGS, + SYSTEM_BLOCK_ORDER_KB, + SYSTEM_BLOCK_ORDER_LIVE_MODE, + SYSTEM_BLOCK_ORDER_PERSONA, + SYSTEM_BLOCK_ORDER_ROUTER, + SYSTEM_BLOCK_ORDER_RUNTIME, + SYSTEM_BLOCK_ORDER_SAFETY, + SYSTEM_BLOCK_ORDER_SKILLS, + SYSTEM_BLOCK_ORDER_TOOL_USE, + USER_APPEND_ORDER_ATTACHMENTS, + USER_APPEND_ORDER_QUOTED, + USER_APPEND_ORDER_SYSTEM_REMINDER, + PromptAssembly, + PromptMutation, + add_context_prefix, + add_context_suffix, + add_system_block, + add_user_text, + build_prompt_trace_snapshot, + render_prompt_assembly, + summarize_provider_request_base, +) from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt from astrbot.core.star.context import Context -from astrbot.core.star.star_handler import star_map +from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.tools.cron_tools import ( CREATE_CRON_JOB_TOOL, DELETE_CRON_JOB_TOOL, @@ -197,6 +221,7 @@ async def _apply_kb( req: ProviderRequest, plugin_context: Context, config: MainAgentBuildConfig, + assembly: PromptAssembly, ) -> None: if not config.kb_agentic_mode: if req.prompt is None: @@ -209,10 +234,12 @@ async def _apply_kb( ) if not kb_result: return - if req.system_prompt is not None: - req.system_prompt += ( - f"\n\n[Related Knowledge Base Results]:\n{kb_result}" - ) + add_system_block( + assembly, + source="knowledge_base", + order=SYSTEM_BLOCK_ORDER_KB, + content=f"\n\n[Related Knowledge Base Results]:\n{kb_result}", + ) except Exception as exc: # noqa: BLE001 logger.error("Error occurred while retrieving knowledge base: %s", exc) else: @@ -225,6 +252,7 @@ async def _apply_file_extract( event: AstrMessageEvent, req: ProviderRequest, config: MainAgentBuildConfig, + assembly: PromptAssembly, ) -> None: file_paths = [] file_names = [] @@ -259,14 +287,19 @@ async def _apply_file_extract( return for file_content, file_name in zip(file_contents, file_names): - req.contexts.append( - { - "role": "system", - "content": ( - "File Extract Results of user uploaded files:\n" - f"{file_content}\nFile Name: {file_name or 'Unknown'}" - ), - }, + add_context_suffix( + assembly, + source="file_extract", + order=CONTEXT_ORDER_FILE_EXTRACT, + messages=[ + { + "role": "system", + "content": ( + "File Extract Results of user uploaded files:\n" + f"{file_content}\nFile Name: {file_name or 'Unknown'}" + ), + } + ], ) @@ -280,12 +313,20 @@ def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None: req.prompt = f"{prefix}{req.prompt}" -def _apply_local_env_tools(req: ProviderRequest) -> None: +def _apply_local_env_tools( + req: ProviderRequest, + assembly: PromptAssembly, +) -> None: if req.func_tool is None: req.func_tool = ToolSet() req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL) req.func_tool.add_tool(LOCAL_PYTHON_TOOL) - req.system_prompt = f"{req.system_prompt or ''}\n{_build_local_mode_prompt()}\n" + add_system_block( + assembly, + source="runtime:local", + order=SYSTEM_BLOCK_ORDER_RUNTIME, + content=f"\n{_build_local_mode_prompt()}\n", + ) def _build_local_mode_prompt() -> str: @@ -308,6 +349,7 @@ async def _ensure_persona_and_skills( cfg: dict, plugin_context: Context, event: AstrMessageEvent, + assembly: PromptAssembly, ) -> None: """Ensure persona and skills are applied to the request's system prompt or user prompt.""" if not req.conversation: @@ -332,11 +374,26 @@ async def _ensure_persona_and_skills( if persona: # Inject persona system prompt if prompt := persona["prompt"]: - req.system_prompt += f"\n# Persona Instructions\n\n{prompt}\n" + add_system_block( + assembly, + source="persona", + order=SYSTEM_BLOCK_ORDER_PERSONA, + content=f"\n# Persona Instructions\n\n{prompt}\n", + ) if begin_dialogs := copy.deepcopy(persona.get("_begin_dialogs_processed")): - req.contexts[:0] = begin_dialogs + add_context_prefix( + assembly, + source="persona_begin_dialogs", + order=CONTEXT_ORDER_PERSONA_BEGIN_DIALOGS, + messages=begin_dialogs, + ) elif use_webchat_special_default: - req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT + add_system_block( + assembly, + source="persona", + order=SYSTEM_BLOCK_ORDER_PERSONA, + content=CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT, + ) # Inject skills prompt runtime = cfg.get("computer_use_runtime", "local") @@ -351,12 +408,22 @@ async def _ensure_persona_and_skills( allowed = set(persona["skills"]) skills = [skill for skill in skills if skill.name in allowed] if skills: - req.system_prompt += f"\n{build_skills_prompt(skills)}\n" + add_system_block( + assembly, + source="skills", + order=SYSTEM_BLOCK_ORDER_SKILLS, + content=f"\n{build_skills_prompt(skills)}\n", + ) if runtime == "none": - req.system_prompt += ( - "User has not enabled the Computer Use feature. " - "You cannot use shell or Python to perform skills. " - "If you need to use these capabilities, ask the user to enable Computer Use in the AstrBot WebUI -> Config." + add_system_block( + assembly, + source="skills_runtime_notice", + order=SYSTEM_BLOCK_ORDER_SKILLS + 10, + content=( + "User has not enabled the Computer Use feature. " + "You cannot use shell or Python to perform skills. " + "If you need to use these capabilities, ask the user to enable Computer Use in the AstrBot WebUI -> Config." + ), ) tmgr = plugin_context.get_llm_tool_manager() @@ -438,7 +505,12 @@ async def _ensure_persona_and_skills( .get("router_system_prompt", "") ).strip() if router_prompt: - req.system_prompt += f"\n{router_prompt}\n" + add_system_block( + assembly, + source="router", + order=SYSTEM_BLOCK_ORDER_ROUTER, + content=f"\n{router_prompt}\n", + ) try: event.trace.record( "sel_persona", @@ -483,6 +555,7 @@ async def _ensure_img_caption( cfg: dict, plugin_context: Context, image_caption_provider: str, + assembly: PromptAssembly, ) -> None: try: compressed_urls = [] @@ -498,20 +571,34 @@ async def _ensure_img_caption( plugin_context, ) if caption: - req.extra_user_content_parts.append( - TextPart(text=f"{caption}") + add_user_text( + assembly, + source="image_caption", + order=USER_APPEND_ORDER_ATTACHMENTS, + text=f"{caption}", ) req.image_urls = [] except Exception as exc: # noqa: BLE001 logger.error("处理图片描述失败: %s", exc) - req.extra_user_content_parts.append(TextPart(text="[Image Captioning Failed]")) + add_user_text( + assembly, + source="image_caption", + order=USER_APPEND_ORDER_ATTACHMENTS, + text="[Image Captioning Failed]", + ) finally: req.image_urls = [] -def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None: - req.extra_user_content_parts.append( - TextPart(text=f"[Image Attachment in quoted message: path {image_path}]") +def _append_quoted_image_attachment( + assembly: PromptAssembly, + image_path: str, +) -> None: + add_user_text( + assembly, + source="quoted_attachment", + order=USER_APPEND_ORDER_ATTACHMENTS, + text=f"[Image Attachment in quoted message: path {image_path}]", ) @@ -582,6 +669,7 @@ async def _process_quote_message( req: ProviderRequest, img_cap_prov_id: str, plugin_context: Context, + assembly: PromptAssembly, quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS, config: MainAgentBuildConfig | None = None, ) -> None: @@ -656,7 +744,12 @@ async def _process_quote_message( quoted_content = "\n".join(content_parts) quoted_text = f"\n{quoted_content}\n" - req.extra_user_content_parts.append(TextPart(text=quoted_text)) + add_user_text( + assembly, + source="quoted_message", + order=USER_APPEND_ORDER_QUOTED, + text=quoted_text, + ) def _append_system_reminders( @@ -664,6 +757,7 @@ def _append_system_reminders( req: ProviderRequest, cfg: dict, timezone: str | None, + assembly: PromptAssembly, ) -> None: system_parts: list[str] = [] if cfg.get("identifier"): @@ -700,7 +794,12 @@ def _append_system_reminders( system_content = ( "" + "\n".join(system_parts) + "" ) - req.extra_user_content_parts.append(TextPart(text=system_content)) + add_user_text( + assembly, + source="system_reminder", + order=USER_APPEND_ORDER_SYSTEM_REMINDER, + text=system_content, + ) async def _decorate_llm_request( @@ -708,6 +807,7 @@ async def _decorate_llm_request( req: ProviderRequest, plugin_context: Context, config: MainAgentBuildConfig, + assembly: PromptAssembly, ) -> None: cfg = config.provider_settings or plugin_context.get_config( umo=event.unified_msg_origin @@ -716,7 +816,7 @@ async def _decorate_llm_request( _apply_prompt_prefix(req, cfg) if req.conversation: - await _ensure_persona_and_skills(req, cfg, plugin_context, event) + await _ensure_persona_and_skills(req, cfg, plugin_context, event, assembly) img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or "" if img_cap_prov_id and req.image_urls: @@ -726,6 +826,7 @@ async def _decorate_llm_request( cfg, plugin_context, img_cap_prov_id, + assembly, ) img_cap_prov_id = cfg.get("default_image_caption_provider_id") or "" @@ -735,6 +836,7 @@ async def _decorate_llm_request( req, img_cap_prov_id, plugin_context, + assembly, quoted_message_settings, config, ) @@ -742,7 +844,7 @@ async def _decorate_llm_request( tz = config.timezone if tz is None: tz = plugin_context.get_config().get("timezone") - _append_system_reminders(event, req, cfg, tz) + _append_system_reminders(event, req, cfg, tz, assembly) def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: @@ -919,9 +1021,19 @@ async def _handle_webchat( ) -def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -> None: +def _apply_llm_safety_mode( + config: MainAgentBuildConfig, + req: ProviderRequest, + assembly: PromptAssembly, +) -> None: if config.safety_mode_strategy == "system_prompt": - req.system_prompt = f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n{req.system_prompt}" + add_system_block( + assembly, + source="safety", + order=SYSTEM_BLOCK_ORDER_SAFETY, + content=f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n", + prepend=True, + ) else: logger.warning( "Unsupported llm_safety_mode strategy: %s.", @@ -930,12 +1042,13 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) - def _apply_sandbox_tools( - config: MainAgentBuildConfig, req: ProviderRequest, session_id: str + config: MainAgentBuildConfig, + req: ProviderRequest, + session_id: str, + assembly: PromptAssembly, ) -> None: if req.func_tool is None: req.func_tool = ToolSet() - if req.system_prompt is None: - req.system_prompt = "" booter = config.sandbox_cfg.get("booter", "shipyard_neo") if booter == "shipyard": ep = config.sandbox_cfg.get("shipyard_endpoint", "") @@ -953,23 +1066,33 @@ def _apply_sandbox_tools( if booter == "shipyard_neo": # Neo-specific path rule: filesystem tools operate relative to sandbox # workspace root. Do not prepend "/workspace". - req.system_prompt += ( - "\n[Shipyard Neo File Path Rule]\n" - "When using sandbox filesystem tools (upload/download/read/write/list/delete), " - "always pass paths relative to the sandbox workspace root. " - "Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n" + add_system_block( + assembly, + source="runtime:sandbox_path_rule", + order=SYSTEM_BLOCK_ORDER_RUNTIME, + content=( + "\n[Shipyard Neo File Path Rule]\n" + "When using sandbox filesystem tools (upload/download/read/write/list/delete), " + "always pass paths relative to the sandbox workspace root. " + "Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n" + ), ) - req.system_prompt += ( - "\n[Neo Skill Lifecycle Workflow]\n" - "When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n" - "Preferred sequence:\n" - "1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n" - "2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n" - "3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n" - "For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n" - "Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n" - "To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n" + add_system_block( + assembly, + source="runtime:sandbox_skill_lifecycle", + order=SYSTEM_BLOCK_ORDER_RUNTIME + 10, + content=( + "\n[Neo Skill Lifecycle Workflow]\n" + "When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n" + "Preferred sequence:\n" + "1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n" + "2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n" + "3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n" + "For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n" + "Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n" + "To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n" + ), ) # Determine sandbox capabilities from an already-booted session. @@ -1002,7 +1125,12 @@ def _apply_sandbox_tools( req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL) req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL) - req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" + add_system_block( + assembly, + source="runtime:sandbox", + order=SYSTEM_BLOCK_ORDER_RUNTIME + 20, + content=f"\n{SANDBOX_MODE_PROMPT}\n", + ) def _proactive_cron_job_tools(req: ProviderRequest) -> None: @@ -1089,6 +1217,8 @@ async def build_main_agent( logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。") return None + prompt_assembly = PromptAssembly() + if req is None: if event.get_extra("provider_request"): req = event.get_extra("provider_request") @@ -1121,16 +1251,20 @@ async def build_main_agent( if _is_generated_compressed_image_path(path, image_path): event.track_temporary_local_file(image_path) req.image_urls.append(image_path) - req.extra_user_content_parts.append( - TextPart(text=f"[Image Attachment: path {image_path}]") + add_user_text( + prompt_assembly, + source="attachment", + order=USER_APPEND_ORDER_ATTACHMENTS, + text=f"[Image Attachment: path {image_path}]", ) elif isinstance(comp, File): file_path = await comp.get_file() file_name = comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=f"[File Attachment: name {file_name}, path {file_path}]" - ) + add_user_text( + prompt_assembly, + source="attachment", + order=USER_APPEND_ORDER_ATTACHMENTS, + text=f"[File Attachment: name {file_name}, path {file_path}]", ) # quoted message attachments reply_comps = [ @@ -1154,17 +1288,18 @@ async def build_main_agent( if _is_generated_compressed_image_path(path, image_path): event.track_temporary_local_file(image_path) req.image_urls.append(image_path) - _append_quoted_image_attachment(req, image_path) + _append_quoted_image_attachment(prompt_assembly, image_path) elif isinstance(reply_comp, File): file_path = await reply_comp.get_file() file_name = reply_comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=( - f"[File Attachment in quoted message: " - f"name {file_name}, path {file_path}]" - ) - ) + add_user_text( + prompt_assembly, + source="quoted_attachment", + order=USER_APPEND_ORDER_ATTACHMENTS, + text=( + f"[File Attachment in quoted message: " + f"name {file_name}, path {file_path}]" + ), ) # Fallback quoted image extraction for reply-id-only payloads, or when @@ -1204,7 +1339,7 @@ async def build_main_agent( continue req.image_urls.append(image_ref) fallback_quoted_image_count += 1 - _append_quoted_image_attachment(req, image_ref) + _append_quoted_image_attachment(prompt_assembly, image_ref) except Exception as exc: # noqa: BLE001 logger.warning( "Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s", @@ -1225,34 +1360,35 @@ async def build_main_agent( if config.file_extract_enabled: try: - await _apply_file_extract(event, req, config) + await _apply_file_extract(event, req, config, prompt_assembly) except Exception as exc: # noqa: BLE001 logger.error("Error occurred while applying file extract: %s", exc) if not req.prompt and not req.image_urls: - if not event.get_group_id() and req.extra_user_content_parts: + if not event.get_group_id() and ( + req.extra_user_content_parts or prompt_assembly.user_append_parts + ): req.prompt = "" else: return None - await _decorate_llm_request(event, req, plugin_context, config) + await _decorate_llm_request(event, req, plugin_context, config, prompt_assembly) - await _apply_kb(event, req, plugin_context, config) + await _apply_kb(event, req, plugin_context, config, prompt_assembly) if not req.session_id: req.session_id = event.unified_msg_origin _modalities_fix(provider, req) _plugin_tool_fix(event, req) - _sanitize_context_by_modalities(config, provider, req) if config.llm_safety_mode: - _apply_llm_safety_mode(config, req) + _apply_llm_safety_mode(config, req, prompt_assembly) if config.computer_use_runtime == "sandbox": - _apply_sandbox_tools(config, req, req.session_id) + _apply_sandbox_tools(config, req, req.session_id, prompt_assembly) elif config.computer_use_runtime == "local": - _apply_local_env_tools(req) + _apply_local_env_tools(req, prompt_assembly) agent_runner = AgentRunner() astr_agent_ctx = AstrAgentContext( @@ -1284,11 +1420,40 @@ async def build_main_agent( if config.tool_schema_mode == "full" else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE ) - req.system_prompt += f"\n{tool_prompt}\n" + add_system_block( + prompt_assembly, + source="tool_use", + order=SYSTEM_BLOCK_ORDER_TOOL_USE, + content=f"\n{tool_prompt}\n", + ) action_type = event.get_extra("action_type") if action_type == "live": - req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n" + add_system_block( + prompt_assembly, + source="live_mode", + order=SYSTEM_BLOCK_ORDER_LIVE_MODE, + content=f"\n{LIVE_MODE_SYSTEM_PROMPT}\n", + ) + + prompt_assembly.metadata["base_request"] = summarize_provider_request_base(req) + core_prompt_trace_snapshot = build_prompt_trace_snapshot(prompt_assembly) + if await call_event_hook( + event, + EventType.OnPromptAssemblyEvent, + PromptMutation(prompt_assembly), + ): + return None + + render_prompt_assembly(req, prompt_assembly) + try: + event.trace.record( + "core_prompt_assembly", + **core_prompt_trace_snapshot, + ) + except Exception: + logger.debug("Failed to record core_prompt_assembly trace", exc_info=True) + _sanitize_context_by_modalities(config, provider, req) reset_coro = agent_runner.reset( provider=provider, diff --git a/astrbot/core/prompt/__init__.py b/astrbot/core/prompt/__init__.py new file mode 100644 index 0000000000..bee8a96cd1 --- /dev/null +++ b/astrbot/core/prompt/__init__.py @@ -0,0 +1,58 @@ +from .assembly import ( + PromptMutation, + add_context_prefix, + add_context_suffix, + add_system_block, + add_user_part, + add_user_text, +) +from .models import ( + CONTEXT_ORDER_FILE_EXTRACT, + CONTEXT_ORDER_PERSONA_BEGIN_DIALOGS, + SYSTEM_BLOCK_ORDER_KB, + SYSTEM_BLOCK_ORDER_LIVE_MODE, + SYSTEM_BLOCK_ORDER_PERSONA, + SYSTEM_BLOCK_ORDER_ROUTER, + SYSTEM_BLOCK_ORDER_RUNTIME, + SYSTEM_BLOCK_ORDER_SAFETY, + SYSTEM_BLOCK_ORDER_SKILLS, + SYSTEM_BLOCK_ORDER_TOOL_USE, + USER_APPEND_ORDER_ATTACHMENTS, + USER_APPEND_ORDER_QUOTED, + USER_APPEND_ORDER_SYSTEM_REMINDER, + ContextContribution, + PromptAssembly, + SystemBlock, + UserAppendPart, +) +from .renderer import render_prompt_assembly +from .tracing import build_prompt_trace_snapshot, summarize_provider_request_base + +__all__ = [ + "CONTEXT_ORDER_FILE_EXTRACT", + "CONTEXT_ORDER_PERSONA_BEGIN_DIALOGS", + "SYSTEM_BLOCK_ORDER_KB", + "SYSTEM_BLOCK_ORDER_LIVE_MODE", + "SYSTEM_BLOCK_ORDER_PERSONA", + "SYSTEM_BLOCK_ORDER_ROUTER", + "SYSTEM_BLOCK_ORDER_RUNTIME", + "SYSTEM_BLOCK_ORDER_SAFETY", + "SYSTEM_BLOCK_ORDER_SKILLS", + "SYSTEM_BLOCK_ORDER_TOOL_USE", + "USER_APPEND_ORDER_ATTACHMENTS", + "USER_APPEND_ORDER_QUOTED", + "USER_APPEND_ORDER_SYSTEM_REMINDER", + "ContextContribution", + "PromptAssembly", + "PromptMutation", + "SystemBlock", + "UserAppendPart", + "add_context_prefix", + "add_context_suffix", + "add_system_block", + "add_user_part", + "add_user_text", + "build_prompt_trace_snapshot", + "render_prompt_assembly", + "summarize_provider_request_base", +] diff --git a/astrbot/core/prompt/assembly.py b/astrbot/core/prompt/assembly.py new file mode 100644 index 0000000000..dc2281b0b3 --- /dev/null +++ b/astrbot/core/prompt/assembly.py @@ -0,0 +1,305 @@ +""" +Prompt Assembly 注册方法 — 向 PromptAssembly 中添加各通道内容的便捷 API。 + +每个函数对应一个通道,负责创建对应的数据对象并追加到 assembly 中。 +所有函数都会跳过空内容(空字符串、空消息列表),避免产生无意义的区块。 +""" + +from __future__ import annotations + +import copy + +from astrbot.core import logger +from astrbot.core.agent.message import ContentPart, TextPart + +from .models import ( + ContextContribution, + ContextPosition, + PromptAssembly, + SystemBlock, + UserAppendPart, +) + + +class PromptMutation: + """Restricted facade exposed to prompt assembly hooks.""" + + __slots__ = ( + "_assembly", + "_warned_context_prefix_sources", + "_warned_plugin_orders", + ) + + def __init__(self, assembly: PromptAssembly) -> None: + self._assembly = assembly + self._warned_context_prefix_sources: set[str] = set() + self._warned_plugin_orders: set[tuple[str, int]] = set() + + def add_system( + self, + text: str, + source: str, + order: int, + *, + visible_in_trace: bool = True, + ) -> None: + self._warn_if_reserved_order(source, order) + add_system_block( + self._assembly, + source=source, + order=order, + content=text, + visible_in_trace=visible_in_trace, + ) + + def add_user_text( + self, + text: str, + source: str, + order: int, + *, + visible_in_trace: bool = True, + ) -> None: + self._warn_if_reserved_order(source, order) + add_user_text( + self._assembly, + source=source, + order=order, + text=text, + visible_in_trace=visible_in_trace, + ) + + def add_context_prefix( + self, + messages: list[dict], + source: str, + order: int, + *, + visible_in_trace: bool = True, + ) -> None: + # order 警告:插件开发者应避免使用过低的 order 值(如 <900),以免与核心保留的提示块发生排序冲突。 + self._warn_if_reserved_order(source, order) + # Context prefix 插在历史最前,最有可能影响 KV cache 效率,发出警告提示插件开发者确认是否真的需要使用 context prefix + self._warn_if_context_prefix_affects_cache(source, messages) + add_context_prefix( + self._assembly, + source=source, + order=order, + messages=messages, + visible_in_trace=visible_in_trace, + ) + + def add_context_suffix( + self, + messages: list[dict], + source: str, + order: int, + *, + visible_in_trace: bool = True, + ) -> None: + self._warn_if_reserved_order(source, order) + add_context_suffix( + self._assembly, + source=source, + order=order, + messages=messages, + visible_in_trace=visible_in_trace, + ) + + def _warn_if_reserved_order(self, source: str, order: int) -> None: + if order >= 900: + return + warn_key = (source, order) + if warn_key in self._warned_plugin_orders: + return + self._warned_plugin_orders.add(warn_key) + logger.warning( + "Prompt assembly plugin order %s for source %s overlaps the core-reserved range. " + "Prefer plugin orders >= 900.", + order, + source, + ) + + def _warn_if_context_prefix_affects_cache( + self, + source: str, + messages: list[dict], + ) -> None: + if not messages or source in self._warned_context_prefix_sources: + return + self._warned_context_prefix_sources.add(source) + logger.warning( + "Prompt assembly context prefix from source %s is prepended to the message history " + "and may reduce provider-side KV cache prefix reuse. Prefer add_system() for static " + "policy text, and reserve add_context_prefix() for few-shot examples or synthetic " + "history that must appear before the conversation.", + source, + ) + + +def add_system_block( + assembly: PromptAssembly, + *, + source: str, + order: int, + content: str, + prepend: bool = False, + visible_in_trace: bool = True, +) -> None: + """向 assembly 注册一个 system prompt 区块。 + + Args: + assembly: 目标 assembly 容器 + source: 来源标识,如 "persona", "safety", "kb" + order: 通道内排序权重,使用 models.py 中的常量 + content: 区块文本内容,为空时静默跳过 + prepend: True 则插到 system_prompt 最前面(仅 SAFETY 使用) + visible_in_trace: 是否在 trace 快照中可见 + """ + if not content: + return + assembly.system_blocks.append( + SystemBlock( + source=source, + order=order, + content=content, + prepend=prepend, + visible_in_trace=visible_in_trace, + ) + ) + + +def add_user_part( + assembly: PromptAssembly, + *, + source: str, + order: int, + part: ContentPart, + visible_in_trace: bool = True, +) -> None: + """向 assembly 注册一个用户消息追加片段(通用版,接受任意 ContentPart)。 + + Args: + assembly: 目标 assembly 容器 + source: 来源标识,如 "attachment", "quoted_message" + order: 通道内排序权重 + part: 内容片段(TextPart、ImagePart 等) + visible_in_trace: 是否在 trace 快照中可见 + """ + assembly.user_append_parts.append( + UserAppendPart( + source=source, + order=order, + part=part, + visible_in_trace=visible_in_trace, + ) + ) + + +def add_user_text( + assembly: PromptAssembly, + *, + source: str, + order: int, + text: str, + visible_in_trace: bool = True, +) -> None: + """向 assembly 注册一段纯文本追加到用户消息(add_user_part 的文本快捷方式)。 + + Args: + assembly: 目标 assembly 容器 + source: 来源标识 + order: 通道内排序权重 + text: 纯文本内容,为空时静默跳过 + visible_in_trace: 是否在 trace 快照中可见 + """ + if not text: + return + add_user_part( + assembly, + source=source, + order=order, + part=TextPart(text=text), + visible_in_trace=visible_in_trace, + ) + + +def add_context_prefix( + assembly: PromptAssembly, + *, + source: str, + order: int, + messages: list[dict], + visible_in_trace: bool = True, +) -> None: + """向 assembly 注册一组前缀消息(插到对话历史最前面)。 + + 典型用途:人格预设对话示例(persona begin_dialogs)。 + + Args: + assembly: 目标 assembly 容器 + source: 来源标识,如 "persona_begin_dialogs" + order: 通道内排序权重 + messages: OpenAI 格式消息列表,为空时静默跳过 + visible_in_trace: 是否在 trace 快照中可见 + """ + _add_context_contribution( + assembly, + source=source, + order=order, + messages=messages, + position="prefix", + visible_in_trace=visible_in_trace, + ) + + +def add_context_suffix( + assembly: PromptAssembly, + *, + source: str, + order: int, + messages: list[dict], + visible_in_trace: bool = True, +) -> None: + """向 assembly 注册一组后缀消息(追加到对话历史末尾)。 + + 典型用途:文件提取合成消息(file extract results)。 + + Args: + assembly: 目标 assembly 容器 + source: 来源标识,如 "file_extract" + order: 通道内排序权重 + messages: OpenAI 格式消息列表,为空时静默跳过 + visible_in_trace: 是否在 trace 快照中可见 + """ + _add_context_contribution( + assembly, + source=source, + order=order, + messages=messages, + position="suffix", + visible_in_trace=visible_in_trace, + ) + + +def _add_context_contribution( + assembly: PromptAssembly, + *, + source: str, + order: int, + messages: list[dict], + position: ContextPosition, + visible_in_trace: bool = True, +) -> None: + """context 贡献的内部实现,被 add_context_prefix / add_context_suffix 调用。""" + if not messages: + return + assembly.context_contributions.append( + ContextContribution( + source=source, + order=order, + messages=copy.deepcopy(messages), + visible_in_trace=visible_in_trace, + position=position, + ) + ) diff --git a/astrbot/core/prompt/models.py b/astrbot/core/prompt/models.py new file mode 100644 index 0000000000..db59c203a9 --- /dev/null +++ b/astrbot/core/prompt/models.py @@ -0,0 +1,155 @@ +""" +Prompt Assembly data models. + +This module defines the three-channel prompt assembly structures used by +``assembly.py``, ``renderer.py`` and ``tracing.py``. + +The core types exported via ``astrbot.core.prompt`` are part of the extension +surface for prompt assembly integrations. Internal wiring around those types +may still evolve without separate notice. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +from astrbot.core.agent.message import ContentPart + +# --------------------------------------------------------------------------- +# system_blocks 排序常量 — 控制 system prompt 内各区块的渲染顺序 +# +# 数值为整数,按升序排列。只在 system_blocks 通道内排序,不影响其他通道。 +# 间隔为 100,便于未来在两个现有区块之间插入新区块(如 order=350 或 650) +# 而无需重新编号。 +# --------------------------------------------------------------------------- + +# LLM 安全模式提示(防注入、内容策略等)。prepend=True,强制置于 system_prompt 最前面 +SYSTEM_BLOCK_ORDER_SAFETY = 100 +# 人格/角色设定指令 — 定义 AI 的身份、语气、行为规则 +SYSTEM_BLOCK_ORDER_PERSONA = 200 +# 技能(Skills)提示词 — 描述可用技能的触发条件和行为 +SYSTEM_BLOCK_ORDER_SKILLS = 300 +# 知识库检索结果(非 agentic 模式)— 将 RAG 检索到的相关片段注入 system prompt +SYSTEM_BLOCK_ORDER_KB = 400 +# Sub-agent 路由器提示 — 多 Agent 编排场景下,指导 LLM 选择合适的子 Agent +SYSTEM_BLOCK_ORDER_ROUTER = 500 +# 运行时环境提示 — sandbox 或 local 模式下的环境说明和约束 +SYSTEM_BLOCK_ORDER_RUNTIME = 600 +# 工具调用格式说明 — 指导 LLM 如何使用注册的工具(function calling 格式等) +SYSTEM_BLOCK_ORDER_TOOL_USE = 700 +# Live 模式提示 — 实时交互模式下的特殊行为指令 +SYSTEM_BLOCK_ORDER_LIVE_MODE = 800 + +# --------------------------------------------------------------------------- +# user_append_parts 排序常量 — 控制追加到用户消息中各内容的顺序 +# +# 这些内容最终写入 req.extra_user_content_parts,作为用户消息的一部分发送, +# 而非进入 system prompt,避免干扰模型的指令遵循能力。 +# --------------------------------------------------------------------------- + +# 附件/图片描述通知 — 告知 LLM 用户上传了哪些文件或图片及其摘要 +USER_APPEND_ORDER_ATTACHMENTS = 100 +# 引用消息内容 — 用户引用/回复的历史消息文本及图片说明 +USER_APPEND_ORDER_QUOTED = 200 +# 系统提醒 — 用户 ID、昵称、群组名称、当前日期时间等元信息 +USER_APPEND_ORDER_SYSTEM_REMINDER = 300 + +# --------------------------------------------------------------------------- +# context_contributions 排序常量 — 控制插入到对话历史(req.contexts)前后的消息顺序 +# +# position="prefix" 的消息插到历史最前面(如预设对话示例), +# position="suffix" 的消息追加到历史末尾(如文件提取合成消息)。 +# --------------------------------------------------------------------------- + +# 人格预设对话示例 — persona 中配置的 begin_dialogs,作为 few-shot 示例引导 LLM 行为 +CONTEXT_ORDER_PERSONA_BEGIN_DIALOGS = 100 +# 文件提取合成消息 — 从上传文件中提取的文本内容,包装为 system 角色的合成历史消息 +CONTEXT_ORDER_FILE_EXTRACT = 100 + +# context 贡献的插入位置:prefix 插到历史最前面,suffix 追加到历史末尾 +ContextPosition = Literal["prefix", "suffix"] + + +# --------------------------------------------------------------------------- +# 数据类定义 +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class SystemBlock: + """system prompt 中的一个结构化区块。 + + Attributes: + source: 来源标识(如 "persona", "router", "sandbox"), + 用于 trace 调试时快速定位区块来源 + order: 通道内排序权重,值越小越靠前 + content: 区块的文本内容 + prepend: 是否插入到 system_prompt 最前面(而非追加到末尾)。 + 目前仅 SAFETY 区块使用 True,确保安全规则始终在最前面 + visible_in_trace: 是否在结构化 trace 快照中可见,用于控制敏感内容的可见性 + """ + + source: str + order: int + content: str + prepend: bool = False + visible_in_trace: bool = True + + +@dataclass(slots=True) +class UserAppendPart: + """追加到用户消息中的一个内容片段。 + + Attributes: + source: 来源标识(如 "attachment", "quoted_message", "system_reminder") + order: 通道内排序权重 + part: 实际内容,支持 TextPart、ImagePart 等多种类型 + visible_in_trace: 是否在 trace 快照中可见 + """ + + source: str + order: int + part: ContentPart + visible_in_trace: bool = True + + +@dataclass(slots=True) +class ContextContribution: + """插入到对话历史(req.contexts)中的一组合成消息。 + + Attributes: + source: 来源标识(如 "persona_begin_dialogs", "file_extract") + order: 通道内排序权重 + messages: OpenAI 格式的消息列表,如 [{"role": "user", "content": "..."}] + position: 插入位置:"prefix" 插到历史最前,"suffix" 追加到历史末尾 + visible_in_trace: 是否在 trace 快照中可见 + """ + + source: str + order: int + messages: list[dict] + position: ContextPosition + visible_in_trace: bool = True + + +@dataclass(slots=True) +class PromptAssembly: + """请求级别的 prompt 组装容器。 + + 在 build_main_agent() 中创建,各 core helper 向其中注册区块, + 最终由 renderer 一次性渲染回 ProviderRequest。生命周期仅限于单次请求。 + + Attributes: + system_blocks: 所有 system prompt 区块 + user_append_parts: 所有追加到用户消息的内容片段 + context_contributions: 所有插入到对话历史前后的合成消息 + metadata: 附加元数据,可用于记录请求级别的调试信息 + rendered: 是否已完成渲染,防止重复渲染 + """ + + system_blocks: list[SystemBlock] = field(default_factory=list) + user_append_parts: list[UserAppendPart] = field(default_factory=list) + context_contributions: list[ContextContribution] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + rendered: bool = False diff --git a/astrbot/core/prompt/renderer.py b/astrbot/core/prompt/renderer.py new file mode 100644 index 0000000000..71d1e92cf9 --- /dev/null +++ b/astrbot/core/prompt/renderer.py @@ -0,0 +1,86 @@ +""" +Prompt Assembly renderer. + +This is the final step of prompt assembly: core helpers register structured +blocks into a request-scoped PromptAssembly, then render_prompt_assembly() +writes the three channels back into the request fields once. +""" + +from __future__ import annotations + +import copy + +from astrbot.core.provider.entities import ProviderRequest + +from .models import PromptAssembly + + +def render_prompt_assembly( + req: ProviderRequest, + assembly: PromptAssembly, +) -> ProviderRequest: + """将 PromptAssembly 渲染回 ProviderRequest。 + + 渲染逻辑: + - system_blocks 中 prepend=True 的块按 order 排序后拼到 system_prompt 最前面, + prepend=False 的块按 order 排序后追加到末尾 + - user_append_parts 按 order 排序后依次追加到 extra_user_content_parts + - context_contributions 按 position 分为 prefix/suffix, + 分别插到 contexts 的最前面和最末尾 + + Args: + req: 待渲染的目标 ProviderRequest + assembly: 已完成注册的 PromptAssembly + + Returns: + 渲染后的同一个 req 对象(原地修改,返回引用便于链式调用) + """ + if assembly.rendered: + return req + + # TODO: Remove this compatibility shim after all helper entrypoints are + # tightened to accept only fully initialized ProviderRequest instances. + # Some legacy unit tests and helper-only call paths still pass lightweight + # request-like objects that only provide the fields they mutate. + if ( + not hasattr(req, "extra_user_content_parts") + or req.extra_user_content_parts is None + ): + req.extra_user_content_parts = [] + if not hasattr(req, "contexts") or req.contexts is None: + req.contexts = [] + + # --- system_blocks → req.system_prompt --- + # 将 prepend 块和 append 块分离,分别拼到原始 system_prompt 的前面和后面 + prepend_parts: list[str] = [] + append_parts: list[str] = [] + for block in sorted(assembly.system_blocks, key=lambda item: item.order): + if block.prepend: + prepend_parts.append(block.content) + else: + append_parts.append(block.content) + prepend_prompt = "".join(prepend_parts) + append_prompt = "".join(append_parts) + system_prompt = f"{prepend_prompt}{req.system_prompt or ''}{append_prompt}" + req.system_prompt = system_prompt + + # --- user_append_parts → req.extra_user_content_parts --- + # 按 order 排序后依次追加,保持用户消息中各片段的有序性 + for item in sorted(assembly.user_append_parts, key=lambda part: part.order): + req.extra_user_content_parts.append(item.part) + + # --- context_contributions → req.contexts --- + # 分离 prefix 和 suffix,最终组合为 prefix + 原始历史 + suffix + prefix_messages: list[dict] = [] + suffix_messages: list[dict] = [] + for contribution in sorted( + assembly.context_contributions, key=lambda item: item.order + ): + target = ( + prefix_messages if contribution.position == "prefix" else suffix_messages + ) + target.extend(copy.deepcopy(contribution.messages)) + + req.contexts = prefix_messages + list(req.contexts) + suffix_messages + assembly.rendered = True + return req diff --git a/astrbot/core/prompt/tracing.py b/astrbot/core/prompt/tracing.py new file mode 100644 index 0000000000..f3c39080d5 --- /dev/null +++ b/astrbot/core/prompt/tracing.py @@ -0,0 +1,161 @@ +""" +Prompt Assembly 追踪快照 — 生成结构化的 prompt 组装过程调试信息。 + +与现有的 astr_agent_prepare 原始 trace 互补: + - 现有 trace: 记录最终渲染后的完整 system_prompt 字符串(不可逆) + - 结构化 trace: 记录每个区块的来源、排序和内容(可追溯到具体模块) + +结构化 trace 仅包含 core 拥有的区块,不包含插件 hook 后的修改。 +插件修改仍然通过现有的原始 trace 观察。 +""" + +from __future__ import annotations + +import copy + +from astrbot.core.agent.message import ( + AudioURLPart, + ContentPart, + ImageURLPart, + TextPart, + ThinkPart, +) +from astrbot.core.provider.entities import ProviderRequest + +from .models import PromptAssembly + + +def summarize_provider_request_base(req: ProviderRequest) -> dict: + return { + "system_prompt_chars": len(req.system_prompt or ""), + "context_count": len(req.contexts), + "extra_user_part_count": len(req.extra_user_content_parts), + "image_count": len(req.image_urls), + "has_prompt": bool(req.prompt and req.prompt.strip()), + } + + +def build_prompt_trace_snapshot(assembly: PromptAssembly) -> dict: + """Build a redacted structured trace snapshot from a PromptAssembly. + + 返回格式: + { + "system_blocks": [ + {"source": "safety", "order": 100, "prepend": True, "char_count": 120}, + {"source": "persona", "order": 200, "prepend": False, "char_count": 80}, + ... + ], + "user_append_parts": [ + {"source": "system_reminder", "order": 300, "part": {...}}, + ... + ], + "context_prefix": [ + {"source": "persona_begin_dialogs", "order": 100, "messages": [...]}, + ], + "context_suffix": [ + {"source": "file_extract", "order": 100, "messages": [...]}, + ], + "metadata": {...}, + } + + 所有内容按 order 排序,且仅包含 visible_in_trace=True 的区块。 + 快照默认只记录摘要信息(如字符数、消息数、角色),不直接记录原文内容。 + + Args: + assembly: PromptAssembly to summarize. + + Returns: + Serializable dict for traces and debugging. + """ + return { + "system_blocks": [ + { + "source": block.source, + "order": block.order, + "prepend": block.prepend, + "char_count": len(block.content), + } + for block in sorted(assembly.system_blocks, key=lambda item: item.order) + if block.visible_in_trace + ], + "user_append_parts": [ + { + "source": item.source, + "order": item.order, + "part": _summarize_content_part(item.part), + } + for item in sorted(assembly.user_append_parts, key=lambda part: part.order) + if item.visible_in_trace + ], + "context_prefix": [ + { + "source": item.source, + "order": item.order, + **_summarize_messages(item.messages), + } + for item in sorted( + ( + contribution + for contribution in assembly.context_contributions + if contribution.position == "prefix" + ), + key=lambda contribution: contribution.order, + ) + if item.visible_in_trace + ], + "context_suffix": [ + { + "source": item.source, + "order": item.order, + **_summarize_messages(item.messages), + } + for item in sorted( + ( + contribution + for contribution in assembly.context_contributions + if contribution.position == "suffix" + ), + key=lambda contribution: contribution.order, + ) + if item.visible_in_trace + ], + "metadata": copy.deepcopy(assembly.metadata), + } + + +def _summarize_content_part(part: ContentPart) -> dict: + summary: dict[str, object] = {"type": part.type} + if isinstance(part, TextPart): + summary["char_count"] = len(part.text) + elif isinstance(part, ThinkPart): + summary["char_count"] = len(part.think) + summary["has_encrypted"] = bool(part.encrypted) + elif isinstance(part, ImageURLPart): + summary["has_id"] = bool(part.image_url.id) + elif isinstance(part, AudioURLPart): + summary["has_id"] = bool(part.audio_url.id) + return summary + + +def _summarize_messages(messages: list[dict]) -> dict: + roles: list[str] = [] + text_char_count = 0 + non_text_part_count = 0 + for message in messages: + roles.append(str(message.get("role", "unknown"))) + content = message.get("content") + if isinstance(content, str): + text_char_count += len(content) + continue + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text_char_count += len(str(part.get("text", ""))) + else: + non_text_part_count += 1 + return { + "message_count": len(messages), + "roles": roles, + "text_char_count": text_char_count, + "non_text_part_count": non_text_part_count, + } diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 5e99948cd2..ced9706113 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -16,6 +16,7 @@ register_on_plugin_error, register_on_plugin_loaded, register_on_plugin_unloaded, + register_on_prompt_assembly, register_on_using_llm_tool, register_on_waiting_llm_request, register_permission_type, @@ -34,6 +35,7 @@ "register_on_astrbot_loaded", "register_on_decorating_result", "register_on_llm_request", + "register_on_prompt_assembly", "register_on_llm_response", "register_on_plugin_error", "register_on_plugin_loaded", diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 1385b50566..dd4f76082f 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -437,6 +437,45 @@ def decorator(awaitable): return decorator +def register_on_prompt_assembly(**kwargs): + """当 Prompt Assembly 组装完成、渲染回 ProviderRequest 之前触发。 + + 此时核心拥有的所有 prompt block 已注册完毕,插件可通过 + ``mutation`` 参数向各通道追加内容。 + + 可用方法: + + - ``add_system(text, source, order)`` — 追加到 system prompt 末尾 + - ``add_user_text(text, source, order)`` — 追加到用户消息中 + - ``add_context_prefix(messages, source, order)`` — 插入到对话历史最前面 + - ``add_context_suffix(messages, source, order)`` — 追加到对话历史末尾 + + Note: + ``add_context_prefix`` 会改变对话历史的开头,会**降低** LLM 服务端对 + KV cache 前缀的复用率。只有确实需要这些内容参与多轮推理时才建议使用。 + + 相比之下,``add_system``、``add_user_text`` 和 ``add_context_suffix`` + 通常不会以同样的方式破坏历史前缀缓存,但它们依然会改变最终请求内容。 + + Example: + ```py + @on_prompt_assembly() + async def add_prompt_block(self, event, mutation) -> None: + mutation.add_system( + "\\n[Plugin Policy]\\nKeep answers brief.\\n", + source="plugin:my_plugin", + order=950, + ) + ``` + """ + + def decorator(awaitable): + _ = get_handler_or_create(awaitable, EventType.OnPromptAssemblyEvent, **kwargs) + return awaitable + + return decorator + + def register_on_llm_response(**kwargs): """当有 LLM 请求后的事件 diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index d28ac726ae..18a9f0b50d 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -63,6 +63,14 @@ def get_handlers_by_event_type( plugins_name: list[str] | None = None, ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnPromptAssemblyEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + @overload def get_handlers_by_event_type( self, @@ -211,6 +219,9 @@ class EventType(enum.Enum): AdapterMessageEvent = enum.auto() # 收到适配器发来的消息 OnWaitingLLMRequestEvent = enum.auto() # 等待调用 LLM(在获取锁之前,仅通知) + OnPromptAssemblyEvent = ( + enum.auto() + ) # 组装 PromptAssembly(渲染 ProviderRequest 之前) OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件) OnLLMResponseEvent = enum.auto() # LLM 响应后 OnDecoratingResultEvent = enum.auto() # 发送消息前 diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index d151bbe6f6..866379c0d8 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -77,6 +77,7 @@ def __init__( self.translated_event_type = { EventType.AdapterMessageEvent: "平台消息下发时", + EventType.OnPromptAssemblyEvent: "Prompt 组装时", EventType.OnLLMRequestEvent: "LLM 请求时", EventType.OnLLMResponseEvent: "LLM 响应后", EventType.OnDecoratingResultEvent: "回复消息前", diff --git a/tests/test_profile_aware_tools.py b/tests/test_profile_aware_tools.py index e8c2954380..a54e1c8324 100644 --- a/tests/test_profile_aware_tools.py +++ b/tests/test_profile_aware_tools.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest +from astrbot.core.prompt import PromptAssembly # ═══════════════════════════════════════════════════════════════ @@ -93,10 +94,8 @@ def test_no_session_registers_all(self): config = _make_config("shipyard_neo") req = _make_req() - with patch( - "astrbot.core.computer.computer_client.session_booter", {} - ): - fn(config, req, "session-1") + with patch("astrbot.core.computer.computer_client.session_booter", {}): + fn(config, req, "session-1", PromptAssembly()) names = self._tool_names(req) assert "astrbot_execute_browser" in names @@ -116,7 +115,7 @@ def test_with_browser_capability(self): "astrbot.core.computer.computer_client.session_booter", {"session-1": fake_booter}, ): - fn(config, req, "session-1") + fn(config, req, "session-1", PromptAssembly()) names = self._tool_names(req) assert "astrbot_execute_browser" in names @@ -126,15 +125,13 @@ def test_without_browser_capability(self): fn = _import_apply_sandbox_tools() config = _make_config("shipyard_neo") req = _make_req() - fake_booter = SimpleNamespace( - capabilities=["python", "shell", "filesystem"] - ) + fake_booter = SimpleNamespace(capabilities=["python", "shell", "filesystem"]) with patch( "astrbot.core.computer.computer_client.session_booter", {"session-1": fake_booter}, ): - fn(config, req, "session-1") + fn(config, req, "session-1", PromptAssembly()) names = self._tool_names(req) assert "astrbot_execute_browser" not in names @@ -154,7 +151,7 @@ def test_skill_tools_always_registered(self): "astrbot.core.computer.computer_client.session_booter", {"session-1": fake_booter}, ): - fn(config, req, "session-1") + fn(config, req, "session-1", PromptAssembly()) names = self._tool_names(req) assert "astrbot_create_skill_candidate" in names diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..4c3dc0a75c 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1,5 +1,6 @@ """Tests for astr_main_agent module.""" +import json import os from unittest.mock import AsyncMock, MagicMock, patch @@ -12,6 +13,7 @@ from astrbot.core.message.components import File, Image, Plain, Reply from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.prompt import PromptAssembly, render_prompt_assembly from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest @@ -70,6 +72,7 @@ def mock_event(): event.get_platform_id.return_value = "test_platform" event.get_group_id.return_value = None event.get_sender_name.return_value = "TestUser" + event.is_stopped.return_value = False event.trace = MagicMock() event.plugins_name = None return event @@ -114,6 +117,10 @@ def _setup_conversation_for_build(conv_mgr, cid: str = "conv-id") -> MagicMock: return conversation +def _render_assembly(req: ProviderRequest, assembly: PromptAssembly) -> None: + render_prompt_assembly(req, assembly) + + class TestMainAgentBuildConfig: """Tests for MainAgentBuildConfig dataclass.""" @@ -304,12 +311,14 @@ async def test_apply_kb_without_agentic_mode(self, mock_event, mock_context): config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) + assembly = PromptAssembly() with patch( "astrbot.core.astr_main_agent.retrieve_knowledge_base", AsyncMock(return_value="KB result"), ): - await module._apply_kb(mock_event, req, mock_context, config) + await module._apply_kb(mock_event, req, mock_context, config, assembly) + _render_assembly(req, assembly) assert "[Related Knowledge Base Results]:" in req.system_prompt assert "KB result" in req.system_prompt @@ -320,8 +329,9 @@ async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context): module = ama req = ProviderRequest(prompt="test question") config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True) + assembly = PromptAssembly() - await module._apply_kb(mock_event, req, mock_context, config) + await module._apply_kb(mock_event, req, mock_context, config, assembly) assert req.func_tool is not None @@ -333,8 +343,10 @@ async def test_apply_kb_no_prompt(self, mock_event, mock_context): config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) + assembly = PromptAssembly() - await module._apply_kb(mock_event, req, mock_context, config) + await module._apply_kb(mock_event, req, mock_context, config, assembly) + _render_assembly(req, assembly) assert req.system_prompt == "System" @@ -346,12 +358,14 @@ async def test_apply_kb_no_result(self, mock_event, mock_context): config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) + assembly = PromptAssembly() with patch( "astrbot.core.astr_main_agent.retrieve_knowledge_base", AsyncMock(return_value=None), ): - await module._apply_kb(mock_event, req, mock_context, config) + await module._apply_kb(mock_event, req, mock_context, config, assembly) + _render_assembly(req, assembly) assert req.system_prompt == "System" @@ -362,8 +376,9 @@ async def test_apply_kb_with_existing_tools(self, mock_event, mock_context): existing_tools = ToolSet() req = ProviderRequest(prompt="test", func_tool=existing_tools) config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True) + assembly = PromptAssembly() - await module._apply_kb(mock_event, req, mock_context, config) + await module._apply_kb(mock_event, req, mock_context, config, assembly) assert req.func_tool is not None @@ -381,16 +396,25 @@ async def test_file_extract_basic(self, mock_event, sample_config): mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt="Summarize") + assembly = PromptAssembly() with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "File content" - await module._apply_file_extract(mock_event, req, sample_config) + await module._apply_file_extract(mock_event, req, sample_config, assembly) + _render_assembly(req, assembly) - assert len(req.contexts) == 1 - assert "File Extract Results" in req.contexts[0]["content"] + assert req.contexts == [ + { + "role": "system", + "content": ( + "File Extract Results of user uploaded files:\n" + "File content\nFile Name: test.pdf" + ), + } + ] @pytest.mark.asyncio async def test_file_extract_no_files(self, mock_event, sample_config): @@ -398,8 +422,10 @@ async def test_file_extract_no_files(self, mock_event, sample_config): module = ama mock_event.message_obj.message = [Plain(text="Hello")] req = ProviderRequest(prompt="Hello") + assembly = PromptAssembly() - await module._apply_file_extract(mock_event, req, sample_config) + await module._apply_file_extract(mock_event, req, sample_config, assembly) + _render_assembly(req, assembly) assert len(req.contexts) == 0 @@ -415,13 +441,15 @@ async def test_file_extract_in_reply(self, mock_event, sample_config): mock_event.message_obj.message = [mock_reply] req = ProviderRequest(prompt="Summarize") + assembly = PromptAssembly() with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "Reply content" - await module._apply_file_extract(mock_event, req, sample_config) + await module._apply_file_extract(mock_event, req, sample_config, assembly) + _render_assembly(req, assembly) assert len(req.contexts) == 1 @@ -435,13 +463,14 @@ async def test_file_extract_no_prompt(self, mock_event, sample_config): mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt=None) + assembly = PromptAssembly() with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "Content" - await module._apply_file_extract(mock_event, req, sample_config) + await module._apply_file_extract(mock_event, req, sample_config, assembly) assert req.prompt == "总结一下文件里面讲了什么?" @@ -460,8 +489,10 @@ async def test_file_extract_no_api_key(self, mock_event): mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt="Summarize") + assembly = PromptAssembly() - await module._apply_file_extract(mock_event, req, config) + await module._apply_file_extract(mock_event, req, config, assembly) + _render_assembly(req, assembly) assert len(req.contexts) == 0 @@ -481,8 +512,12 @@ async def test_ensure_persona_from_session(self, mock_event, mock_context): mock_event.trace = MagicMock(record=MagicMock()) req = ProviderRequest() req.conversation = MagicMock(persona_id=None) + assembly = PromptAssembly() - await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + await module._ensure_persona_and_skills( + req, {}, mock_context, mock_event, assembly + ) + _render_assembly(req, assembly) assert "You are helpful." in req.system_prompt @@ -497,8 +532,12 @@ async def test_ensure_persona_from_conversation(self, mock_event, mock_context): ) req = ProviderRequest() req.conversation = MagicMock(persona_id="conv-persona") + assembly = PromptAssembly() - await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + await module._ensure_persona_and_skills( + req, {}, mock_context, mock_event, assembly + ) + _render_assembly(req, assembly) assert "Custom persona." in req.system_prompt @@ -512,8 +551,12 @@ async def test_ensure_persona_none_explicit(self, mock_event, mock_context): ) req = ProviderRequest() req.conversation = MagicMock(persona_id="[%None]") + assembly = PromptAssembly() - await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + await module._ensure_persona_and_skills( + req, {}, mock_context, mock_event, assembly + ) + _render_assembly(req, assembly) assert "Persona Instructions" not in req.system_prompt @@ -534,8 +577,11 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context): req = ProviderRequest() req.conversation = MagicMock(persona_id="persona") + assembly = PromptAssembly() - await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + await module._ensure_persona_and_skills( + req, {}, mock_context, mock_event, assembly + ) assert req.func_tool is not None @@ -565,9 +611,10 @@ async def test_subagent_dedupe_uses_default_persona_tools( tmgr = mock_context.get_llm_tool_manager.return_value tmgr.func_list = [tool_a, tool_b] tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b]) - tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get( - name - ) + tmgr.get_func.side_effect = lambda name: { + "tool_a": tool_a, + "tool_b": tool_b, + }.get(name) handoff = MagicMock() handoff.name = "transfer_to_planner" @@ -588,8 +635,11 @@ async def test_subagent_dedupe_uses_default_persona_tools( req = ProviderRequest() req.conversation = MagicMock(persona_id=None) + assembly = PromptAssembly() - await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + await module._ensure_persona_and_skills( + req, {}, mock_context, mock_event, assembly + ) assert req.func_tool is not None assert "transfer_to_planner" in req.func_tool.names() @@ -607,8 +657,12 @@ async def test_decorate_llm_request_basic( """Test basic LLM request decoration.""" module = ama req = ProviderRequest(prompt="Hello", system_prompt="System") + assembly = PromptAssembly() - await module._decorate_llm_request(mock_event, req, mock_context, sample_config) + await module._decorate_llm_request( + mock_event, req, mock_context, sample_config, assembly + ) + _render_assembly(req, assembly) assert req.prompt == "Hello" assert req.system_prompt == "System" @@ -621,11 +675,14 @@ async def test_decorate_llm_request_with_prefix(self, mock_event, mock_context): config = module.MainAgentBuildConfig( tool_call_timeout=60, provider_settings={"prompt_prefix": "AI: "} ) + assembly = PromptAssembly() with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} - await module._decorate_llm_request(mock_event, req, mock_context, config) + await module._decorate_llm_request( + mock_event, req, mock_context, config, assembly + ) assert req.prompt == "AI: Hello" @@ -640,14 +697,46 @@ async def test_decorate_llm_request_prefix_with_placeholder( tool_call_timeout=60, provider_settings={"prompt_prefix": "AI {{prompt}} - Please respond:"}, ) + assembly = PromptAssembly() with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} - await module._decorate_llm_request(mock_event, req, mock_context, config) + await module._decorate_llm_request( + mock_event, req, mock_context, config, assembly + ) assert req.prompt == "AI Hello - Please respond:" + @pytest.mark.asyncio + async def test_decorate_llm_request_keeps_quote_and_system_reminder_in_user_parts( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello") + quote = MagicMock(spec=Reply) + quote.sender_nickname = "Alice" + quote.message_str = "Quoted hi" + quote.chain = [] + mock_event.message_obj.message = [quote] + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={"identifier": True, "datetime_system_prompt": False}, + ) + assembly = PromptAssembly() + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request( + mock_event, req, mock_context, config, assembly + ) + _render_assembly(req, assembly) + + assert [part.text for part in req.extra_user_content_parts] == [ + "\n(Alice): Quoted hi\n", + "User ID: user123, Nickname: TestUser", + ] + @pytest.mark.asyncio async def test_decorate_llm_request_no_conversation(self, mock_event, mock_context): """Test decoration when no conversation exists.""" @@ -655,11 +744,15 @@ async def test_decorate_llm_request_no_conversation(self, mock_event, mock_conte req = ProviderRequest(prompt="Hello") req.conversation = None config = module.MainAgentBuildConfig(tool_call_timeout=60) + assembly = PromptAssembly() with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} - await module._decorate_llm_request(mock_event, req, mock_context, config) + await module._decorate_llm_request( + mock_event, req, mock_context, config, assembly + ) + _render_assembly(req, assembly) assert req.prompt == "Hello" @@ -1103,7 +1196,273 @@ async def test_build_main_agent_with_existing_request( ) assert result is not None - assert result.provider_request == existing_req + + @pytest.mark.asyncio + async def test_build_main_agent_prompt_assembly_ordering( + self, mock_event, mock_context, mock_provider + ): + module = ama + mock_event.platform_meta.support_proactive_message = True + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = { + "provider_settings": {}, + "subagent_orchestrator": { + "main_enable": True, + "router_system_prompt": "ROUTER_BLOCK", + }, + } + conversation = _setup_conversation_for_build(mock_context.conversation_manager) + conversation.persona_id = "persona-1" + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=( + "persona-1", + { + "prompt": "PERSONA_BLOCK", + "skills": None, + "tools": None, + "_begin_dialogs_processed": [], + }, + None, + False, + ) + ) + tool_manager = MagicMock() + tool_manager.get_full_tool_set.return_value = ToolSet() + tool_manager.func_list = [] + mock_context.get_llm_tool_manager.return_value = tool_manager + mock_context.subagent_orchestrator = MagicMock(handoffs=[]) + + with ( + patch.object(module, "SkillManager") as mock_skill_manager_cls, + patch.object(module, "build_skills_prompt", return_value="SKILLS_BLOCK"), + patch.object( + module, "_build_local_mode_prompt", return_value="LOCAL_BLOCK" + ), + patch.object(module, "LLM_SAFETY_MODE_SYSTEM_PROMPT", "SAFE_BLOCK"), + patch.object(module, "TOOL_CALL_PROMPT", "TOOL_BLOCK"), + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_skill_manager = MagicMock() + mock_skill_manager.list_skills.return_value = [MagicMock(name="skill-a")] + mock_skill_manager_cls.return_value = mock_skill_manager + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + llm_safety_mode=False, + ), + ) + + assert result is not None + assert result.provider_request.system_prompt == ( + "\n# Persona Instructions\n\nPERSONA_BLOCK\n" + "\nSKILLS_BLOCK\n" + "\nROUTER_BLOCK\n" + "\nLOCAL_BLOCK\n" + "\nTOOL_BLOCK\n" + ) + core_trace = next( + call + for call in mock_event.trace.record.call_args_list + if call.args and call.args[0] == "core_prompt_assembly" + ) + assert [item["source"] for item in core_trace.kwargs["system_blocks"]] == [ + "persona", + "skills", + "router", + "runtime:local", + "tool_use", + ] + + @pytest.mark.asyncio + async def test_build_main_agent_prepends_persona_begin_dialogs( + self, mock_event, mock_context, mock_provider + ): + module = ama + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + conversation = _setup_conversation_for_build(mock_context.conversation_manager) + conversation.persona_id = "persona-1" + conversation.history = json.dumps([{"role": "user", "content": "history"}]) + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=( + "persona-1", + { + "prompt": "PERSONA_BLOCK", + "skills": None, + "tools": None, + "_begin_dialogs_processed": [ + {"role": "assistant", "content": "example"} + ], + }, + None, + False, + ) + ) + tool_manager = MagicMock() + tool_manager.get_full_tool_set.return_value = ToolSet() + tool_manager.func_list = [] + mock_context.get_llm_tool_manager.return_value = tool_manager + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + llm_safety_mode=False, + computer_use_runtime="none", + add_cron_tools=False, + ), + ) + + assert result is not None + assert result.provider_request.contexts == [ + {"role": "assistant", "content": "example"}, + {"role": "user", "content": "history"}, + ] + + @pytest.mark.asyncio + async def test_build_main_agent_prompt_assembly_hook_can_add_blocks( + self, mock_event, mock_context, mock_provider + ): + module = ama + mock_event.platform_meta.support_proactive_message = False + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + _setup_conversation_for_build(mock_context.conversation_manager) + tool_manager = MagicMock() + tool_manager.get_full_tool_set.return_value = ToolSet() + tool_manager.func_list = [] + mock_context.get_llm_tool_manager.return_value = tool_manager + + async def fake_call_event_hook(event, hook_type, *args): + if hook_type == module.EventType.OnPromptAssemblyEvent: + mutation = args[0] + mutation.add_system("\nPLUGIN_SYSTEM\n", "plugin:test", 950) + mutation.add_user_text("plugin user", "plugin:test", 950) + mutation.add_context_prefix( + [{"role": "system", "content": "plugin prefix"}], + "plugin:test", + 950, + ) + mutation.add_context_suffix( + [{"role": "assistant", "content": "plugin suffix"}], + "plugin:test", + 950, + ) + return False + + with ( + patch.object( + module, "call_event_hook", AsyncMock(side_effect=fake_call_event_hook) + ), + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + llm_safety_mode=False, + computer_use_runtime="none", + add_cron_tools=False, + ), + ) + + assert result is not None + assert result.provider_request.system_prompt == "\nPLUGIN_SYSTEM\n" + assert [ + part.text for part in result.provider_request.extra_user_content_parts + ] == [ + "plugin user", + ] + assert result.provider_request.contexts == [ + {"role": "system", "content": "plugin prefix"}, + {"role": "assistant", "content": "plugin suffix"}, + ] + core_trace = next( + call + for call in mock_event.trace.record.call_args_list + if call.args and call.args[0] == "core_prompt_assembly" + ) + assert core_trace.kwargs["system_blocks"] == [] + assert core_trace.kwargs["user_append_parts"] == [] + assert core_trace.kwargs["context_prefix"] == [] + assert core_trace.kwargs["context_suffix"] == [] + assert core_trace.kwargs["metadata"]["base_request"]["has_prompt"] is True + + @pytest.mark.asyncio + async def test_build_main_agent_prompt_assembly_hook_can_cancel_request( + self, mock_event, mock_context, mock_provider + ): + module = ama + mock_event.platform_meta.support_proactive_message = False + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + _setup_conversation_for_build(mock_context.conversation_manager) + tool_manager = MagicMock() + tool_manager.get_full_tool_set.return_value = ToolSet() + tool_manager.func_list = [] + mock_context.get_llm_tool_manager.return_value = tool_manager + + async def fake_call_event_hook(event, hook_type, *args): + return hook_type == module.EventType.OnPromptAssemblyEvent + + with ( + patch.object( + module, "call_event_hook", AsyncMock(side_effect=fake_call_event_hook) + ), + patch.object( + module, "render_prompt_assembly" + ) as mock_render_prompt_assembly, + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + llm_safety_mode=False, + computer_use_runtime="none", + add_cron_tools=False, + ), + ) + + assert result is None + mock_render_prompt_assembly.assert_not_called() + mock_runner.reset.assert_not_called() + assert not any( + call.args and call.args[0] == "core_prompt_assembly" + for call in mock_event.trace.record.call_args_list + ) class TestHandleWebchat: @@ -1363,8 +1722,10 @@ def test_apply_llm_safety_mode_system_prompt_strategy(self): safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="Original prompt") + assembly = PromptAssembly() - module._apply_llm_safety_mode(config, req) + module._apply_llm_safety_mode(config, req, assembly) + _render_assembly(req, assembly) assert "You are running in Safe Mode" in req.system_prompt assert "Original prompt" in req.system_prompt @@ -1377,8 +1738,10 @@ def test_apply_llm_safety_mode_prepends_safety_prompt(self): safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="My custom prompt") + assembly = PromptAssembly() - module._apply_llm_safety_mode(config, req) + module._apply_llm_safety_mode(config, req, assembly) + _render_assembly(req, assembly) assert req.system_prompt.startswith("You are running in Safe Mode") assert "My custom prompt" in req.system_prompt @@ -1391,8 +1754,10 @@ def test_apply_llm_safety_mode_with_none_system_prompt(self): safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt=None) + assembly = PromptAssembly() - module._apply_llm_safety_mode(config, req) + module._apply_llm_safety_mode(config, req, assembly) + _render_assembly(req, assembly) assert "You are running in Safe Mode" in req.system_prompt @@ -1404,9 +1769,11 @@ def test_apply_llm_safety_mode_unsupported_strategy(self): safety_mode_strategy="unsupported_strategy", ) req = ProviderRequest(prompt="Test", system_prompt="Original") + assembly = PromptAssembly() with patch("astrbot.core.astr_main_agent.logger") as mock_logger: - module._apply_llm_safety_mode(config, req) + module._apply_llm_safety_mode(config, req, assembly) + _render_assembly(req, assembly) mock_logger.warning.assert_called_once() assert ( @@ -1423,8 +1790,10 @@ def test_apply_llm_safety_mode_empty_system_prompt(self): safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="") + assembly = PromptAssembly() - module._apply_llm_safety_mode(config, req) + module._apply_llm_safety_mode(config, req, assembly) + _render_assembly(req, assembly) assert "You are running in Safe Mode" in req.system_prompt @@ -1441,8 +1810,9 @@ def test_apply_sandbox_tools_creates_toolset_if_none(self): sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", func_tool=None) + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) assert req.func_tool is not None assert isinstance(req.func_tool, ToolSet) @@ -1456,8 +1826,9 @@ def test_apply_sandbox_tools_adds_required_tools(self): sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", func_tool=None) + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) tool_names = req.func_tool.names() assert "astrbot_execute_shell" in tool_names @@ -1474,8 +1845,10 @@ def test_apply_sandbox_tools_adds_sandbox_prompt(self): sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt="Original prompt") + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) + _render_assembly(req, assembly) assert "sandboxed environment" in req.system_prompt @@ -1492,11 +1865,12 @@ def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch): }, ) req = ProviderRequest(prompt="Test", func_tool=None) + assembly = PromptAssembly() monkeypatch.delenv("SHIPYARD_ENDPOINT", raising=False) monkeypatch.delenv("SHIPYARD_ACCESS_TOKEN", raising=False) - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com" assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token" @@ -1514,9 +1888,10 @@ def test_apply_sandbox_tools_shipyard_missing_endpoint(self): }, ) req = ProviderRequest(prompt="Test", func_tool=None) + assembly = PromptAssembly() with patch("astrbot.core.astr_main_agent.logger") as mock_logger: - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) mock_logger.error.assert_called_once() assert ( @@ -1537,9 +1912,10 @@ def test_apply_sandbox_tools_shipyard_missing_access_token(self): }, ) req = ProviderRequest(prompt="Test", func_tool=None) + assembly = PromptAssembly() with patch("astrbot.core.astr_main_agent.logger") as mock_logger: - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) mock_logger.error.assert_called_once() @@ -1556,8 +1932,9 @@ def test_apply_sandbox_tools_preserves_existing_toolset(self): existing_tool.name = "existing_tool" existing_toolset.add_tool(existing_tool) req = ProviderRequest(prompt="Test", func_tool=existing_toolset) + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) assert "existing_tool" in req.func_tool.names() assert "astrbot_execute_shell" in req.func_tool.names() @@ -1571,8 +1948,10 @@ def test_apply_sandbox_tools_appends_to_existing_system_prompt(self): sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt="Base prompt") + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) + _render_assembly(req, assembly) assert req.system_prompt.startswith("Base prompt") assert "sandboxed environment" in req.system_prompt @@ -1586,8 +1965,10 @@ def test_apply_sandbox_tools_with_none_system_prompt(self): sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt=None) + assembly = PromptAssembly() - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req, "session-123", assembly) + _render_assembly(req, assembly) assert isinstance(req.system_prompt, str) assert "sandboxed environment" in req.system_prompt diff --git a/tests/unit/test_internal_agent_trace.py b/tests/unit/test_internal_agent_trace.py new file mode 100644 index 0000000000..248b0f1c62 --- /dev/null +++ b/tests/unit/test_internal_agent_trace.py @@ -0,0 +1,135 @@ +"""Focused tests for InternalAgentSubStage trace behavior.""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, +) +from astrbot.core.astr_main_agent import MainAgentBuildConfig +from astrbot.core.provider.entities import LLMResponse, ProviderRequest + + +@pytest.mark.asyncio +async def test_internal_agent_prepare_trace_keeps_string_system_prompt(): + stage = InternalAgentSubStage() + stage.ctx = MagicMock() + stage.ctx.plugin_manager = MagicMock() + stage.ctx.plugin_manager.context = MagicMock() + stage.streaming_response = False + stage.unsupported_streaming_strategy = "turn_off" + stage.max_step = 3 + stage.show_tool_use = False + stage.show_tool_call_result = False + stage.show_reasoning = False + stage.main_agent_cfg = MainAgentBuildConfig(tool_call_timeout=60) + stage._save_to_history = AsyncMock() + + provider = MagicMock() + provider.provider_config = {"id": "provider-1", "api_base": ""} + provider.get_model.return_value = "gpt-4" + + req = ProviderRequest(prompt="hello", system_prompt="SYSTEM") + + agent_runner = MagicMock() + agent_runner.done.return_value = True + agent_runner.get_final_llm_resp.return_value = LLMResponse( + role="assistant", completion_text="done" + ) + agent_runner.stats = MagicMock() + agent_runner.stats.to_dict.return_value = {} + agent_runner.run_context.messages = [] + agent_runner.was_aborted.return_value = False + agent_runner.provider = provider + + async def noop(): + return None + + build_result = MagicMock( + agent_runner=agent_runner, + provider_request=req, + provider=provider, + reset_coro=noop(), + ) + + @asynccontextmanager + async def fake_lock(*args, **kwargs): + yield + + async def fake_run_agent(*args, **kwargs): + if False: + yield None + + def consume_task(coro): + coro.close() + return MagicMock() + + event = MagicMock() + event.message_str = "hello" + event.message_obj.message = [] + event.unified_msg_origin = "test:private:1" + event.platform_meta.support_streaming_message = False + event.get_extra.return_value = None + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + event.set_result = MagicMock() + event.is_stopped.return_value = False + event.trace = MagicMock() + + with ( + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.build_main_agent", + AsyncMock(return_value=build_result), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.call_event_hook", + AsyncMock(return_value=False), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.session_lock_manager.acquire_lock", + fake_lock, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.try_capture_follow_up", + return_value=None, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.register_active_runner" + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.unregister_active_runner" + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.run_agent", + fake_run_agent, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.asyncio.create_task", + side_effect=consume_task, + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal._record_internal_agent_stats", + AsyncMock(), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.Metric.upload", + AsyncMock(), + ), + patch( + "astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal.decoded_blocked", + [], + create=True, + ), + ): + async for _ in stage.process(event, ""): + pass + + prepare_trace = next( + call + for call in event.trace.record.call_args_list + if call.args and call.args[0] == "astr_agent_prepare" + ) + assert prepare_trace.kwargs["system_prompt"] == "SYSTEM" + assert isinstance(prepare_trace.kwargs["system_prompt"], str) diff --git a/tests/unit/test_prompt_renderer.py b/tests/unit/test_prompt_renderer.py new file mode 100644 index 0000000000..5d16897c7c --- /dev/null +++ b/tests/unit/test_prompt_renderer.py @@ -0,0 +1,274 @@ +"""Tests for prompt assembly rendering helpers.""" + +from unittest.mock import patch + +from astrbot.core.agent.message import TextPart +from astrbot.core.prompt import ( + PromptAssembly, + PromptMutation, + add_context_prefix, + add_context_suffix, + add_system_block, + add_user_part, + build_prompt_trace_snapshot, + render_prompt_assembly, + summarize_provider_request_base, +) +from astrbot.core.provider.entities import ProviderRequest + + +def test_render_prompt_assembly_orders_channels(): + req = ProviderRequest( + prompt="hello", + system_prompt="BASE", + contexts=[{"role": "user", "content": "history"}], + ) + assembly = PromptAssembly() + + add_system_block(assembly, source="skills", order=300, content="\nSKILLS\n") + add_system_block(assembly, source="persona", order=200, content="\nPERSONA\n") + add_system_block( + assembly, + source="safety", + order=100, + content="SAFE\n", + prepend=True, + ) + add_user_part( + assembly, + source="quoted", + order=200, + part=TextPart(text="quoted"), + ) + add_user_part( + assembly, + source="attachment", + order=100, + part=TextPart(text="attachment"), + ) + add_context_prefix( + assembly, + source="prefix", + order=100, + messages=[{"role": "system", "content": "prefix"}], + ) + add_context_suffix( + assembly, + source="suffix", + order=100, + messages=[{"role": "assistant", "content": "suffix"}], + ) + + render_prompt_assembly(req, assembly) + + assert req.system_prompt == "SAFE\nBASE\nPERSONA\n\nSKILLS\n" + assert [part.text for part in req.extra_user_content_parts] == [ + "attachment", + "quoted", + ] + assert req.contexts == [ + {"role": "system", "content": "prefix"}, + {"role": "user", "content": "history"}, + {"role": "assistant", "content": "suffix"}, + ] + + +def test_render_prompt_assembly_is_render_once(): + req = ProviderRequest(prompt="hello") + assembly = PromptAssembly() + add_system_block(assembly, source="persona", order=200, content="\nPERSONA\n") + add_user_part( + assembly, + source="attachment", + order=100, + part=TextPart(text="attachment"), + ) + + render_prompt_assembly(req, assembly) + render_prompt_assembly(req, assembly) + + assert req.system_prompt == "\nPERSONA\n" + assert [part.text for part in req.extra_user_content_parts] == ["attachment"] + + +def test_render_prompt_assembly_empty_is_noop(): + req = ProviderRequest( + prompt="hello", + system_prompt="BASE", + contexts=[{"role": "user", "content": "history"}], + ) + + render_prompt_assembly(req, PromptAssembly()) + + assert req.system_prompt == "BASE" + assert req.extra_user_content_parts == [] + assert req.contexts == [{"role": "user", "content": "history"}] + + +def test_build_prompt_trace_snapshot_contains_sorted_sources(): + req = ProviderRequest( + prompt="hello", + system_prompt="BASE", + contexts=[{"role": "user", "content": "history"}], + ) + assembly = PromptAssembly( + metadata={ + "kind": "test", + "base_request": summarize_provider_request_base(req), + } + ) + add_system_block(assembly, source="skills", order=300, content="\nSKILLS\n") + add_system_block( + assembly, + source="persona", + order=200, + content="\nPERSONA\n", + prepend=True, + ) + add_user_part( + assembly, + source="attachment", + order=100, + part=TextPart(text="attachment"), + ) + add_context_suffix( + assembly, + source="file_extract", + order=100, + messages=[{"role": "system", "content": "suffix"}], + ) + + snapshot = build_prompt_trace_snapshot(assembly) + + assert [item["source"] for item in snapshot["system_blocks"]] == [ + "persona", + "skills", + ] + assert snapshot["user_append_parts"] == [ + { + "source": "attachment", + "order": 100, + "part": {"type": "text", "char_count": 10}, + } + ] + assert snapshot["system_blocks"] == [ + { + "source": "persona", + "order": 200, + "prepend": True, + "char_count": 9, + }, + { + "source": "skills", + "order": 300, + "prepend": False, + "char_count": 8, + }, + ] + assert snapshot["context_prefix"] == [] + assert snapshot["context_suffix"] == [ + { + "source": "file_extract", + "order": 100, + "message_count": 1, + "roles": ["system"], + "text_char_count": 6, + "non_text_part_count": 0, + } + ] + assert snapshot["metadata"] == { + "kind": "test", + "base_request": { + "system_prompt_chars": 4, + "context_count": 1, + "extra_user_part_count": 0, + "image_count": 0, + "has_prompt": True, + }, + } + assert "content" not in snapshot["system_blocks"][0] + assert "messages" not in snapshot["context_suffix"][0] + + +def test_prompt_mutation_facade_dispatches_to_helper_functions(): + assembly = PromptAssembly() + mutation = PromptMutation(assembly) + + with ( + patch("astrbot.core.prompt.assembly.add_system_block") as mock_add_system, + patch("astrbot.core.prompt.assembly.add_user_text") as mock_add_user_text, + patch( + "astrbot.core.prompt.assembly.add_context_prefix" + ) as mock_add_context_prefix, + patch( + "astrbot.core.prompt.assembly.add_context_suffix" + ) as mock_add_context_suffix, + patch("astrbot.core.prompt.assembly.logger.warning"), + ): + mutation.add_system("\nPLUGIN\n", "plugin:test", 950) + mutation.add_user_text("plugin user", "plugin:test", 950) + mutation.add_context_prefix( + [{"role": "system", "content": "prefix"}], "plugin:test", 950 + ) + mutation.add_context_suffix( + [{"role": "assistant", "content": "suffix"}], "plugin:test", 950 + ) + + mock_add_system.assert_called_once_with( + assembly, + source="plugin:test", + order=950, + content="\nPLUGIN\n", + visible_in_trace=True, + ) + mock_add_user_text.assert_called_once_with( + assembly, + source="plugin:test", + order=950, + text="plugin user", + visible_in_trace=True, + ) + mock_add_context_prefix.assert_called_once_with( + assembly, + source="plugin:test", + order=950, + messages=[{"role": "system", "content": "prefix"}], + visible_in_trace=True, + ) + mock_add_context_suffix.assert_called_once_with( + assembly, + source="plugin:test", + order=950, + messages=[{"role": "assistant", "content": "suffix"}], + visible_in_trace=True, + ) + + +def test_prompt_mutation_warns_once_for_reserved_plugin_order(): + assembly = PromptAssembly() + mutation = PromptMutation(assembly) + + with patch("astrbot.core.prompt.assembly.logger.warning") as mock_warning: + mutation.add_system("\nPLUGIN\n", "plugin:test", 850) + mutation.add_user_text("plugin user", "plugin:test", 850) + + mock_warning.assert_called_once() + + +def test_prompt_mutation_warns_once_for_context_prefix_cache_impact(): + assembly = PromptAssembly() + mutation = PromptMutation(assembly) + + with patch("astrbot.core.prompt.assembly.logger.warning") as mock_warning: + mutation.add_context_prefix( + [{"role": "system", "content": "prefix"}], + "plugin:test", + 950, + ) + mutation.add_context_prefix( + [{"role": "system", "content": "prefix again"}], + "plugin:test", + 950, + ) + + mock_warning.assert_called_once()