Skip to content
Open
432 changes: 375 additions & 57 deletions astrbot/builtin_stars/web_searcher/main.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion astrbot/core/astr_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.pipeline.context_utils import call_event_hook
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.web_search_utils import WEB_SEARCH_REFERENCE_TOOLS


class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
Expand Down Expand Up @@ -59,7 +60,7 @@ async def on_tool_end(
platform_name = run_context.context.event.get_platform_name()
if (
platform_name == "webchat"
and tool.name in ["web_search_tavily", "web_search_bocha"]
and tool.name in WEB_SEARCH_REFERENCE_TOOLS
and len(run_context.messages) > 0
and tool_result
and len(tool_result.content)
Expand Down
39 changes: 38 additions & 1 deletion astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@
"web_search": False,
"websearch_provider": "default",
"websearch_tavily_key": [],
"websearch_tavily_base_url": "https://api.tavily.com",
"websearch_bocha_key": [],
"websearch_baidu_app_builder_key": "",
"websearch_exa_key": [],
"websearch_exa_base_url": "https://api.exa.ai",
"web_search_link": False,
"display_reasoning_text": False,
"identifier": False,
Expand Down Expand Up @@ -3084,7 +3087,13 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.websearch_provider": {
"description": "网页搜索提供商",
"type": "string",
"options": ["default", "tavily", "baidu_ai_search", "bocha"],
"options": [
"default",
"tavily",
"baidu_ai_search",
"bocha",
"exa",
],
"condition": {
"provider_settings.web_search": True,
},
Expand Down Expand Up @@ -3117,6 +3126,34 @@ class ChatProviderTemplate(TypedDict):
"provider_settings.websearch_provider": "baidu_ai_search",
},
},
"provider_settings.websearch_tavily_base_url": {
"description": "Tavily API Base URL",
"type": "string",
"hint": "默认为 https://api.tavily.com,可改为代理地址。",
"condition": {
"provider_settings.websearch_provider": "tavily",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_exa_key": {
"description": "Exa API Key",
"type": "list",
"items": {"type": "string"},
"hint": "可添加多个 Key 进行轮询。",
"condition": {
"provider_settings.websearch_provider": "exa",
"provider_settings.web_search": True,
},
},
"provider_settings.websearch_exa_base_url": {
"description": "Exa API Base URL",
"type": "string",
"hint": "默认为 https://api.exa.ai,可改为代理地址。",
"condition": {
"provider_settings.websearch_provider": "exa",
"provider_settings.web_search": True,
},
},
"provider_settings.web_search_link": {
"description": "显示来源引用",
"type": "bool",
Expand Down
8 changes: 7 additions & 1 deletion astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,18 @@ async def upload_from_url(
"Error: Tavily API key is not configured in provider_settings."
)

tavily_base_url = config.get("provider_settings", {}).get(
"websearch_tavily_base_url", "https://api.tavily.com"
)

# 阶段1: 从 URL 提取内容
if progress_callback:
await progress_callback("extracting", 0, 100)

try:
text_content = await extract_text_from_url(url, tavily_keys)
text_content = await extract_text_from_url(
url, tavily_keys, tavily_base_url
)
except Exception as e:
logger.error(f"Failed to extract content from URL {url}: {e}")
raise OSError(f"Failed to extract content from URL {url}: {e}") from e
Expand Down
21 changes: 17 additions & 4 deletions astrbot/core/knowledge_base/parsers/url_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@

import aiohttp

from astrbot.core.utils.web_search_utils import normalize_web_search_base_url


class URLExtractor:
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""

def __init__(self, tavily_keys: list[str]) -> None:
def __init__(
self, tavily_keys: list[str], tavily_base_url: str = "https://api.tavily.com"
) -> None:
"""
初始化 URL 提取器

Args:
tavily_keys: Tavily API 密钥列表
tavily_base_url: Tavily API 基础 URL
"""
if not tavily_keys:
raise ValueError("Error: Tavily API keys are not configured.")

self.tavily_keys = tavily_keys
self.tavily_key_index = 0
self.tavily_key_lock = asyncio.Lock()
self.tavily_base_url = normalize_web_search_base_url(
tavily_base_url,
default="https://api.tavily.com",
provider_name="Tavily",
)

async def _get_tavily_key(self) -> str:
"""并发安全的从列表中获取并轮换Tavily API密钥。"""
Expand Down Expand Up @@ -47,7 +57,7 @@ async def extract_text_from_url(self, url: str) -> str:
raise ValueError("Error: url must be a non-empty string.")

tavily_key = await self._get_tavily_key()
api_url = "https://api.tavily.com/extract"
api_url = f"{self.tavily_base_url}/extract"
headers = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
Expand Down Expand Up @@ -88,16 +98,19 @@ async def extract_text_from_url(self, url: str) -> str:


# 为了向后兼容,提供一个简单的函数接口
async def extract_text_from_url(url: str, tavily_keys: list[str]) -> str:
async def extract_text_from_url(
url: str, tavily_keys: list[str], tavily_base_url: str = "https://api.tavily.com"
) -> str:
"""
简单的函数接口,用于从 URL 提取文本内容

Args:
url: 要提取内容的网页 URL
tavily_keys: Tavily API 密钥列表
tavily_base_url: Tavily API 基础 URL

Returns:
提取的文本内容
"""
extractor = URLExtractor(tavily_keys)
extractor = URLExtractor(tavily_keys, tavily_base_url)
return await extractor.extract_text_from_url(url)
130 changes: 130 additions & 0 deletions astrbot/core/utils/web_search_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import json
import re
from typing import Any
from urllib.parse import urlparse

WEB_SEARCH_REFERENCE_TOOLS = (
"web_search_tavily",
"web_search_bocha",
"web_search_exa",
"exa_find_similar",
)


def normalize_web_search_base_url(
base_url: str | None,
*,
default: str,
provider_name: str,
) -> str:
normalized = (base_url or "").strip()
if not normalized:
normalized = default
normalized = normalized.rstrip("/")

parsed = urlparse(normalized)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError(
f"Error: {provider_name} API Base URL must start with http:// or https://.",
)
return normalized


def _iter_web_search_result_items(
accumulated_parts: list[dict[str, Any]],
):
for part in accumulated_parts:
if part.get("type") != "tool_call" or not part.get("tool_calls"):
continue

for tool_call in part["tool_calls"]:
if tool_call.get(
"name"
) not in WEB_SEARCH_REFERENCE_TOOLS or not tool_call.get("result"):
continue

result = tool_call["result"]
try:
result_data = json.loads(result) if isinstance(result, str) else result
except json.JSONDecodeError:
continue

if not isinstance(result_data, dict):
continue

for item in result_data.get("results", []):
if isinstance(item, dict):
yield item


def _extract_ref_indices(accumulated_text: str) -> list[str]:
ref_indices: list[str] = []
seen_indices: set[str] = set()

for match in re.finditer(r"<ref>(.*?)</ref>", accumulated_text):
ref_index = match.group(1).strip()
if not ref_index or ref_index in seen_indices:
continue
ref_indices.append(ref_index)
seen_indices.add(ref_index)

return ref_indices


def collect_web_search_ref_items(
accumulated_parts: list[dict[str, Any]],
favicon_cache: dict[str, str] | None = None,
) -> list[dict[str, Any]]:
web_search_refs: list[dict[str, Any]] = []
seen_indices: set[str] = set()

for item in _iter_web_search_result_items(accumulated_parts):
ref_index = item.get("index")
if not ref_index or ref_index in seen_indices:
continue

payload = {
"index": ref_index,
"url": item.get("url"),
"title": item.get("title"),
"snippet": item.get("snippet"),
}
if favicon_cache and payload["url"] in favicon_cache:
payload["favicon"] = favicon_cache[payload["url"]]

web_search_refs.append(payload)
seen_indices.add(ref_index)

return web_search_refs


def build_web_search_refs(
accumulated_text: str,
accumulated_parts: list[dict[str, Any]],
favicon_cache: dict[str, str] | None = None,
) -> dict:
ordered_refs = collect_web_search_ref_items(accumulated_parts, favicon_cache)
if not ordered_refs:
return {}

refs_by_index = {ref["index"]: ref for ref in ordered_refs}
ref_indices = _extract_ref_indices(accumulated_text)
used_refs = [refs_by_index[idx] for idx in ref_indices if idx in refs_by_index]

if not used_refs:
used_refs = ordered_refs

return {"used": used_refs}


def collect_web_search_results(accumulated_parts: list[dict[str, Any]]) -> dict:
web_search_results = {}

for ref in collect_web_search_ref_items(accumulated_parts):
web_search_results[ref["index"]] = {
"url": ref.get("url"),
"title": ref.get("title"),
"snippet": ref.get("snippet"),
}

return web_search_results
Loading
Loading