Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,48 @@ async def pop_record(self, context: list) -> None:
for idx in reversed(indexs_to_pop):
context.pop(idx)

context[:] = self._fix_tool_call_pairs_in_dict_context(context)

@staticmethod
def _fix_tool_call_pairs_in_dict_context(context: list[dict]) -> list[dict]:
"""Remove orphaned tool call chains from dict-based message history."""
if not context:
return context

fixed_context: list[dict] = []
pending_assistant: dict | None = None
pending_tools: list[dict] = []

def flush_pending_if_valid() -> None:
nonlocal pending_assistant, pending_tools
if pending_assistant is not None and pending_tools:
fixed_context.append(pending_assistant)
fixed_context.extend(pending_tools)
pending_assistant = None
pending_tools = []

for message in context:
role = message.get("role")
if role == "tool":
if pending_assistant is not None:
pending_tools.append(message)
continue

if (
role == "assistant"
and message.get("tool_calls") is not None
and len(message.get("tool_calls")) > 0
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition to identify an assistant message with tool calls can be simplified. In Python, an empty list is falsy, so message.get("tool_calls") is sufficient to check for both the existence of the key and that the list is not empty. This also avoids redundant calls to .get() and len().

            if role == "assistant" and message.get("tool_calls"):

flush_pending_if_valid()
pending_assistant = message
continue

flush_pending_if_valid()
fixed_context.append(message)

flush_pending_if_valid()
return fixed_context

def _ensure_message_to_dicts(
self,
messages: list[dict] | list[Message] | None,
Expand Down
81 changes: 81 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,87 @@ async def test_handle_api_error_model_not_vlm_after_fallback_raises():
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():
Comment on lines +168 to +169
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test for multiple tool calls and their tool responses to ensure whole tool-call chains are cleaned up correctly after truncation.

The current test covers a single assistant tool_calls → tool message pair. To better exercise _fix_tool_call_pairs_in_dict_context, please add a variant where one assistant message has multiple tool_calls with multiple corresponding tool messages (e.g., call_1, call_2). Then simulate truncation that drops part of a chain (some tools, or the assistant but not all tools) and assert that payloads['messages'] never contains partial or orphaned tool chains—only fully intact chains or none at all.

Suggested change
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_multi_tool_chains():
provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run multiple tools"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "tool_a", "arguments": "{}"},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "tool_b", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_1",
"name": "tool_a",
"content": "result a",
},
{
"role": "tool",
"tool_call_id": "call_2",
"name": "tool_b",
"content": "result b",
},
]
}
# Case 1: truncate away the assistant message but leave tool messages
truncated = {
"messages": payloads["messages"][:2] + payloads["messages"][3:]
}
provider._fix_tool_call_pairs_in_dict_context(truncated)
# No orphan tool messages should remain
assert all(m["role"] != "tool" for m in truncated["messages"])
# Case 2: truncate away one tool message from a multi-tool chain
truncated = {
"messages": payloads["messages"][:-1]
}
provider._fix_tool_call_pairs_in_dict_context(truncated)
# The remaining context must not contain partial tool-call chains:
# every tool_call id present on assistant messages must be present
# on tool messages, and vice versa.
tool_call_ids = {
tc["id"]
for m in truncated["messages"]
if m["role"] == "assistant"
for tc in m.get("tool_calls", [])
}
tool_msg_ids = {
m["tool_call_id"]
for m in truncated["messages"]
if m["role"] == "tool"
}
assert tool_call_ids == tool_msg_ids
finally:
await provider.terminate()
@pytest.mark.asyncio
async def test_handle_api_error_context_length_removes_orphaned_tool_messages():

provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "Run tool"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
},
{"role": "tool", "content": "Tool result", "tool_call_id": "call_1"},
{"role": "assistant", "content": "Final answer"},
]
}
context_query = payloads["messages"]

success, *_rest = await provider._handle_api_error(
Exception("maximum context length exceeded"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)

assert success is False
assert payloads["messages"] == [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "Final answer"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_context_length_preserves_remaining_valid_messages():
provider = _make_provider()
try:
payloads = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "old answer"},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
}
context_query = payloads["messages"]

success, *_rest = await provider._handle_api_error(
Exception("maximum context length exceeded"),
payloads=payloads,
context_query=context_query,
func_tool=None,
chosen_key="test-key",
available_api_keys=["test-key"],
retry_cnt=0,
max_retries=10,
)

assert success is False
assert payloads["messages"] == [
{"role": "system", "content": "system"},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_handle_api_error_content_moderated_with_unserializable_body():
provider = _make_provider({"image_moderation_error_patterns": ["blocked"]})
Expand Down
Loading