Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,75 @@ async def pop_record(self, context: list) -> None:
for idx in reversed(indexs_to_pop):
context.pop(idx)

context[:] = self._fix_tool_call_pairs_in_dict_context(context)

@staticmethod
def _fix_tool_call_pairs_in_dict_context(context: list[dict]) -> list[dict]:
"""Remove orphaned tool call chains from dict-based message history."""
if not context:
return context

fixed_context: list[dict] = []
pending_assistant: dict | None = None
pending_tools: list[dict] = []

def flush_pending_if_valid() -> None:
nonlocal pending_assistant, pending_tools
if pending_assistant is not None and Provider._is_complete_tool_chain(
pending_assistant,
pending_tools,
):
fixed_context.append(pending_assistant)
fixed_context.extend(pending_tools)
pending_assistant = None
pending_tools = []

for message in context:
role = message.get("role")
if role == "tool":
if pending_assistant is not None:
pending_tools.append(message)
continue

if role == "assistant" and message.get("tool_calls"):
flush_pending_if_valid()
pending_assistant = message
continue

flush_pending_if_valid()
fixed_context.append(message)

flush_pending_if_valid()
return fixed_context

@staticmethod
def _is_complete_tool_chain(
assistant_message: dict,
tool_messages: list[dict],
) -> bool:
"""Check whether a dict-based assistant/tool chain is fully paired."""
tool_calls = assistant_message.get("tool_calls") or []
expected_ids = [
tool_call.get("id")
for tool_call in tool_calls
if tool_call.get("id") is not None
]
if not expected_ids or len(expected_ids) != len(tool_calls):
return False

seen_ids: set[str] = set()
for tool_message in tool_messages:
tool_call_id = tool_message.get("tool_call_id")
if (
tool_call_id is None
or tool_call_id not in expected_ids
or tool_call_id in seen_ids
):
return False
seen_ids.add(tool_call_id)

return len(seen_ids) == len(expected_ids)

def _ensure_message_to_dicts(
self,
messages: list[dict] | list[Message] | None,
Expand Down
138 changes: 138 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,144 @@ async def test_handle_api_error_model_not_vlm_after_fallback_raises():
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():
Comment on lines +168 to +169
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test for multiple tool calls and their tool responses to ensure whole tool-call chains are cleaned up correctly after truncation.

The current test covers a single assistant tool_calls → tool message pair. To better exercise _fix_tool_call_pairs_in_dict_context, please add a variant where one assistant message has multiple tool_calls with multiple corresponding tool messages (e.g., call_1, call_2). Then simulate truncation that drops part of a chain (some tools, or the assistant but not all tools) and assert that payloads['messages'] never contains partial or orphaned tool chains—only fully intact chains or none at all.

Suggested change
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_multi_tool_chains():
provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run multiple tools"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "tool_a", "arguments": "{}"},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "tool_b", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"name": "tool_a",
"content": "result a",
},
{
"role": "tool",
"tool_call_id": "call_2",
"name": "tool_b",
"content": "result b",
},
]
}
# Case 1: truncate away the assistant message but leave tool messages
truncated = {
"messages": payloads["messages"][:2] + payloads["messages"][3:]
}
provider._fix_tool_call_pairs_in_dict_context(truncated)
# No orphan tool messages should remain
assert all(m["role"] != "tool" for m in truncated["messages"])
# Case 2: truncate away one tool message from a multi-tool chain
truncated = {
"messages": payloads["messages"][:-1]
}
provider._fix_tool_call_pairs_in_dict_context(truncated)
# The remaining context must not contain partial tool-call chains:
# every tool_call id present on assistant messages must be present
# on tool messages, and vice versa.
tool_call_ids = {
tc["id"]
for m in truncated["messages"]
if m["role"] == "assistant"
for tc in m.get("tool_calls", [])
}
tool_msg_ids = {
m["tool_call_id"]
for m in truncated["messages"]
if m["role"] == "tool"
}
assert tool_call_ids == tool_msg_ids
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():

provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run tool"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
},
{"role": "tool", "content": "Tool result", "tool_call_id": "call_1"},
{"role": "assistant", "content": "Final answer"},
]
}
context_query = payloads["messages"]

success, *_rest = await provider._handle_api_error(
Exception("maximum context length exceeded"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)

assert success is False
assert payloads["messages"] == [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "Final answer"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_fix_tool_call_pairs_in_dict_context_removes_partial_multi_tool_chain():
provider = _make_provider()
try:
full_chain = [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run multiple tools"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "tool_a", "arguments": "{}"},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "tool_b", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"name": "tool_a",
"content": "result a",
},
{
"role": "tool",
"tool_call_id": "call_2",
"name": "tool_b",
"content": "result b",
},
{"role": "assistant", "content": "Final answer"},
]

assert provider._fix_tool_call_pairs_in_dict_context(full_chain) == full_chain

missing_assistant = full_chain[:2] + full_chain[3:]
assert provider._fix_tool_call_pairs_in_dict_context(missing_assistant) == [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run multiple tools"},
{"role": "assistant", "content": "Final answer"},
]

missing_one_tool = full_chain[:-2] + [full_chain[-1]]
assert provider._fix_tool_call_pairs_in_dict_context(missing_one_tool) == [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run multiple tools"},
{"role": "assistant", "content": "Final answer"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_context_length_preserves_remaining_valid_messages():
provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "old answer"},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
}
context_query = payloads["messages"]

success, *_rest = await provider._handle_api_error(
Exception("maximum context length exceeded"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)

assert success is False
assert payloads["messages"] == [
{"role": "system", "content": "system"},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_with_unserializable_body():
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
Expand Down
Loading