Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,49 @@ async def text_chat_stream(
raise NotImplementedError()

async def pop_record(self, context: list) -> None:
"""弹出 context 第一条非系统提示词对话记录"""
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
"""弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。"""

def _has_tool_calls(message: dict) -> bool:
return bool(message.get("tool_calls"))

def _next_unit_bounds() -> tuple[int, int] | None:
for idx, record in enumerate(context):
if record.get("role") != "system":
end_idx = idx
role = record.get("role")
if role == "assistant" and _has_tool_calls(record):
# Keep assistant(tool_calls) and following tool messages atomic.
while end_idx + 1 < len(context) and (
context[end_idx + 1].get("role") == "tool"
):
end_idx += 1
elif role == "tool":
# Remove leading orphan tool messages together.
while end_idx + 1 < len(context) and (
context[end_idx + 1].get("role") == "tool"
):
end_idx += 1
return idx, end_idx
return None

# Removal policy: try to remove around TARGET_RECORDS messages,
# but allow up to MAX_RECORDS to keep tool-call/message units atomic.
TARGET_RECORDS = 2
MAX_RECORDS = 3

removed = 0
while removed < TARGET_RECORDS:
next_unit = _next_unit_bounds()
if next_unit is None:
break

for idx in reversed(indexs_to_pop):
context.pop(idx)
start_idx, end_idx = next_unit
next_unit_count = end_idx - start_idx + 1
# Keep behavior close to the old "pop around 2 records" strategy,
# while still preserving tool-call atomicity.
if removed > 0 and removed + next_unit_count > MAX_RECORDS:
break
del context[start_idx : end_idx + 1]
removed += next_unit_count

def _ensure_message_to_dicts(
self,
Expand Down
128 changes: 118 additions & 10 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pathlib import Path
from types import SimpleNamespace
from urllib.parse import urlparse, urlunparse

import pytest
from openai.types.chat.chat_completion import ChatCompletion
Expand Down Expand Up @@ -244,6 +246,112 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history():
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
{"role": "user", "content": "keep me"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "user", "content": "keep me"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_removes_leading_orphan_tool_messages():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{"role": "tool", "tool_call_id": "call_1", "content": "orphan"},
{"role": "user", "content": "old user"},
{"role": "assistant", "content": "old assistant"},
{"role": "user", "content": "new user"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "assistant", "content": "old assistant"},
{"role": "user", "content": "new user"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_normal_messages_no_regression():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{"role": "user", "content": "user1"},
{"role": "assistant", "content": "assistant1"},
{"role": "user", "content": "user2"},
{"role": "assistant", "content": "assistant2"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "user", "content": "user2"},
{"role": "assistant", "content": "assistant2"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_assistant_with_multiple_tool_calls():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{
"role": "assistant",
"tool_calls": [{"id": "call_1"}, {"id": "call_2"}],
"content": None,
},
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
{"role": "user", "content": "keep me"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "user", "content": "keep me"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_only_system_messages():
provider = _make_provider()
try:
context = [{"role": "system", "content": "system"}]

await provider.pop_record(context)

assert context == [{"role": "system", "content": "system"}]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_groq_payload_drops_reasoning_content_from_assistant_history():
provider = _make_groq_provider()
Expand Down Expand Up @@ -782,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp
async def test_file_uri_to_path_preserves_windows_drive_letter():
provider = _make_provider()
try:
assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == (
"C:/tmp/quoted-image.png"
)
resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png")
assert Path(resolved) == Path("C:/tmp/quoted-image.png")
finally:
await provider.terminate()

Expand All @@ -793,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter():
async def test_file_uri_to_path_preserves_windows_netloc_drive_letter():
provider = _make_provider()
try:
assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == (
"C:/tmp/quoted-image.png"
)
resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png")
assert Path(resolved) == Path("C:/tmp/quoted-image.png")
finally:
await provider.terminate()

Expand All @@ -804,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter():
async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path():
provider = _make_provider()
try:
assert provider._file_uri_to_path("file://server/share/quoted-image.png") == (
"//server/share/quoted-image.png"
)
resolved = provider._file_uri_to_path("file://server/share/quoted-image.png")
assert Path(resolved) == Path("//server/share/quoted-image.png")
finally:
await provider.terminate()

Expand Down Expand Up @@ -977,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag
image_path = tmp_path / "quoted-image.png"
PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path)

localhost_uri = f"file://localhost{image_path.as_posix()}"
parsed_local_uri = urlparse(image_path.as_uri())
localhost_uri = urlunparse(
("file", "localhost", parsed_local_uri.path, "", "", "")
)
payloads, _ = await provider._prepare_chat_payload(
prompt=None,
contexts=[
Expand Down