diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 345ad7b743..fefbb8580b 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -172,6 +172,75 @@ async def pop_record(self, context: list) -> None: for idx in reversed(indexs_to_pop): context.pop(idx) + context[:] = self._fix_tool_call_pairs_in_dict_context(context) + + @staticmethod + def _fix_tool_call_pairs_in_dict_context(context: list[dict]) -> list[dict]: + """Remove orphaned tool call chains from dict-based message history.""" + if not context: + return context + + fixed_context: list[dict] = [] + pending_assistant: dict | None = None + pending_tools: list[dict] = [] + + def flush_pending_if_valid() -> None: + nonlocal pending_assistant, pending_tools + if pending_assistant is not None and Provider._is_complete_tool_chain( + pending_assistant, + pending_tools, + ): + fixed_context.append(pending_assistant) + fixed_context.extend(pending_tools) + pending_assistant = None + pending_tools = [] + + for message in context: + role = message.get("role") + if role == "tool": + if pending_assistant is not None: + pending_tools.append(message) + continue + + if role == "assistant" and message.get("tool_calls"): + flush_pending_if_valid() + pending_assistant = message + continue + + flush_pending_if_valid() + fixed_context.append(message) + + flush_pending_if_valid() + return fixed_context + + @staticmethod + def _is_complete_tool_chain( + assistant_message: dict, + tool_messages: list[dict], + ) -> bool: + """Check whether a dict-based assistant/tool chain is fully paired.""" + tool_calls = assistant_message.get("tool_calls") or [] + expected_ids = [ + tool_call.get("id") + for tool_call in tool_calls + if tool_call.get("id") is not None + ] + if not expected_ids or len(expected_ids) != len(tool_calls): + return False + + seen_ids: set[str] = set() + for tool_message in tool_messages: + tool_call_id = tool_message.get("tool_call_id") + if ( + tool_call_id is None + or tool_call_id not in expected_ids + or tool_call_id in seen_ids + ): + return False + seen_ids.add(tool_call_id) + + return len(seen_ids) == len(expected_ids) + def _ensure_message_to_dicts( self, messages: list[dict] | list[Message] | None, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 0040f0be62..869987b560 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -165,6 +165,144 @@ async def test_handle_api_error_model_not_vlm_after_fallback_raises(): await provider.terminate() +@pytest.mark.asyncio +async def test_handle_api_error_context_length_removes_orphaned_tool_messages(): + provider = _make_provider() + try: + payloads = { + "messages": [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "Run tool"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "content": "Tool result", "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Final answer"}, + ] + } + context_query = payloads["messages"] + + success, *_rest = await provider._handle_api_error( + Exception("maximum context length exceeded"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + + assert success is False + assert payloads["messages"] == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "Final answer"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_fix_tool_call_pairs_in_dict_context_removes_partial_multi_tool_chain(): + provider = _make_provider() + try: + full_chain = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "Run multiple tools"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "tool_a", "arguments": "{}"}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "tool_b", "arguments": "{}"}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "name": "tool_a", + "content": "result a", + }, + { + "role": "tool", + "tool_call_id": "call_2", + "name": "tool_b", + "content": "result b", + }, + {"role": "assistant", "content": "Final answer"}, + ] + + assert provider._fix_tool_call_pairs_in_dict_context(full_chain) == full_chain + + missing_assistant = full_chain[:2] + full_chain[3:] + assert provider._fix_tool_call_pairs_in_dict_context(missing_assistant) == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "Run multiple tools"}, + {"role": "assistant", "content": "Final answer"}, + ] + + missing_one_tool = full_chain[:-2] + [full_chain[-1]] + assert provider._fix_tool_call_pairs_in_dict_context(missing_one_tool) == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "Run multiple tools"}, + {"role": "assistant", "content": "Final answer"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_handle_api_error_context_length_preserves_remaining_valid_messages(): + provider = _make_provider() + try: + payloads = { + "messages": [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] + } + context_query = payloads["messages"] + + success, *_rest = await provider._handle_api_error( + Exception("maximum context length exceeded"), + payloads=payloads, + context_query=context_query, + func_tool=None, + chosen_key="test-key", + available_api_keys=["test-key"], + retry_cnt=0, + max_retries=10, + ) + + assert success is False + assert payloads["messages"] == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_handle_api_error_content_moderated_with_unserializable_body(): provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})