Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions astrbot/core/provider/sources/bailian_rerank_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def _build_payload(
normalized_model = self.model.strip().lower()
normalized_top_n = top_n if top_n is not None and top_n > 0 else None

# qwen3-rerank follows a model-specific payload:
# query/documents/top_n/instruct should be at the top level.
if normalized_model == self.QWEN3_RERANK_MODEL:
is_compatible_api = "compatible-api" in self.base_url
if normalized_model == self.QWEN3_RERANK_MODEL and is_compatible_api:
payload = {
"model": self.model,
"query": query,
Expand All @@ -107,16 +106,33 @@ def _build_payload(
)
return payload

base = {"model": self.model, "input": {"query": query, "documents": documents}}

params = {
k: v
for k, v in [
("top_n", normalized_top_n),
("return_documents", True if self.return_documents else None),
]
if v is not None
}
if is_compatible_api:
base = {
"model": self.model,
"query": query,
"documents": documents,
}
params = {
k: v
for k, v in [
("top_n", normalized_top_n),
("return_documents", True if self.return_documents else None),
]
if v is not None
}
else:
base = {
"model": self.model,
"input": {"query": query, "documents": documents},
}
params = {
k: v
for k, v in [
("top_n", normalized_top_n),
("return_documents", True if self.return_documents else None),
]
if v is not None
}

if params:
base["parameters"] = params
Expand All @@ -136,16 +152,20 @@ def _parse_results(self, data: dict) -> list[RerankResult]:
BailianAPIError: API返回错误
KeyError: 结果缺少必要字段
"""
# 检查响应状态
if data.get("code", "200") != "200":
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
)
is_compatible_api = "compatible-api" in self.base_url

results = data.get("output", {}).get("results", [])
if not results:
logger.warning(f"百炼 Rerank 返回空结果: {data}")
return []
if is_compatible_api:
if data.get("code"):
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
)
results = data.get("results", [])
else:
if data.get("code", "200") != "200":
raise BailianAPIError(
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
)
results = data.get("output", {}).get("results", [])

# 转换为RerankResult对象,使用.get()避免KeyError
rerank_results = []
Expand Down
Loading