diff --git a/.trae/specs/qqofficial-fixes/checklist.md b/.trae/specs/qqofficial-fixes/checklist.md new file mode 100644 index 0000000000..41d9279634 --- /dev/null +++ b/.trae/specs/qqofficial-fixes/checklist.md @@ -0,0 +1,8 @@ +- [x] 检查点 1: 验证 chunk_text 函数是否正确修复,无死循环和重复块 +- [x] 检查点 2: 验证流式 C2C 降级条件是否覆盖所有富媒体类型 +- [x] 检查点 3: 验证频道消息是否支持 URL 图片发送 +- [x] 检查点 4: 验证 MessageReplyLimiter 是否使用 logger 进行日志记录 +- [x] 检查点 5: 验证 MessageReplyLimiter 的并发安全性 +- [x] 检查点 6: 验证未使用的上传辅助函数和缓存是否已清理 +- [x] 检查点 7: 运行项目的测试和 lint 检查,确保代码质量 +- [x] 检查点 8: 验证修复后的代码与现有代码风格和架构保持一致 \ No newline at end of file diff --git a/.trae/specs/qqofficial-fixes/spec.md b/.trae/specs/qqofficial-fixes/spec.md new file mode 100644 index 0000000000..208a0a938c --- /dev/null +++ b/.trae/specs/qqofficial-fixes/spec.md @@ -0,0 +1,84 @@ +# QQOfficial 模块修复 - 产品需求文档 + +## Overview +- **Summary**: 修复 QQOfficial 模块中的多个 bug,包括文本分块逻辑、流式消息降级条件、频道消息图片发送和消息回复限流器等问题 +- **Purpose**: 解决 PR #7176 中提出的代码审查问题,确保 QQOfficial 模块的稳定性和可靠性 +- **Target Users**: 开发团队和使用 QQOfficial 模块的用户 + +## Goals +- 修复 chunk_text 函数的游标更新逻辑,避免死循环和重复块风险 +- 完善流式 C2C 降级条件,覆盖所有富媒体类型 +- 修复频道消息图片发送问题,支持 URL 图片 +- 改进 MessageReplyLimiter 的日志记录和并发安全性 +- 清理未使用的上传辅助函数和缓存 + +## Non-Goals (Out of Scope) +- 重构整个 QQOfficial 模块 +- 添加新功能或特性 +- 修改其他平台适配器的代码 + +## Background & Context +- PR #7176 提出了多个代码审查问题,需要修复 +- 参考 OpenClaw 项目的实现方式进行修复 +- 确保修复后的代码与现有代码风格和架构保持一致 + +## Functional Requirements +- **FR-1**: 修复 chunk_text 函数的游标更新逻辑,确保每次循环 start 都单调前进 +- **FR-2**: 完善流式 C2C 降级条件,当检测到任何富媒体时都降级为非流式发送 +- **FR-3**: 修复频道消息图片发送问题,支持 URL 图片 +- **FR-4**: 改进 MessageReplyLimiter,使用 logger 进行日志记录,避免使用模块级全局变量 +- **FR-5**: 清理未使用的上传辅助函数和缓存 + +## Non-Functional Requirements +- **NFR-1**: 代码质量:修复后的代码应符合项目的代码风格和最佳实践 +- **NFR-2**: 安全性:确保 MessageReplyLimiter 的并发安全性 +- **NFR-3**: 可维护性:清理未使用的代码,提高代码可读性 + +## Constraints +- **Technical**: 保持与现有代码架构的一致性 +- **Dependencies**: 参考 OpenClaw 项目的实现方式 + +## Assumptions +- OpenClaw 项目的实现方式是可靠的参考 +- 修复后的代码应通过项目的测试和 lint 检查 + +## Acceptance Criteria + +### AC-1: 修复 chunk_text 函数 +- **Given**: 长文本需要分块 +- **When**: 调用 chunk_text 函数 +- **Then**: 函数应正确分块,无死循环,无重复块 +- **Verification**: `programmatic` +- **Notes**: 确保每次循环 start 都单调前进 + +### AC-2: 完善流式 C2C 降级条件 +- **Given**: 发送包含语音、视频或文件的流式 C2C 消息 +- **When**: 触发流式消息发送 +- **Then**: 应降级为非流式发送 +- **Verification**: `programmatic` +- **Notes**: 确保所有富媒体类型都被覆盖 + +### AC-3: 修复频道消息图片发送 +- **Given**: 发送包含 URL 图片的频道消息 +- **When**: 调用频道消息发送接口 +- **Then**: 应正确发送 URL 图片 +- **Verification**: `programmatic` +- **Notes**: 区分本地路径和 URL 图片的处理 + +### AC-4: 改进 MessageReplyLimiter +- **Given**: 使用 MessageReplyLimiter 进行消息回复限流 +- **When**: 记录消息回复或检查限流 +- **Then**: 应使用 logger 进行日志记录,且线程安全 +- **Verification**: `programmatic` +- **Notes**: 避免使用模块级全局变量 + +### AC-5: 清理未使用的代码 +- **Given**: 检查上传相关代码 +- **When**: 分析代码使用情况 +- **Then**: 移除或标记未使用的上传辅助函数和缓存 +- **Verification**: `human-judgment` +- **Notes**: 保持代码整洁 + +## Open Questions +- [ ] 是否需要添加单元测试来验证修复效果? +- [ ] 清理未使用代码时是否需要保留某些接口以保持向后兼容? \ No newline at end of file diff --git a/.trae/specs/qqofficial-fixes/tasks.md b/.trae/specs/qqofficial-fixes/tasks.md new file mode 100644 index 0000000000..34a533ce00 --- /dev/null +++ b/.trae/specs/qqofficial-fixes/tasks.md @@ -0,0 +1,66 @@ +# QQOfficial 模块修复 - 实现计划 + +## [x] 任务 1: 修复 chunk_text 函数的游标更新逻辑 +- **优先级**: P0 +- **依赖**: 无 +- **描述**: + - 修改 qqofficial_message_event.py 中的 chunk_text 函数 + - 简化游标更新逻辑,确保每次循环 start 都单调前进 + - 避免使用复杂的 overlap 逻辑和 find 方法 +- **接受标准**: AC-1 +- **测试需求**: + - `programmatic` TR-1.1: 测试长文本分块功能,确保无死循环和重复块 + - `programmatic` TR-1.2: 测试边界条件,如文本长度正好等于限制、小于限制等 +- **注意**: 参考 PR 中的建议,使用 `start = max(breakpoint - overlap, start + 1)` 或类似逻辑 + +## [x] 任务 2: 完善流式 C2C 降级条件 +- **优先级**: P0 +- **依赖**: 无 +- **描述**: + - 修改 qqofficial_message_event.py 中的流式消息降级逻辑 + - 确保当检测到任何富媒体时都降级为非流式发送 + - 覆盖图片、语音、视频和文件等所有富媒体类型 +- **接受标准**: AC-2 +- **测试需求**: + - `programmatic` TR-2.1: 测试包含语音的流式 C2C 消息,应降级为非流式 + - `programmatic` TR-2.2: 测试包含视频的流式 C2C 消息,应降级为非流式 + - `programmatic` TR-2.3: 测试包含文件的流式 C2C 消息,应降级为非流式 +- **注意**: 参考 PR 中的建议,使用 `if stream and (image_source or record_file_path or video_file_source or file_source):` + +## [x] 任务 3: 修复频道消息图片发送问题 +- **优先级**: P0 +- **依赖**: 无 +- **描述**: + - 修改 qqofficial_platform_adapter.py 中的频道消息发送逻辑 + - 支持 URL 图片的发送 + - 区分本地路径和 URL 图片的处理 +- **接受标准**: AC-3 +- **测试需求**: + - `programmatic` TR-3.1: 测试发送包含 URL 图片的频道消息 + - `programmatic` TR-3.2: 测试发送包含本地路径图片的频道消息 +- **注意**: 参考 PR 中的建议,添加对 URL 图片的特殊处理 + +## [x] 任务 4: 改进 MessageReplyLimiter +- **优先级**: P1 +- **依赖**: 无 +- **描述**: + - 修改 rate_limiter.py 中的 MessageReplyLimiter 类 + - 使用 logger 进行日志记录,替代 print + - 改进并发安全性,避免使用模块级全局变量 +- **接受标准**: AC-4 +- **测试需求**: + - `programmatic` TR-4.1: 测试消息回复限流功能 + - `programmatic` TR-4.2: 测试并发场景下的限流器行为 +- **注意**: 参考 OpenClaw 项目的实现方式 + +## [x] 任务 5: 清理未使用的上传辅助函数和缓存 +- **优先级**: P2 +- **依赖**: 无 +- **描述**: + - 检查 chunked_upload.py 中的上传相关代码 + - 移除或标记未使用的上传辅助函数和缓存 + - 保持代码整洁 +- **接受标准**: AC-5 +- **测试需求**: + - `human-judgment` TR-5.1: 检查代码是否整洁,无未使用的函数和缓存 +- **注意**: 确保不影响现有功能 \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial/__init__.py b/astrbot/core/platform/sources/qqofficial/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/astrbot/core/platform/sources/qqofficial/chunked_upload.py b/astrbot/core/platform/sources/qqofficial/chunked_upload.py new file mode 100644 index 0000000000..defd38582f --- /dev/null +++ b/astrbot/core/platform/sources/qqofficial/chunked_upload.py @@ -0,0 +1,991 @@ +""" +分片上传模块 +参照 openclaw-qqbot 的 chunked-upload.ts 实现 + +流程: +1. 申请上传 (upload_prepare) → 获取 upload_id + block_size + 分片预签名链接 +2. 并行上传所有分片 +3. 所有分片完成后,调用完成文件上传接口 → 获取 file_info + +特性: +- 完善的重试机制(分片上传、分片完成、文件完成) +- 上传缓存(相同文件不重复上传) +- 用户友好的错误提示 +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +import time +from dataclasses import dataclass +from typing import Callable, Optional, Dict, Tuple + +import aiohttp + +from astrbot import logger + + +# ============ 常量 ============ + +DEFAULT_CONCURRENT_PARTS = 1 +MAX_CONCURRENT_PARTS = 10 +PART_UPLOAD_TIMEOUT = 300 # 5分钟 +PART_UPLOAD_MAX_RETRIES = 3 +MAX_PART_FINISH_RETRY_TIMEOUT_MS = 10 * 60 * 1000 # 10分钟 +MD5_10M_SIZE = 10002432 # 用于计算 md5_10m + +# 每日上传限额错误码 +UPLOAD_PREPARE_FALLBACK_CODE = 40093002 +PART_FINISH_RETRYABLE_CODES = {40093001} + +# 完成上传重试配置 +COMPLETE_UPLOAD_MAX_RETRIES = 3 +COMPLETE_UPLOAD_BASE_DELAY_MS = 2000 + +# 分片完成重试配置 +PART_FINISH_MAX_RETRIES = 2 +PART_FINISH_BASE_DELAY_MS = 1000 +PART_FINISH_RETRYABLE_DEFAULT_TIMEOUT_MS = 2 * 60 * 1000 +PART_FINISH_RETRYABLE_INTERVAL_MS = 1000 + + +# ============ 异常定义 ============ + + +class UploadDailyLimitExceededError(Exception): + """每日上传限额超限""" + + def __init__(self, file_path: str, file_size: int, message: str): + self.file_path = file_path + self.file_size = file_size + super().__init__(message) + + +class ApiError(Exception): + """API 错误""" + + def __init__( + self, message: str, status: int, path: str, biz_code: Optional[int] = None + ): + self.status = status + self.path = path + self.biz_code = biz_code + super().__init__(message) + + +class ChunkedUploadError(Exception): + """分片上传错误""" + + def __init__( + self, + message: str, + file_path: str, + file_size: int, + cause: Optional[Exception] = None, + ): + self.file_path = file_path + self.file_size = file_size + self.cause = cause + super().__init__(message) + + +# ============ 全局 HTTP 客户端管理器(按 appId 隔离)============ + +import threading + + +class QQBotHttpClientManager: + """ + HTTP 客户端全局管理器 + + 按 appId 隔离客户端实例,实现多机器人共享 Token 缓存。 + - 同一 appId 的多个实例共享同一个 QQBotHttpClient + - Singleflight 模式避免并发重复获取 Token + + 注意:由于客户端创建是轻量操作,使用简单的同步锁即可, + 避免 asyncio.Lock 在非事件循环上下文中的问题。 + """ + + _instance: Optional["QQBotHttpClientManager"] = None + _clients: Dict[str, QQBotHttpClient] = {} + _lock = threading.Lock() + + @classmethod + def get_instance(cls) -> "QQBotHttpClientManager": + """获取单例实例""" + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + async def get_client(cls, appid: str, secret: str) -> QQBotHttpClient: + """ + 获取指定 appId 的 HTTP 客户端(按需创建,按 appId 隔离) + + Args: + appid: QQ Bot AppID + secret: QQ Bot Secret + + Returns: + QQBotHttpClient: 该 appId 对应的 HTTP 客户端 + """ + # 使用同步锁保护客户端创建 + with cls._lock: + if appid not in cls._clients: + logger.debug( + f"[QQBotHttpClientManager] Creating new client for appId={appid[:8]}..." + ) + cls._clients[appid] = QQBotHttpClient(appid, secret) + return cls._clients[appid] + + @classmethod + def clear(cls) -> None: + """清除所有客户端实例(用于测试或重置)""" + with cls._lock: + cls._clients.clear() + logger.debug("[QQBotHttpClientManager] All clients cleared") + + @classmethod + def get_stats(cls) -> Dict[str, Dict]: + """获取各客户端状态统计""" + with cls._lock: + return { + appid: { + "has_token": client._token is not None, + "token_expires_in": max( + 0, int(client._token_expires_at - time.time()) + ) + if client._token_expires_at + else None, + } + for appid, client in cls._clients.items() + } + + +# ============ 数据类 ============ + + +@dataclass +class UploadPrepareHashes: + md5: str + sha1: str + md5_10m: str + + +@dataclass +class UploadPart: + index: int + presigned_url: str + + +@dataclass +class UploadPrepareResponse: + upload_id: str + block_size: int + parts: list[UploadPart] + concurrency: int = 1 + retry_timeout: int = 0 + + +@dataclass +class MediaUploadResponse: + file_uuid: str + file_info: str + ttl: int + + +@dataclass +class ChunkedUploadProgress: + completed_parts: int + total_parts: int + uploaded_bytes: int + total_bytes: int + + +# ============ 文件哈希计算 ============ + + +async def compute_file_hashes(file_path: str, file_size: int) -> UploadPrepareHashes: + """ + 计算文件的 MD5、SHA1、md5_10m + + Args: + file_path: 文件路径 + file_size: 文件大小 + + Returns: + UploadPrepareHashes: 文件哈希信息 + """ + md5_hash = hashlib.md5() + sha1_hash = hashlib.sha1() + md5_10m_hash = hashlib.md5() + + need_10m = file_size > MD5_10M_SIZE + bytes_read = 0 + + with open(file_path, "rb") as f: + while True: + chunk = f.read(65536) # 64KB + if not chunk: + break + + md5_hash.update(chunk) + sha1_hash.update(chunk) + + if need_10m: + remaining = MD5_10M_SIZE - bytes_read + if remaining > 0: + md5_10m_hash.update( + chunk[:remaining] if len(chunk) > remaining else chunk + ) + + bytes_read += len(chunk) + + return UploadPrepareHashes( + md5=md5_hash.hexdigest(), + sha1=sha1_hash.hexdigest(), + md5_10m=md5_10m_hash.hexdigest() if need_10m else md5_hash.hexdigest(), + ) + + +def read_file_chunk(file_path: str, offset: int, length: int) -> bytes: + """读取文件的指定区间""" + with open(file_path, "rb") as f: + f.seek(offset) + return f.read(length) + + +# ============ API 请求封装 ============ + + +class QQBotHttpClient: + """QQ Bot HTTP 客户端,直接调用 API""" + + API_BASE = "https://api.sgroup.qq.com" + TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken" + + # User-Agent 标识 + PLUGIN_USER_AGENT = "AstrBot-QQOfficial/1.0 (Python/3.x)" + + def __init__(self, appid: str, secret: str): + self.appid = appid + self.secret = secret + self._token: Optional[str] = None + self._token_expires_at: float = 0 + self._token_fetch_lock = asyncio.Lock() + self._token_fetch_promise: Optional[asyncio.Future[str]] = None + self._session: Optional[aiohttp.ClientSession] = None + self._session_lock = asyncio.Lock() + + async def _get_session(self) -> aiohttp.ClientSession: + """获取或创建共享的 ClientSession""" + if self._session is None or self._session.closed: + async with self._session_lock: + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector( + limit=100, + keepalive_timeout=30, + ) + self._session = aiohttp.ClientSession( + connector=connector, + ) + return self._session + + async def close(self): + """关闭 ClientSession""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def get_access_token(self) -> str: + """ + 获取 AccessToken(带缓存 + singleflight 并发安全) + + 使用 singleflight 模式:当多个请求同时发现 Token 过期时, + 只有第一个请求会真正去获取新 Token,其他请求复用同一个 Promise。 + """ + # 提前5分钟刷新 + if self._token and time.time() < self._token_expires_at - 300: + return self._token + + # Singleflight: 避免并发重复获取 + async with self._token_fetch_lock: + # 双重检查 + if self._token and time.time() < self._token_expires_at - 300: + return self._token + + # 如果已有进行中的获取请求,复用它 + if self._token_fetch_promise is not None: + return await self._token_fetch_promise + + # 创建新的获取请求 + self._token_fetch_promise = asyncio.create_task(self._do_fetch_token()) + try: + token = await self._token_fetch_promise + return token + finally: + self._token_fetch_promise = None + + async def _do_fetch_token(self) -> str: + """实际执行 Token 获取""" + logger.debug(f"[QQBotHttpClient:{self.appid}] Fetching access token...") + + async with aiohttp.ClientSession() as session: + async with session.post( + self.TOKEN_URL, + json={"appId": self.appid, "clientSecret": self.secret}, + headers={ + "Content-Type": "application/json", + "User-Agent": self.PLUGIN_USER_AGENT, + }, + ) as resp: + data = await resp.json() + if "access_token" not in data: + error_msg = data.get("message", str(data)) + logger.error( + f"[QQBotHttpClient:{self.appid}] Token fetch failed: {error_msg}" + ) + raise RuntimeError(f"获取 access_token 失败: {error_msg}") + + self._token = data["access_token"] + expires_in = int(data.get("expires_in", 7200)) + self._token_expires_at = time.time() + expires_in + + logger.debug( + f"[QQBotHttpClient:{self.appid}] Token cached, expires in {expires_in}s" + ) + return self._token + + async def api_request( + self, + method: str, + path: str, + body: Optional[dict] = None, + timeout: float = 300.0, + ) -> dict: + """API 请求封装(带详细日志)""" + token = await self.get_access_token() + url = f"{self.API_BASE}{path}" + headers = { + "Authorization": f"QQBot {token}", + "Content-Type": "application/json", + "User-Agent": self.PLUGIN_USER_AGENT, + } + + # 打印请求信息(隐藏敏感数据) + log_body = dict(body) if body else None + if log_body and "file_data" in log_body: + log_body["file_data"] = f"" + logger.debug(f"[QQBotHttpClient] >>> {method} {path}") + if log_body: + logger.debug(f"[QQBotHttpClient] >>> Body: {log_body}") + + session = await self._get_session() + async with session.request( + method, + url, + json=body, + headers=headers, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as resp: + # 打印响应信息 + trace_id = resp.headers.get("x-tps-trace-id", "") + logger.debug( + f"[QQBotHttpClient] <<< Status: {resp.status} {resp.reason}" + + (f" | TraceId: {trace_id}" if trace_id else "") + ) + + raw = await resp.text() + logger.debug(f"[QQBotHttpClient] <<< Body: {raw[:500]}") + + if not resp.ok: + try: + import json + + err_data = json.loads(raw) if raw else {} + biz_code = err_data.get("code") or err_data.get("err_code") + error_msg = err_data.get("message", "Unknown error") + + logger.error( + f"[QQBotHttpClient] API Error [{path}]: {error_msg} (bizCode={biz_code})" + ) + raise ApiError( + f"API Error [{path}]: {error_msg}", resp.status, path, biz_code + ) + except Exception as e: + if isinstance(e, ApiError): + raise + logger.error( + f"[QQBotHttpClient] API Error [{path}] HTTP {resp.status}: {raw[:200]}" + ) + raise ApiError( + f"API Error [{path}] HTTP {resp.status}: {raw[:200]}", + resp.status, + path, + ) + + import json + + return json.loads(raw) + + async def base64_upload( + self, + file_type: int, + file_data: str, + file_name: Optional[str] = None, + srv_send_msg: bool = False, + target_type: str = "c2c", + target_id: str = "", + ) -> MediaUploadResponse: + """ + Base64 格式上传文件(小文件专用,带长超时) + + 与分片上传不同,Base64 上传直接将文件内容放在请求体中, + 适用于 5MB 以下的文件。超时设置为 300 秒(5分钟)以适应慢速网络。 + + Args: + file_type: 文件类型(1=图片, 2=视频, 3=语音, 4=文件) + file_data: Base64 编码的文件内容 + file_name: 文件名(可选) + srv_send_msg: 是否作为机器人发送 + target_type: 目标类型 ("c2c" 或 "group") + target_id: 用户 openid 或群 openid + + Returns: + MediaUploadResponse: 包含 file_uuid, file_info, ttl + """ + if target_type == "c2c": + path = f"/v2/users/{target_id}/files" + else: + path = f"/v2/groups/{target_id}/files" + + payload = { + "file_type": file_type, + "file_data": file_data, + "srv_send_msg": srv_send_msg, + } + if file_name: + payload["file_name"] = file_name + + logger.info( + f"[QQBotHttpClient] Base64 upload: target={target_type}:{target_id[:16]}, file_type={file_type}, size={len(file_data)} chars" + ) + + data = await self.api_request("POST", path, body=payload, timeout=300.0) + + return MediaUploadResponse( + file_uuid=data["file_uuid"], + file_info=data["file_info"], + ttl=data.get("ttl", 0), + ) + + async def c2c_upload_prepare( + self, + user_id: str, + file_type: int, + file_name: str, + file_size: int, + hashes: UploadPrepareHashes, + ) -> UploadPrepareResponse: + """C2C 申请上传""" + logger.info( + f"[QQBotHttpClient] C2C upload_prepare: user={user_id[:16]}, file={file_name}, size={file_size}" + ) + + data = await self.api_request( + "POST", + f"/v2/users/{user_id}/upload_prepare", + { + "file_type": file_type, + "file_name": file_name, + "file_size": file_size, + "md5": hashes.md5, + "sha1": hashes.sha1, + "md5_10m": hashes.md5_10m, + }, + timeout=60.0, + ) + + logger.info( + f"[QQBotHttpClient] C2C upload_prepare success: upload_id={data['upload_id']}, parts={len(data['parts'])}" + ) + + return UploadPrepareResponse( + upload_id=data["upload_id"], + block_size=int(data["block_size"]), + parts=[ + UploadPart(index=p["index"], presigned_url=p["presigned_url"]) + for p in data["parts"] + ], + concurrency=int(data.get("concurrency", 1)), + retry_timeout=int(data.get("retry_timeout", 0)), + ) + + async def c2c_upload_part_finish( + self, + user_id: str, + upload_id: str, + part_index: int, + block_size: int, + md5: str, + retry_timeout_ms: Optional[int] = None, + ) -> None: + """C2C 完成分片上传(带持续重试)""" + logger.debug(f"[QQBotHttpClient] C2C upload_part_finish: part={part_index}") + await self._part_finish_with_retry( + "POST", + f"/v2/users/{user_id}/upload_part_finish", + { + "upload_id": upload_id, + "part_index": part_index, + "block_size": block_size, + "md5": md5, + }, + retry_timeout_ms, + ) + + async def c2c_complete_upload( + self, user_id: str, upload_id: str + ) -> MediaUploadResponse: + """C2C 完成文件上传(带重试)""" + result = await self._complete_upload_with_retry( + "POST", f"/v2/users/{user_id}/files", {"upload_id": upload_id} + ) + logger.info( + f"[QQBotHttpClient] c2c complete_upload success: upload_id={upload_id}, file_uuid={result.file_uuid}" + ) + return result + + async def group_upload_prepare( + self, + group_id: str, + file_type: int, + file_name: str, + file_size: int, + hashes: UploadPrepareHashes, + ) -> UploadPrepareResponse: + """Group 申请上传""" + logger.info( + f"[QQBotHttpClient] Group upload_prepare: group={group_id[:16]}, file={file_name}, size={file_size}" + ) + + data = await self.api_request( + "POST", + f"/v2/groups/{group_id}/upload_prepare", + { + "file_type": file_type, + "file_name": file_name, + "file_size": file_size, + "md5": hashes.md5, + "sha1": hashes.sha1, + "md5_10m": hashes.md5_10m, + }, + timeout=60.0, + ) + + logger.info( + f"[QQBotHttpClient] Group upload_prepare success: upload_id={data['upload_id']}, parts={len(data['parts'])}" + ) + + return UploadPrepareResponse( + upload_id=data["upload_id"], + block_size=int(data["block_size"]), + parts=[ + UploadPart(index=p["index"], presigned_url=p["presigned_url"]) + for p in data["parts"] + ], + concurrency=int(data.get("concurrency", 1)), + retry_timeout=int(data.get("retry_timeout", 0)), + ) + + async def group_upload_part_finish( + self, + group_id: str, + upload_id: str, + part_index: int, + block_size: int, + md5: str, + retry_timeout_ms: Optional[int] = None, + ) -> None: + """Group 完成分片上传(带持续重试)""" + await self._part_finish_with_retry( + "POST", + f"/v2/groups/{group_id}/upload_part_finish", + { + "upload_id": upload_id, + "part_index": part_index, + "block_size": block_size, + "md5": md5, + }, + retry_timeout_ms, + ) + + async def group_complete_upload( + self, group_id: str, upload_id: str + ) -> MediaUploadResponse: + """Group 完成文件上传(带重试)""" + return await self._complete_upload_with_retry( + "POST", f"/v2/groups/{group_id}/files", {"upload_id": upload_id} + ) + + # ============ 内部重试逻辑 ============ + + async def _part_finish_with_retry( + self, method: str, path: str, body: dict, retry_timeout_ms: Optional[int] = None + ) -> None: + """分片完成接口重试策略""" + PART_FINISH_MAX_RETRIES = 2 + PART_FINISH_BASE_DELAY_MS = 1000 + PART_FINISH_RETRYABLE_DEFAULT_TIMEOUT_MS = 2 * 60 * 1000 + PART_FINISH_RETRYABLE_INTERVAL_MS = 1000 + + last_error: Optional[Exception] = None + + for attempt in range(PART_FINISH_MAX_RETRIES + 1): + try: + await self.api_request(method, path, body, timeout=60.0) + return + except Exception as err: + last_error = err + + # 命中特定错误码 → 进入持续重试模式 + if ( + isinstance(err, ApiError) + and err.biz_code in PART_FINISH_RETRYABLE_CODES + ): + timeout_ms = ( + retry_timeout_ms or PART_FINISH_RETRYABLE_DEFAULT_TIMEOUT_MS + ) + logger.warning( + f"[chunked] PartFinish hit retryable bizCode={err.biz_code}, entering persistent retry (timeout={timeout_ms / 1000}s)" + ) + await self._part_finish_persistent_retry( + method, path, body, timeout_ms + ) + return + + if attempt < PART_FINISH_MAX_RETRIES: + delay = PART_FINISH_BASE_DELAY_MS * (2**attempt) / 1000 + logger.warning( + f"[chunked] PartFinish attempt {attempt + 1} failed, retrying in {delay}s: {str(err)[:200]}" + ) + await asyncio.sleep(delay) + + raise last_error or RuntimeError("PartFinish failed") + + async def _part_finish_persistent_retry( + self, method: str, path: str, body: dict, timeout_ms: int + ) -> None: + """特定错误码的持续重试模式""" + PART_FINISH_RETRYABLE_INTERVAL_MS = 1000 + deadline = time.time() + timeout_ms / 1000 + attempt = 0 + + while time.time() < deadline: + try: + await self.api_request(method, path, body, timeout=60.0) + logger.info( + f"[chunked] PartFinish persistent retry succeeded after {attempt} retries" + ) + return + except Exception as err: + # 如果不再是可重试的错误码,直接抛出 + if not ( + isinstance(err, ApiError) + and err.biz_code in PART_FINISH_RETRYABLE_CODES + ): + logger.error( + f"[chunked] PartFinish persistent retry: error is no longer retryable" + ) + raise + + attempt += 1 + remaining = deadline - time.time() + if remaining <= 0: + break + + logger.warning( + f"[chunked] PartFinish persistent retry #{attempt}: bizCode={err.biz_code}, retrying (remaining={int(remaining)}s)" + ) + await asyncio.sleep(PART_FINISH_RETRYABLE_INTERVAL_MS / 1000) + + raise RuntimeError( + f"upload_part_finish 持续重试超时({timeout_ms / 1000}s, {attempt} 次重试)" + ) + + async def _complete_upload_with_retry( + self, method: str, path: str, body: dict + ) -> MediaUploadResponse: + """完成上传接口重试(无条件重试)""" + COMPLETE_UPLOAD_MAX_RETRIES = 2 + COMPLETE_UPLOAD_BASE_DELAY_MS = 2000 + + last_error: Optional[Exception] = None + + for attempt in range(COMPLETE_UPLOAD_MAX_RETRIES + 1): + try: + data = await self.api_request(method, path, body, timeout=120.0) + return MediaUploadResponse( + file_uuid=data["file_uuid"], + file_info=data["file_info"], + ttl=data.get("ttl", 0), + ) + except Exception as err: + last_error = err + + if attempt < COMPLETE_UPLOAD_MAX_RETRIES: + delay = COMPLETE_UPLOAD_BASE_DELAY_MS * (2**attempt) / 1000 + logger.warning( + f"[chunked] CompleteUpload attempt {attempt + 1} failed, retrying in {delay}s" + ) + await asyncio.sleep(delay) + + raise last_error or RuntimeError("CompleteUpload failed") + + +# ============ 分片上传核心逻辑 ============ + + +async def put_to_presigned_url( + presigned_url: str, + data: bytes, + prefix: str = "[chunked]", + part_index: int = 0, + total_parts: int = 0, +) -> None: + """PUT 分片数据到预签名 URL(带重试)""" + last_error: Optional[Exception] = None + + for attempt in range(PART_UPLOAD_MAX_RETRIES + 1): + try: + timeout = aiohttp.ClientTimeout(total=PART_UPLOAD_TIMEOUT, connect=60) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.put( + presigned_url, data=data, headers={"Content-Length": str(len(data))} + ) as resp: + if not resp.ok: + body = await resp.text() + raise RuntimeError( + f"COS PUT failed: {resp.status} {body[:200]}" + ) + + logger.debug( + f"{prefix} Part {part_index}/{total_parts}: uploaded {len(data)} bytes" + ) + return + except Exception as e: + last_error = e + if attempt < PART_UPLOAD_MAX_RETRIES: + delay = 1000 * (2**attempt) / 1000 # 1s, 2s + logger.warning( + f"{prefix} Part {part_index}/{total_parts}: attempt {attempt + 1} failed, retrying in {delay}s: {str(e)[:100]}" + ) + await asyncio.sleep(delay) + + raise last_error or RuntimeError("Upload failed") + + +async def chunked_upload_c2c( + http_client: QQBotHttpClient, + user_id: str, + file_path: str, + file_type: int, + on_progress: Optional[Callable[[ChunkedUploadProgress], None]] = None, + log_prefix: str = "[chunked]", +) -> MediaUploadResponse: + """C2C 大文件分片上传""" + prefix = log_prefix + + # 1. 读取文件信息 + file_size = os.path.getsize(file_path) + file_name = os.path.basename(file_path) + + logger.info( + f"{prefix} Starting chunked upload: file={file_name}, size={file_size}, type={file_type}" + ) + + # 2. 计算文件哈希 + logger.debug(f"{prefix} Computing file hashes...") + hashes = await compute_file_hashes(file_path, file_size) + logger.debug(f"{prefix} File hashes: md5={hashes.md5[:16]}...") + + # 3. 申请上传 + try: + prepare_resp = await http_client.c2c_upload_prepare( + user_id, file_type, file_name, file_size, hashes + ) + except ApiError as e: + if e.biz_code == UPLOAD_PREPARE_FALLBACK_CODE: + raise UploadDailyLimitExceededError(file_path, file_size, str(e)) + raise + + upload_id = prepare_resp.upload_id + block_size = prepare_resp.block_size + parts = prepare_resp.parts + concurrency = min( + prepare_resp.concurrency or DEFAULT_CONCURRENT_PARTS, MAX_CONCURRENT_PARTS + ) + retry_timeout_ms = ( + prepare_resp.retry_timeout * 1000 if prepare_resp.retry_timeout else None + ) + + logger.info( + f"{prefix} Upload prepared: upload_id={upload_id}, block_size={block_size}, parts={len(parts)}, concurrency={concurrency}" + ) + + # 4. 并行上传所有分片 + completed_parts = 0 + uploaded_bytes = 0 + + async def upload_part(part: UploadPart) -> None: + nonlocal completed_parts, uploaded_bytes + + part_index = part.index + offset = (part_index - 1) * block_size + length = min(block_size, file_size - offset) + + # 读取分片数据 + part_data = read_file_chunk(file_path, offset, length) + part_md5 = hashlib.md5(part_data).hexdigest() + + logger.debug( + f"{prefix} Part {part_index}/{len(parts)}: uploading {length} bytes" + ) + + # PUT 到预签名 URL + await put_to_presigned_url( + part.presigned_url, part_data, prefix, part_index, len(parts) + ) + + # 通知平台分片完成(带重试) + await http_client.c2c_upload_part_finish( + user_id, upload_id, part_index, length, part_md5, retry_timeout_ms + ) + + # 更新进度 + completed_parts += 1 + uploaded_bytes += length + + if on_progress: + on_progress( + ChunkedUploadProgress( + completed_parts=completed_parts, + total_parts=len(parts), + uploaded_bytes=uploaded_bytes, + total_bytes=file_size, + ) + ) + + # 按并发数分批执行 + for i in range(0, len(parts), concurrency): + batch = parts[i : i + concurrency] + await asyncio.gather(*[upload_part(p) for p in batch]) + + logger.info(f"{prefix} All {len(parts)} parts uploaded, completing...") + + # 5. 完成文件上传 + result = await http_client.c2c_complete_upload(user_id, upload_id) + logger.info( + f"{prefix} Upload completed: file_uuid={result.file_uuid}, ttl={result.ttl}s" + ) + + return result + + +async def chunked_upload_group( + http_client: QQBotHttpClient, + group_id: str, + file_path: str, + file_type: int, + on_progress: Optional[Callable[[ChunkedUploadProgress], None]] = None, + log_prefix: str = "[chunked]", +) -> MediaUploadResponse: + """Group 大文件分片上传""" + prefix = log_prefix + + # 1. 读取文件信息 + file_size = os.path.getsize(file_path) + file_name = os.path.basename(file_path) + + logger.info( + f"{prefix} Starting chunked upload (group): file={file_name}, size={file_size}, type={file_type}" + ) + + # 2. 计算文件哈希 + logger.debug(f"{prefix} Computing file hashes...") + hashes = await compute_file_hashes(file_path, file_size) + + # 3. 申请上传 + try: + prepare_resp = await http_client.group_upload_prepare( + group_id, file_type, file_name, file_size, hashes + ) + except ApiError as e: + if e.biz_code == UPLOAD_PREPARE_FALLBACK_CODE: + raise UploadDailyLimitExceededError(file_path, file_size, str(e)) + raise + + upload_id = prepare_resp.upload_id + block_size = prepare_resp.block_size + parts = prepare_resp.parts + concurrency = min( + prepare_resp.concurrency or DEFAULT_CONCURRENT_PARTS, MAX_CONCURRENT_PARTS + ) + retry_timeout_ms = ( + prepare_resp.retry_timeout * 1000 if prepare_resp.retry_timeout else None + ) + + logger.info( + f"{prefix} Upload prepared: upload_id={upload_id}, block_size={block_size}, parts={len(parts)}" + ) + + # 4. 并行上传所有分片 + completed_parts = 0 + uploaded_bytes = 0 + + async def upload_part(part: UploadPart) -> None: + nonlocal completed_parts, uploaded_bytes + + part_index = part.index + offset = (part_index - 1) * block_size + length = min(block_size, file_size - offset) + + part_data = read_file_chunk(file_path, offset, length) + part_md5 = hashlib.md5(part_data).hexdigest() + + await put_to_presigned_url( + part.presigned_url, part_data, prefix, part_index, len(parts) + ) + await http_client.group_upload_part_finish( + group_id, upload_id, part_index, length, part_md5, retry_timeout_ms + ) + + completed_parts += 1 + uploaded_bytes += length + + if on_progress: + on_progress( + ChunkedUploadProgress( + completed_parts=completed_parts, + total_parts=len(parts), + uploaded_bytes=uploaded_bytes, + total_bytes=file_size, + ) + ) + + for i in range(0, len(parts), concurrency): + batch = parts[i : i + concurrency] + await asyncio.gather(*[upload_part(p) for p in batch]) + + logger.info(f"{prefix} All {len(parts)} parts uploaded, completing...") + + # 5. 完成文件上传 + result = await http_client.group_complete_upload(group_id, upload_id) + logger.info( + f"{prefix} Upload completed: file_uuid={result.file_uuid}, ttl={result.ttl}s" + ) + + return result diff --git a/astrbot/core/platform/sources/qqofficial/file_utils.py b/astrbot/core/platform/sources/qqofficial/file_utils.py new file mode 100644 index 0000000000..315f28cc04 --- /dev/null +++ b/astrbot/core/platform/sources/qqofficial/file_utils.py @@ -0,0 +1,111 @@ +""" +文件工具模块 +参照 openclaw-qqbot 的 file-utils.ts 实现 +""" + +import os +from typing import Optional + + +# ============ 文件类型与大小限制 ============ + + +class MediaFileType: + IMAGE = 1 + VIDEO = 2 + VOICE = 3 + FILE = 4 + + +# QQ Bot API 上传大小限制(字节)- 与 openclaw-qqbot 一致 +MAX_UPLOAD_SIZES = { + MediaFileType.IMAGE: 30 * 1024 * 1024, # 30MB + MediaFileType.VIDEO: 100 * 1024 * 1024, # 100MB + MediaFileType.VOICE: 20 * 1024 * 1024, # 20MB + MediaFileType.FILE: 100 * 1024 * 1024, # 100MB +} + +FILE_TYPE_NAMES = { + MediaFileType.IMAGE: "图片", + MediaFileType.VIDEO: "视频", + MediaFileType.VOICE: "语音", + MediaFileType.FILE: "文件", +} + + +def format_file_size(size_bytes: int) -> str: + """格式化文件大小""" + if size_bytes < 1024: + return f"{size_bytes}B" + elif size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f}KB" + elif size_bytes < 1024 * 1024 * 1024: + return f"{size_bytes / (1024 * 1024):.1f}MB" + else: + return f"{size_bytes / (1024 * 1024 * 1024):.2f}GB" + + +def get_max_upload_size(file_type: int) -> int: + """获取文件类型对应的最大上传大小""" + return MAX_UPLOAD_SIZES.get(file_type, 100 * 1024 * 1024) + + +def get_file_type_name(file_type: int) -> str: + """获取文件类型名称""" + return FILE_TYPE_NAMES.get(file_type, "文件") + + +async def file_exists_async(file_path: str) -> bool: + """异步检查文件是否存在""" + return os.path.exists(file_path) + + +async def get_file_size_async(file_path: str) -> int: + """异步获取文件大小""" + try: + return os.path.getsize(file_path) + except OSError: + return 0 + + +def is_image_file(file_path: str, mime_type: Optional[str] = None) -> bool: + """判断是否为图片文件""" + if mime_type and mime_type.startswith("image/"): + return True + + ext = os.path.splitext(file_path)[1].lower() + return ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} + + +def is_video_file(file_path: str, mime_type: Optional[str] = None) -> bool: + """判断是否为视频文件""" + if mime_type and mime_type.startswith("video/"): + return True + + ext = os.path.splitext(file_path)[1].lower() + return ext in {".mp4", ".mov", ".avi", ".mkv", ".webm", ".flv", ".wmv"} + + +def is_audio_file(file_path: str, mime_type: Optional[str] = None) -> bool: + """判断是否为音频文件""" + if mime_type and mime_type.startswith("audio/"): + return True + + ext = os.path.splitext(file_path)[1].lower() + return ext in {".mp3", ".wav", ".ogg", ".m4a", ".amr", ".silk", ".aac", ".flac"} + + +def get_file_extension(file_path: str) -> str: + """ + 获取文件扩展名(去除查询参数和 hash) + + Args: + file_path: 文件路径或 URL + + Returns: + 文件扩展名(小写,包含点号) + """ + # 去除查询参数和 hash + clean_path = file_path.split("?")[0].split("#")[0] + ext = os.path.splitext(clean_path)[1].lower() + return ext diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 97b2b2fb49..e03c27a203 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -3,7 +3,7 @@ import os import random import uuid -from typing import cast +from typing import Callable, cast, Optional, Dict, List, Tuple import aiofiles import botpy @@ -21,20 +21,39 @@ from astrbot.api.message_components import File, Image, Plain, Record, Video from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.core.utils.astrbot_path import get_astrbot_temp_path -from astrbot.core.utils.io import download_image_by_url, file_to_base64 +from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk +# 导入分片上传模块 +from .chunked_upload import ( + QQBotHttpClient, + QQBotHttpClientManager, + chunked_upload_c2c, + chunked_upload_group, + ChunkedUploadProgress, + UploadDailyLimitExceededError, + ApiError as ChunkedApiError, +) -def _patch_qq_botpy_formdata() -> None: - """Patch qq-botpy for aiohttp>=3.12 compatibility. +# 导入限流器 +from .rate_limiter import ( + MessageReplyLimiter, + check_message_reply_limit, + record_message_reply, +) - qq-botpy 1.2.1 defines botpy.http._FormData._gen_form_data() and expects - aiohttp.FormData to have a private flag named _is_processed, which is no - longer present in newer aiohttp versions. - """ +# 导入文件工具 +from .file_utils import ( + format_file_size, + get_max_upload_size, + get_file_type_name, +) + +def _patch_qq_botpy_formdata() -> None: + """Patch qq-botpy for aiohttp>=3.12 compatibility.""" try: - from botpy.http import _FormData # type: ignore + from botpy.http import _FormData if not hasattr(_FormData, "_is_processed"): setattr(_FormData, "_is_processed", False) @@ -45,13 +64,70 @@ def _patch_qq_botpy_formdata() -> None: _patch_qq_botpy_formdata() +# ============ 文本分块常量 ============ +TEXT_CHUNK_LIMIT = 2000 # QQ 单条消息文本限制 +TEXT_CHUNK_OVERLAP = 50 # 分块重叠字符数(避免句子被切断) + + +def chunk_text( + text: str, limit: int = TEXT_CHUNK_LIMIT, overlap: int = TEXT_CHUNK_OVERLAP +) -> List[str]: + """ + 将长文本分割为多个小块 + + Args: + text: 原始文本 + limit: 单块最大字符数 + overlap: 块之间重叠字符数 + + Returns: + 文本块列表 + """ + if not text or len(text) <= limit: + return [text] if text else [] + + chunks = [] + start = 0 + + while start < len(text): + end = start + limit + + if end >= len(text): + # 最后一个块 + chunks.append(text[start:]) + break + + # 尝试找到一个合适断点(换行符、句号、逗号等) + breakpoint = end + for bp in range(end - 1, max(start, end - 100), -1): + char = text[bp] + if char in "\n。.,,;;!!??": + breakpoint = bp + 1 + break + + chunk = text[start:breakpoint] + chunks.append(chunk) + + # 下一个块的起始位置(考虑重叠) + start = max(breakpoint - overlap, start + 1) + + return chunks + + class QQOfficialMessageEvent(AstrMessageEvent): MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown" IMAGE_FILE_TYPE = 1 VIDEO_FILE_TYPE = 2 VOICE_FILE_TYPE = 3 FILE_FILE_TYPE = 4 - STREAM_MARKDOWN_NEWLINE_ERROR = "流式消息md分片需要\\n结束" + STREAM_MARKDOWN_NEWLINE_ERROR = "流式消息md片段需要\\n结束" + + # 分片上传阈值:超过此大小使用分片上传 + CHUNKED_UPLOAD_THRESHOLD = 1024 * 1024 # 1MB + + # 消息回复限制配置 + MESSAGE_REPLY_LIMIT = 4 + MESSAGE_REPLY_TTL_MS = 60 * 60 * 1000 # 1小时 def __init__( self, @@ -60,42 +136,146 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, + appid: str = "", + secret: str = "", ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None + # 凭据配置 + self.appid = appid + self.secret = secret + + # 分片上传 HTTP 客户端(延迟初始化) + self._http_client: Optional[QQBotHttpClient] = None + + # 限流器实例 + self._rate_limiter = MessageReplyLimiter() + + # 临时文件跟踪(用于清理) + self._temp_files: list[str] = [] + + # 媒体上传失败的兜底 URL + self._upload_failed_media: Dict[str, str] = {} + + def set_credentials(self, appid: str, secret: str) -> None: + """设置 QQ Bot 凭据(用于分片上传)""" + self.appid = appid + self.secret = secret + + def _cleanup_temp_files(self) -> None: + """清理临时文件""" + if not self._temp_files: + return + + cleaned = 0 + for temp_file in self._temp_files: + try: + if os.path.exists(temp_file): + os.remove(temp_file) + cleaned += 1 + logger.debug(f"[QQOfficial] Cleaned temp file: {temp_file}") + except Exception as e: + logger.warning( + f"[QQOfficial] Failed to clean temp file {temp_file}: {e}" + ) + + if cleaned > 0: + logger.debug( + f"[QQOfficial] Cleaned {cleaned}/{len(self._temp_files)} temp files" + ) + + self._temp_files.clear() + + async def _get_http_client(self) -> QQBotHttpClient: + """ + 获取分片上传 HTTP 客户端 + + 使用全局管理器按 appId 隔离客户端,实现多机器人共享 Token 缓存。 + 同一 appId 的多个实例会共享同一个 HTTP 客户端和 Token。 + """ + if self._http_client is None: + if not self.appid or not self.secret: + raise RuntimeError("QQ Bot 凭据未配置 (缺少 appid 或 secret)") + # 使用全局管理器获取客户端(按 appId 隔离) + self._http_client = await QQBotHttpClientManager.get_client( + self.appid, self.secret + ) + return self._http_client + + def _check_reply_limit( + self, msg_id: str + ) -> Tuple[bool, Optional[str], Optional[str]]: + """ + 检查消息回复是否受到限流 + + Returns: + Tuple[是否使用被动回复, 降级原因, 提示信息] + """ + if not msg_id: + return (False, "no_msg_id", "无消息ID,使用主动消息") + + limit_check = check_message_reply_limit(msg_id) + + if not limit_check.allowed: + if limit_check.should_fallback_to_proactive: + return (False, limit_check.fallback_reason, limit_check.message) + + return (True, None, None) + + def _should_use_passive_reply(self, source) -> Tuple[bool, Optional[str]]: + """ + 判断是否应该使用被动回复 + + Args: + source: 消息源对象 + + Returns: + Tuple[是否使用被动回复, 降级原因] + """ + msg_id = self.message_obj.message_id + + # 频道消息和私信不支持被动回复 + if isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)): + return (False, "channel_dm_no_passive") + + # 检查限流 + use_passive, reason, hint = self._check_reply_limit(msg_id) + + if not use_passive and hint: + logger.warning(f"[QQOfficial] {hint}") + + return (use_passive, reason) + async def send(self, message: MessageChain) -> None: self.send_buffer = message await self._post_send() async def send_streaming(self, generator, use_fallback: bool = False): """流式输出仅支持消息列表私聊(C2C),其他消息源退化为普通发送""" - # 先标记事件层“已执行发送操作”,避免异常路径遗漏 await super().send_streaming(generator, use_fallback) - # QQ C2C 流式协议:开始/中间分片使用 state=1,结束分片使用 state=10 stream_payload = {"state": 1, "id": None, "index": 0, "reset": False} - last_edit_time = 0 # 上次发送分片的时间 - throttle_interval = 1 # 分片间最短间隔 (秒) + last_edit_time = 0 + throttle_interval = 1 ret = None - source = ( - self.message_obj.raw_message - ) # 提前获取,避免 generator 为空时 NameError + + # 记录初始消息源类型(用于流式结束时的判断) + original_source = self.message_obj.raw_message + is_c2c_source = isinstance(original_source, botpy.message.C2CMessage) + try: async for chain in generator: source = self.message_obj.raw_message if not isinstance(source, botpy.message.C2CMessage): - # 非 C2C 场景:直接累积,最后统一发 + # 非 C2C 消息,累积到 send_buffer if not self.send_buffer: self.send_buffer = chain else: self.send_buffer.chain.extend(chain.chain) continue - # ---- C2C 流式场景 ---- - - # tool_call break 信号:工具开始执行,先把已有 buffer 以 state=10 结束当前流式段 if chain.type == "break": if self.send_buffer: stream_payload["state"] = 10 @@ -103,7 +283,6 @@ async def send_streaming(self, generator, use_fallback: bool = False): ret_id = self._extract_response_message_id(ret) if ret_id is not None: stream_payload["id"] = ret_id - # 重置 stream_payload,为下一段流式做准备 stream_payload = { "state": 1, "id": None, @@ -113,13 +292,11 @@ async def send_streaming(self, generator, use_fallback: bool = False): last_edit_time = 0 continue - # 累积内容 if not self.send_buffer: self.send_buffer = chain else: self.send_buffer.chain.extend(chain.chain) - # 节流:按时间间隔发送中间分片 current_time = asyncio.get_running_loop().time() if current_time - last_edit_time >= throttle_interval: ret = cast( @@ -131,26 +308,29 @@ async def send_streaming(self, generator, use_fallback: bool = False): if ret_id is not None: stream_payload["id"] = ret_id last_edit_time = asyncio.get_running_loop().time() - self.send_buffer = None # 清空已发送的分片,避免下次重复发送旧内容 + self.send_buffer = None - if isinstance(source, botpy.message.C2CMessage): - # 结束流式对话,发送 buffer 中剩余内容 - stream_payload["state"] = 10 - ret = await self._post_send(stream=stream_payload) - else: - ret = await self._post_send() + # 流式消息结束处理 + if self.send_buffer: + # 使用初始消息源类型判断,而非生成器最后一个元素 + if is_c2c_source: + stream_payload["state"] = 10 + ret = await self._post_send(stream=stream_payload) + else: + # 非 C2C 消息,直接发送累积的消息 + ret = await self._post_send() except Exception as e: logger.error(f"发送流式消息时出错: {e}", exc_info=True) - # 避免累计内容在异常后被整包重复发送:仅清理缓存,不做非流式整包兜底 - # 如需兜底,应该只发送未发送 delta(后续可继续优化) self.send_buffer = None - return None + # 清理临时文件 + self._cleanup_temp_files() + + return ret @staticmethod def _extract_response_message_id(ret) -> str | None: - """兼容 qq-botpy 返回 Message 对象或 dict 两种形态。""" if ret is None: return None if isinstance(ret, dict): @@ -175,34 +355,43 @@ async def _post_send(self, stream: dict | None = None): logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}") return None + # ========== P0-2 & P0-3: 消息限流检查和自动降级 ========== + msg_id = self.message_obj.message_id + use_passive, fallback_reason = self._should_use_passive_reply(source) + + if not use_passive: + # 降级为主动消息,移除 msg_id + effective_msg_id = None + logger.info(f"[QQOfficial] 消息回复降级为主动消息,原因: {fallback_reason}") + else: + effective_msg_id = msg_id + + # ========== 解析消息内容 ========== ( plain_text, - image_base64, - image_path, + image_source, record_file_path, video_file_source, file_source, file_name, ) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) - # C2C 流式仅用于文本分片,富媒体时降级为普通发送,避免平台侧流式校验报错。 - if stream and (image_base64 or record_file_path): + # C2C 流式仅用于文本分片,富媒体时降级为普通发送 + if stream and ( + image_source or record_file_path or video_file_source or file_source + ): logger.debug("[QQOfficial] 检测到富媒体,降级为非流式发送。") stream = None if ( not plain_text - and not image_base64 - and not image_path + and not image_source and not record_file_path and not video_file_source and not file_source ): return None - # QQ C2C 流式 API 说明: - # - 开始/中间分片(state=1):增量追加内容,不需要 \n(加了会导致强制换行) - # - 最终分片(state=10):结束流,content 必须以 \n 结尾(QQ API 要求) if ( stream and stream.get("state") == 10 @@ -211,11 +400,22 @@ async def _post_send(self, stream: dict | None = None): ): plain_text = plain_text + "\n" + # ========== P1-2: 长文本分块处理 ========== + # 检查是否需要分块 + needs_chunking = len(plain_text) > TEXT_CHUNK_LIMIT if plain_text else False + text_chunks = [] + + if needs_chunking and not stream: + text_chunks = chunk_text(plain_text) + logger.info( + f"[QQOfficial] 文本长度 {len(plain_text)} 超过限制,将分 {len(text_chunks)} 块发送" + ) + + # 构建 payload(使用 effective_msg_id 而不是直接使用 message_id) payload: dict = { - # "content": plain_text, "markdown": MarkdownPayload(content=plain_text) if plain_text else None, "msg_type": 2, - "msg_id": self.message_obj.message_id, + "msg_id": effective_msg_id, # P0-3: 使用可能降级后的 msg_id } if not isinstance(source, botpy.message.Message | botpy.message.DirectMessage): @@ -223,24 +423,38 @@ async def _post_send(self, stream: dict | None = None): ret = None + # ========== P1-1 & P1-3 & P1-4: 媒体处理增强 ========== + # 媒体上传失败标记 + media_upload_failed = False + upload_error_hint = None + + + match source: case botpy.message.GroupMessage(): if not source.group_openid: logger.error("[QQOfficial] GroupMessage 缺少 group_openid") return None - if image_base64: - media = await self.upload_group_and_c2c_image( - image_base64, + if image_source: + media = await self._upload_image_enhanced( + image_source, self.IMAGE_FILE_TYPE, group_openid=source.group_openid, ) - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("markdown", None) - payload["content"] = plain_text or None - if record_file_path: # group record msg - media = await self.upload_group_and_c2c_media( + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + # P1-3: 保留文本内容,不要删除 + payload["content"] = plain_text if plain_text else None + else: + # P1-1: 媒体上传失败标记 + media_upload_failed = True + upload_error_hint = "图片" + + if record_file_path and not media_upload_failed: + media = await self._upload_media_enhanced( record_file_path, self.VOICE_FILE_TYPE, group_openid=source.group_openid, @@ -249,9 +463,14 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None - if video_file_source: - media = await self.upload_group_and_c2c_media( + payload["content"] = plain_text if plain_text else None + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "语音" + + if video_file_source and not media_upload_failed: + media = await self._upload_media_enhanced( video_file_source, self.VIDEO_FILE_TYPE, group_openid=source.group_openid, @@ -260,9 +479,15 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None - if file_source: - media = await self.upload_group_and_c2c_media( + payload["content"] = plain_text if plain_text else None + payload.pop("msg_id", None) # 视频消息不需要 msg_id + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "视频" + + if file_source and not media_upload_failed: + media = await self._upload_media_enhanced( file_source, self.FILE_FILE_TYPE, file_name=file_name, @@ -272,10 +497,24 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None + payload["content"] = plain_text if plain_text else None + payload.pop("msg_id", None) # 文件消息不需要 msg_id + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "文件" + + # P1-1: 如果有文本内容且媒体上传失败,添加提示 + if media_upload_failed and plain_text: + hint = f"[提示: {upload_error_hint}发送失败]" + if not plain_text.endswith(hint): + payload["content"] = plain_text + "\n" + hint + payload["msg_type"] = 0 + payload.pop("markdown", None) + ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.bot.api.post_group_message( - group_openid=source.group_openid, # type: ignore + group_openid=source.group_openid, **retry_payload, ), payload=payload, @@ -283,19 +522,29 @@ async def _post_send(self, stream: dict | None = None): stream=stream, ) + # P0-2: 记录消息回复(如果使用了被动回复) + if use_passive and effective_msg_id: + record_message_reply(effective_msg_id) + case botpy.message.C2CMessage(): - if image_base64: - media = await self.upload_group_and_c2c_image( - image_base64, + if image_source: + media = await self._upload_image_enhanced( + image_source, self.IMAGE_FILE_TYPE, openid=source.author.user_openid, ) - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("markdown", None) - payload["content"] = plain_text or None - if record_file_path: # c2c record - media = await self.upload_group_and_c2c_media( + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("markdown", None) + # P1-3: 保留文本内容 + payload["content"] = plain_text if plain_text else None + else: + media_upload_failed = True + upload_error_hint = "图片" + + if record_file_path and not media_upload_failed: + media = await self._upload_media_enhanced( record_file_path, self.VOICE_FILE_TYPE, openid=source.author.user_openid, @@ -304,9 +553,14 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None - if video_file_source: - media = await self.upload_group_and_c2c_media( + payload["content"] = plain_text if plain_text else None + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "语音" + + if video_file_source and not media_upload_failed: + media = await self._upload_media_enhanced( video_file_source, self.VIDEO_FILE_TYPE, openid=source.author.user_openid, @@ -315,9 +569,14 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None - if file_source: - media = await self.upload_group_and_c2c_media( + payload["content"] = plain_text if plain_text else None + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "视频" + + if file_source and not media_upload_failed: + media = await self._upload_media_enhanced( file_source, self.FILE_FILE_TYPE, file_name=file_name, @@ -327,10 +586,61 @@ async def _post_send(self, stream: dict | None = None): payload["media"] = media payload["msg_type"] = 7 payload.pop("markdown", None) - payload["content"] = plain_text or None - if stream: + payload["content"] = plain_text if plain_text else None + else: + media_upload_failed = True + if not upload_error_hint: + upload_error_hint = "文件" + + # P1-1: 如果有文本内容且媒体上传失败,添加提示 + if media_upload_failed and plain_text: + hint = f"[提示: {upload_error_hint}发送失败]" + if not plain_text.endswith(hint): + payload["content"] = plain_text + "\n" + hint + payload["msg_type"] = 0 + payload.pop("markdown", None) + + # P1-2: 分块发送(如果有多个文本块) + if text_chunks and len(text_chunks) > 1: + logger.info(f"[QQOfficial] 开始分块发送 {len(text_chunks)} 条消息") + for i, chunk_text in enumerate(text_chunks): + chunk_payload = payload.copy() + chunk_payload["msg_id"] = effective_msg_id if i == 0 else None + chunk_payload["markdown"] = MarkdownPayload(content=chunk_text) + chunk_payload["content"] = chunk_text + chunk_payload["msg_type"] = 2 + + try: + ret = await self._send_with_markdown_fallback( + send_func=lambda p: self.post_c2c_message( + openid=source.author.user_openid, + **p, + ), + payload=chunk_payload, + plain_text=chunk_text, + stream=None, + ) + logger.debug( + f"[QQOfficial] 块 {i + 1}/{len(text_chunks)} 发送成功" + ) + + # 记录被动回复 + if i == 0 and use_passive and effective_msg_id: + record_message_reply(effective_msg_id) + + # 避免发送过快 + if i < len(text_chunks) - 1: + await asyncio.sleep(0.5) + except Exception as e: + logger.error(f"[QQOfficial] 块 {i + 1} 发送失败: {e}") + # 继续发送其他块 + + self.send_buffer = None + return ret + elif stream: ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.post_c2c_message( + self.bot, openid=source.author.user_openid, **retry_payload, stream=stream, @@ -342,6 +652,7 @@ async def _post_send(self, stream: dict | None = None): else: ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.post_c2c_message( + self.bot, openid=source.author.user_openid, **retry_payload, ), @@ -349,12 +660,16 @@ async def _post_send(self, stream: dict | None = None): plain_text=plain_text, stream=stream, ) + + # P0-2: 记录消息回复(如果使用了被动回复) + if use_passive and effective_msg_id: + record_message_reply(effective_msg_id) + logger.debug(f"Message sent to C2C: {ret}") case botpy.message.Message(): - if image_path: - payload["file_image"] = image_path - # Guild text-channel send API (/channels/{channel_id}/messages) does not use v2 msg_type. + if image_source and os.path.exists(image_source): + payload["file_image"] = image_source payload.pop("msg_type", None) ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.bot.api.post_message( @@ -367,9 +682,8 @@ async def _post_send(self, stream: dict | None = None): ) case botpy.message.DirectMessage(): - if image_path: - payload["file_image"] = image_path - # Guild DM send API (/dms/{guild_id}/messages) does not use v2 msg_type. + if image_source and os.path.exists(image_source): + payload["file_image"] = image_source payload.pop("msg_type", None) ret = await self._send_with_markdown_fallback( send_func=lambda retry_payload: self.bot.api.post_dms( @@ -385,11 +699,342 @@ async def _post_send(self, stream: dict | None = None): pass await super().send(self.send_buffer) - self.send_buffer = None + # 清理临时文件 + self._cleanup_temp_files() + return ret + async def _upload_image_enhanced( + self, + image_source: str, + file_type: int, + **kwargs, + ) -> botpy.types.message.Media | None: + """ + 增强版图片上传:根据文件大小自动选择 base64 直传或分片上传 + P1-4: 完善 URL 下载的错误处理 + """ + # 判断文件大小 + file_path = None + file_size = 0 + download_error = None + + try: + if os.path.exists(image_source): + file_path = image_source + file_size = os.path.getsize(file_path) + elif image_source.startswith("http"): + # P1-4: URL 图片:先下载,增加错误处理 + try: + file_path = await download_image_by_url(image_source) + if file_path and os.path.exists(file_path): + file_size = os.path.getsize(file_path) + else: + download_error = "下载图片失败" + logger.error(f"[QQOfficial] 下载图片失败: {image_source}") + except Exception as e: + download_error = f"下载图片出错: {str(e)}" + logger.error(f"[QQOfficial] 下载图片异常: {e}") + elif image_source.startswith("base64://"): + # Base64 数据,保存为临时文件 + try: + b64_data = image_source[9:] + temp_dir = get_astrbot_temp_path() + temp_path = os.path.join( + temp_dir, f"qqofficial_{uuid.uuid4().hex}.png" + ) + with open(temp_path, "wb") as f: + f.write(base64.b64decode(b64_data)) + file_path = temp_path + self._temp_files.append(temp_path) + file_size = os.path.getsize(file_path) + except Exception as e: + download_error = f"解析 Base64 图片失败: {str(e)}" + logger.error(f"[QQOfficial] 解析 Base64 图片异常: {e}") + else: + download_error = f"不支持的图片来源: {image_source[:50]}..." + logger.warning(f"[QQOfficial] {download_error}") + + # 如果下载失败但有 URL,记录用于兜底 + if download_error and image_source.startswith("http"): + self._upload_failed_media["image"] = image_source + logger.debug( + f"[QQOfficial] 保存图片 URL 用于兜底: {image_source[:50]}..." + ) + except Exception as e: + logger.error(f"[QQOfficial] 处理图片文件时出错: {e}") + return None + + # 检查文件大小限制 + max_size = get_max_upload_size(file_type) + if file_size > max_size: + type_name = get_file_type_name(file_type) + size_mb = file_size / (1024 * 1024) + limit_mb = max_size / (1024 * 1024) + logger.error( + f"[QQOfficial] {type_name}过大({size_mb:.1f}MB),超过{limit_mb:.0f}MB限制" + ) + return None + + # 始终使用分片上传(与 openclaw-qqbot 行为一致) + # openclaw-qqbot 不使用 base64 上传,所有图片都通过分片上传 + if file_path and os.path.exists(file_path): + return await self._chunked_upload( + file_path, + file_type, + openid=kwargs.get("openid"), + group_openid=kwargs.get("group_openid"), + on_progress=kwargs.get("on_progress"), + ) + else: + logger.error( + f"[QQOfficial] 图片文件不存在: {image_source[:50] if image_source else 'None'}..." + ) + return None + + async def _upload_media_enhanced( + self, + file_source: str, + file_type: int, + srv_send_msg: bool = False, + file_name: str | None = None, + **kwargs, + ) -> Media | None: + """ + 增强版媒体上传:始终使用分片上传(与 openclaw-qqbot 行为一致) + """ + file_path = None + file_size = 0 + + if os.path.exists(file_source): + file_path = file_source + file_size = os.path.getsize(file_path) + else: + # URL 或其他来源 - 记录用于兜底 + if file_source.startswith("http"): + self._upload_failed_media[f"media_{file_type}"] = file_source + file_size = 0 + + # 检查文件大小限制 + max_size = get_max_upload_size(file_type) + if file_size > max_size: + type_name = get_file_type_name(file_type) + size_mb = file_size / (1024 * 1024) + limit_mb = max_size / (1024 * 1024) + logger.error( + f"[QQOfficial] {type_name}过大({size_mb:.1f}MB),超过{limit_mb:.0f}MB限制" + ) + return None + + # 始终使用分片上传(与 openclaw-qqbot 行为一致) + if file_path: + return await self._chunked_upload( + file_path, + file_type, + openid=kwargs.get("openid"), + group_openid=kwargs.get("group_openid"), + on_progress=kwargs.get("on_progress"), + ) + else: + logger.error(f"[QQOfficial] 媒体文件不存在: {file_source}") + return None + + async def _chunked_upload( + self, + file_path: str, + file_type: int, + openid: Optional[str] = None, + group_openid: Optional[str] = None, + on_progress: Optional[Callable[[ChunkedUploadProgress], None]] = None, + ) -> Media | None: + """ + 分片上传(大文件) + + Args: + file_path: 文件路径 + file_type: 文件类型(1=图片, 2=视频, 3=语音, 4=文件) + openid: 用户 openid(C2C 目标) + group_openid: 群 openid(Group 目标) + on_progress: 进度回调函数 + + Returns: + Media 对象,失败返回 None + """ + + # 创建默认进度回调(类似 TypeScript 版本的日志) + def default_progress_callback(progress: ChunkedUploadProgress) -> None: + file_type_name = get_file_type_name(file_type) + logger.debug( + f"[QQOfficial] chunked upload progress: " + f"{progress.completed_parts}/{progress.total_parts} parts, " + f"{format_file_size(progress.uploaded_bytes)}/{format_file_size(progress.total_bytes)}" + ) + + # 使用传入的回调或默认回调 + progress_callback = ( + on_progress if on_progress is not None else default_progress_callback + ) + + try: + http_client = await self._get_http_client() + log_prefix = "[QQOfficial:chunked]" + file_name = os.path.basename(file_path) + file_size = os.path.getsize(file_path) + + # 判断目标是 C2C 还是 Group + if openid: + logger.info( + f"{log_prefix} Starting C2C chunked upload: " + f"file={file_name}, size={format_file_size(file_size)}, type={file_type}" + ) + result = await chunked_upload_c2c( + http_client, + openid, + file_path, + file_type, + on_progress=progress_callback, + log_prefix=log_prefix, + ) + elif group_openid: + logger.info( + f"{log_prefix} Starting group chunked upload: " + f"file={file_name}, size={format_file_size(file_size)}, type={file_type}" + ) + result = await chunked_upload_group( + http_client, + group_openid, + file_path, + file_type, + on_progress=progress_callback, + log_prefix=log_prefix, + ) + else: + raise ValueError( + "Invalid upload parameters: must provide openid or group_openid" + ) + + return Media( + file_uuid=result.file_uuid, + file_info=result.file_info, + ttl=result.ttl, + ) + + except UploadDailyLimitExceededError as e: + # P1-1: 每日上传限额超限 + logger.error(f"[QQOfficial] 每日上传限额超限: {e}") + return None + except ChunkedApiError as e: + # P1-1: API 错误处理 + logger.error(f"[QQOfficial] 分片上传 API 错误: {e}") + return None + except Exception as e: + logger.error(f"[QQOfficial] 分片上传失败: {e}", exc_info=True) + return None + + async def _base64_upload( + self, + file_source: str, + file_type: int, + srv_send_msg: bool = False, + file_name: str | None = None, + **kwargs, + ) -> Media | None: + """ + Base64 直传(小文件)- 使用自定义 HTTP 客户端确保超时配置 + + 使用 QQBotHttpClient 的 base64_upload 方法,该方法配置了 120 秒超时, + 相比 botpy 默认超时更适合文件上传场景。 + """ + # 处理文件数据 + file_data = None + if os.path.exists(file_source): + try: + async with aiofiles.open(file_source, "rb") as f: + file_content = await f.read() + file_data = base64.b64encode(file_content).decode("utf-8") + except Exception as e: + logger.error(f"[QQOfficial] 读取文件失败: {e}") + return None + elif file_source.startswith("http"): + # 对于 URL,使用 botpy 的方式(因为 URL 上传不需要 base64 编码) + pass # 降级到原有逻辑 + else: + logger.error(f"[QQOfficial] 不支持的图片来源: {file_source[:50]}...") + return None + + # 如果是 URL,降级到原有逻辑 + if file_data is None: + payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} + if file_name: + payload["file_name"] = file_name + payload["url"] = file_source + + if "openid" in kwargs: + payload["openid"] = kwargs["openid"] + route = Route( + "POST", "/v2/users/{openid}/files", openid=kwargs["openid"] + ) + elif "group_openid" in kwargs: + payload["group_openid"] = kwargs["group_openid"] + route = Route( + "POST", + "/v2/groups/{group_openid}/files", + group_openid=kwargs["group_openid"], + ) + else: + return None + + try: + result = await self.bot.api._http.request(route, json=payload) + if result and isinstance(result, dict): + return Media( + file_uuid=result["file_uuid"], + file_info=result["file_info"], + ttl=result.get("ttl", 0), + ) + except Exception as e: + logger.error(f"[QQOfficial] URL上传请求错误: {e}") + return None + + # 使用自定义 HTTP 客户端上传 + try: + http_client = await self._get_http_client() + + if "openid" in kwargs: + result = await http_client.base64_upload( + file_type=file_type, + file_data=file_data, + file_name=file_name, + srv_send_msg=srv_send_msg, + target_type="c2c", + target_id=kwargs["openid"], + ) + elif "group_openid" in kwargs: + result = await http_client.base64_upload( + file_type=file_type, + file_data=file_data, + file_name=file_name, + srv_send_msg=srv_send_msg, + target_type="group", + target_id=kwargs["group_openid"], + ) + else: + return None + + return Media( + file_uuid=result.file_uuid, + file_info=result.file_info, + ttl=result.ttl, + ) + except ChunkedApiError as e: + logger.error(f"[QQOfficial] Base64上传 API 错误: {e}") + except Exception as e: + logger.error(f"[QQOfficial] Base64上传失败: {e}", exc_info=True) + + return None + async def _send_with_markdown_fallback( self, send_func, @@ -400,8 +1045,6 @@ async def _send_with_markdown_fallback( try: return await send_func(payload) except botpy.errors.ServerError as err: - # QQ 流式 markdown 分片校验:内容必须以换行结尾。 - # 某些边界场景服务端仍可能判定失败,这里做一次修正重试。 if stream and self.STREAM_MARKDOWN_NEWLINE_ERROR in str(err): retry_payload = payload.copy() @@ -441,105 +1084,66 @@ async def _send_with_markdown_fallback( fallback_payload["content"] = fallback_content + "\n" return await send_func(fallback_payload) + @staticmethod async def upload_group_and_c2c_image( - self, - image_base64: str, + send_helper, + image_source: str, file_type: int, **kwargs, ) -> botpy.types.message.Media: - payload = { - "file_data": image_base64, - "file_type": file_type, - "srv_send_msg": False, - } - - result = None - if "openid" in kwargs: - payload["openid"] = kwargs["openid"] - route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) - result = await self.bot.api._http.request(route, json=payload) - elif "group_openid" in kwargs: - payload["group_openid"] = kwargs["group_openid"] - route = Route( - "POST", - "/v2/groups/{group_openid}/files", - group_openid=kwargs["group_openid"], - ) - result = await self.bot.api._http.request(route, json=payload) - else: - raise ValueError("Invalid upload parameters") + """兼容旧接口:上传图片 - if not isinstance(result, dict): - raise RuntimeError( - f"Failed to upload image, response is not dict: {result}" - ) - - return Media( - file_uuid=result["file_uuid"], - file_info=result["file_info"], - ttl=result.get("ttl", 0), + Args: + send_helper: 发送辅助对象(包含 bot 属性) + image_source: 图片来源,可以是文件路径、URL 或 base64:// 数据 + """ + bot = getattr(send_helper, "bot", send_helper) + event = QQOfficialMessageEvent.__new__(QQOfficialMessageEvent) + event.bot = bot + event._http_client = None + event._temp_files = [] + event._upload_failed_media = {} + appid = getattr(bot, "_appid", "") or getattr(bot, "appid", "") + secret = getattr(bot, "_secret", "") or getattr(bot, "secret", "") + event.appid = appid + event.secret = secret + return await event._upload_image_enhanced( + image_source, + file_type, + **kwargs, ) + @staticmethod async def upload_group_and_c2c_media( - self, + send_helper, file_source: str, file_type: int, srv_send_msg: bool = False, file_name: str | None = None, **kwargs, ) -> Media | None: - """上传媒体文件""" - # 构建基础payload - payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} - if file_name: - payload["file_name"] = file_name - - # 处理文件数据 - if os.path.exists(file_source): - # 读取本地文件 - async with aiofiles.open(file_source, "rb") as f: - file_content = await f.read() - # use base64 encode - payload["file_data"] = base64.b64encode(file_content).decode("utf-8") - else: - # 使用URL - payload["url"] = file_source - - # 添加接收者信息和确定路由 - if "openid" in kwargs: - payload["openid"] = kwargs["openid"] - route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) - elif "group_openid" in kwargs: - payload["group_openid"] = kwargs["group_openid"] - route = Route( - "POST", - "/v2/groups/{group_openid}/files", - group_openid=kwargs["group_openid"], - ) - else: - return None - - try: - # 使用底层HTTP请求 - result = await self.bot.api._http.request(route, json=payload) - - if result: - if not isinstance(result, dict): - logger.error(f"上传文件响应格式错误: {result}") - return None - - return Media( - file_uuid=result["file_uuid"], - file_info=result["file_info"], - ttl=result.get("ttl", 0), - ) - except Exception as e: - logger.error(f"上传请求错误: {e}") - - return None + """兼容旧接口:上传媒体""" + bot = getattr(send_helper, "bot", send_helper) + event = QQOfficialMessageEvent.__new__(QQOfficialMessageEvent) + event.bot = bot + event._http_client = None + event._temp_files = [] + event._upload_failed_media = {} + appid = getattr(bot, "_appid", "") or getattr(bot, "appid", "") + secret = getattr(bot, "_secret", "") or getattr(bot, "secret", "") + event.appid = appid + event.secret = secret + return await event._upload_media_enhanced( + file_source, + file_type, + srv_send_msg, + file_name, + **kwargs, + ) + @staticmethod async def post_c2c_message( - self, + send_helper, openid: str, msg_type: int = 0, content: str | None = None, @@ -554,16 +1158,27 @@ async def post_c2c_message( keyboard: message.Keyboard | None = None, stream: dict | None = None, ) -> message.Message: - payload = locals() - payload.pop("self", None) - # QQ API does not accept stream.id=None; remove it when not yet assigned + bot = getattr(send_helper, "bot", send_helper) + payload = { + "msg_type": msg_type, + "content": content, + "embed": embed, + "ark": ark, + "message_reference": message_reference, + "media": media, + "msg_id": msg_id, + "msg_seq": msg_seq, + "event_id": event_id, + "markdown": markdown, + "keyboard": keyboard, + } if "stream" in payload and payload["stream"] is not None: stream_data = dict(payload["stream"]) if stream_data.get("id") is None: stream_data.pop("id", None) payload["stream"] = stream_data route = Route("POST", "/v2/users/{openid}/messages", openid=openid) - result = await self.bot.api._http.request(route, json=payload) + result = await bot.api._http.request(route, json=payload) if result is None: logger.warning("[QQOfficial] post_c2c_message: API 返回 None,跳过本次发送") @@ -577,55 +1192,67 @@ async def post_c2c_message( @staticmethod async def _parse_to_qqofficial(message: MessageChain): plain_text = "" - image_base64 = None # only one img supported - image_file_path = None + image_source = None # 图片来源(路径或 URL) record_file_path = None video_file_source = None file_source = None file_name = None + for i in message.chain: if isinstance(i, Plain): plain_text += i.text - elif isinstance(i, Image) and not image_base64: + elif isinstance(i, Image) and not image_source: if i.file and i.file.startswith("file:///"): - image_base64 = file_to_base64(i.file[8:]) - image_file_path = i.file[8:] + image_source = i.file[8:] elif i.file and i.file.startswith("http"): - image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path) + image_source = i.file # P1-4: 保留 URL 供后续处理 elif i.file and i.file.startswith("base64://"): - image_base64 = i.file + # Base64 数据,保存为临时文件 + b64_data = i.file[9:] + temp_dir = get_astrbot_temp_path() + temp_path = os.path.join( + temp_dir, f"qqofficial_{uuid.uuid4().hex}.png" + ) + try: + with open(temp_path, "wb") as f: + f.write(base64.b64decode(b64_data)) + image_source = temp_path + except Exception as e: + logger.error(f"[QQOfficial] 保存 Base64 图片失败: {e}") + image_source = i.file # 保留原始数据 elif i.file: - image_base64 = file_to_base64(i.file) + image_source = i.file else: raise ValueError("Unsupported image file format") - image_base64 = image_base64.removeprefix("base64://") + elif isinstance(i, Record): if i.file: - record_wav_path = await i.convert_to_file_path() # wav 路径 + record_wav_path = await i.convert_to_file_path() temp_dir = get_astrbot_temp_path() - record_tecent_silk_path = os.path.join( + record_silk_path = os.path.join( temp_dir, f"qqofficial_{uuid.uuid4()}.silk", ) try: duration = await wav_to_tencent_silk( record_wav_path, - record_tecent_silk_path, + record_silk_path, ) if duration > 0: - record_file_path = record_tecent_silk_path + record_file_path = record_silk_path else: record_file_path = None logger.error("转换音频格式时出错:音频时长不大于0") except Exception as e: logger.error(f"处理语音时出错: {e}") record_file_path = None + elif isinstance(i, Video) and not video_file_source: if i.file.startswith("file:///"): video_file_source = i.file[8:] else: video_file_source = i.file + elif isinstance(i, File) and not file_source: file_name = i.name if i.file_: @@ -637,12 +1264,13 @@ async def _parse_to_qqofficial(message: MessageChain): file_source = file_path elif i.url: file_source = i.url + else: logger.debug(f"qq_official 忽略 {i.type}") + return ( plain_text, - image_base64, - image_file_path, + image_source, record_file_path, video_file_source, file_source, diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 3037ab2d8d..67b02502f1 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -1,597 +1,553 @@ -from __future__ import annotations - -import asyncio -import logging -import os -import random -import time -import uuid -from pathlib import Path -from types import SimpleNamespace -from typing import Any, cast - -import botpy -import botpy.message -from botpy import Client -from botpy.gateway import BotWebSocket - -from astrbot import logger -from astrbot.api.event import MessageChain -from astrbot.api.message_components import At, File, Image, Plain, Record, Video -from astrbot.api.platform import ( - AstrBotMessage, - MessageMember, - MessageType, - Platform, - PlatformMetadata, -) -from astrbot.core.message.components import BaseMessageComponent -from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path -from astrbot.core.utils.io import download_file - -from ...register import register_platform_adapter -from .qqofficial_message_event import QQOfficialMessageEvent - -# remove logger handler -for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - - -class ManagedBotWebSocket(BotWebSocket): - def __init__(self, session, connection: Any, client: botClient): - super().__init__(session, connection) - self._client = client - - async def on_closed(self, close_status_code, close_msg): - if self._client.is_shutting_down: - logger.debug("[QQOfficial] Ignore websocket reconnect during shutdown.") - return - await super().on_closed(close_status_code, close_msg) - - async def close(self) -> None: - self._can_reconnect = False - if self._conn is not None and not self._conn.closed: - await self._conn.close() - - -# QQ 机器人官方框架 -class botClient(Client): - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._shutting_down = False - self._active_websockets: set[ManagedBotWebSocket] = set() - - def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: - self.platform = platform - - @property - def is_shutting_down(self) -> bool: - return self._shutting_down or self.is_closed() - - # 收到群消息 - async def on_group_at_message_create( - self, message: botpy.message.GroupMessage - ) -> None: - abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( - message, - MessageType.GROUP_MESSAGE, - ) - abm.group_id = cast(str, message.group_openid) - abm.session_id = abm.group_id - self.platform.remember_session_scene(abm.session_id, "group") - self._commit(abm) - - # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message) -> None: - abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( - message, - MessageType.GROUP_MESSAGE, - ) - abm.group_id = message.channel_id - abm.session_id = abm.group_id - self.platform.remember_session_scene(abm.session_id, "channel") - self._commit(abm) - - # 收到私聊消息 - async def on_direct_message_create( - self, message: botpy.message.DirectMessage - ) -> None: - abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( - message, - MessageType.FRIEND_MESSAGE, - ) - abm.session_id = abm.sender.user_id - self.platform.remember_session_scene(abm.session_id, "friend") - self._commit(abm) - - # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: - abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( - message, - MessageType.FRIEND_MESSAGE, - ) - abm.session_id = abm.sender.user_id - self.platform.remember_session_scene(abm.session_id, "friend") - self._commit(abm) - - def _commit(self, abm: AstrBotMessage) -> None: - self.platform.remember_session_message_id(abm.session_id, abm.message_id) - self.platform.commit_event( - QQOfficialMessageEvent( - abm.message_str, - abm, - self.platform.meta(), - abm.session_id, - self.platform.client, - ), - ) - - async def bot_connect(self, session) -> None: - logger.info("[QQOfficial] Websocket session starting.") - - websocket = ManagedBotWebSocket(session, self._connection, self) - self._active_websockets.add(websocket) - try: - await websocket.ws_connect() - except Exception as e: - if not self.is_shutting_down: - await websocket.on_error(e) - finally: - self._active_websockets.discard(websocket) - - async def shutdown(self) -> None: - if self.is_shutting_down: - return - - self._shutting_down = True - await asyncio.gather( - *(websocket.close() for websocket in list(self._active_websockets)), - return_exceptions=True, - ) - await self.close() - - -@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器") -class QQOfficialPlatformAdapter(Platform): - def __init__( - self, - platform_config: dict, - platform_settings: dict, - event_queue: asyncio.Queue, - ) -> None: - super().__init__(platform_config, event_queue) - - self.appid = platform_config["appid"] - self.secret = platform_config["secret"] - qq_group = platform_config["enable_group_c2c"] - guild_dm = platform_config["enable_guild_direct_message"] - - if qq_group: - self.intents = botpy.Intents( - public_messages=True, - public_guild_messages=True, - direct_message=guild_dm, - ) - else: - self.intents = botpy.Intents( - public_guild_messages=True, - direct_message=guild_dm, - ) - self.client = botClient( - intents=self.intents, - bot_log=False, - timeout=20, - ) - - self.client.set_platform(self) - - self._session_last_message_id: dict[str, str] = {} - self._session_scene: dict[str, str] = {} - - self.test_mode = os.environ.get("TEST_MODE", "off") == "on" - - async def send_by_session( - self, - session: MessageSesion, - message_chain: MessageChain, - ) -> None: - await self._send_by_session_common(session, message_chain) - - async def _send_by_session_common( - self, - session: MessageSesion, - message_chain: MessageChain, - ) -> None: - ( - plain_text, - image_base64, - image_path, - record_file_path, - video_file_source, - file_source, - file_name, - ) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain) - if ( - not plain_text - and not image_path - and not image_base64 - and not record_file_path - and not video_file_source - and not file_source - ): - return - - msg_id = self._session_last_message_id.get(session.session_id) - if not msg_id: - logger.warning( - "[QQOfficial] No cached msg_id for session: %s, skip send_by_session", - session.session_id, - ) - return - - payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id} - ret: Any = None - send_helper = SimpleNamespace(bot=self.client) - - if session.message_type == MessageType.GROUP_MESSAGE: - scene = self._session_scene.get(session.session_id) - if scene == "group": - payload["msg_seq"] = random.randint(1, 10000) - if image_base64: - media = await QQOfficialMessageEvent.upload_group_and_c2c_image( - send_helper, # type: ignore - image_base64, - QQOfficialMessageEvent.IMAGE_FILE_TYPE, - group_openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - if record_file_path: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - record_file_path, - QQOfficialMessageEvent.VOICE_FILE_TYPE, - group_openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - if video_file_source: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - video_file_source, - QQOfficialMessageEvent.VIDEO_FILE_TYPE, - group_openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("msg_id", None) - if file_source: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - file_source, - QQOfficialMessageEvent.FILE_FILE_TYPE, - file_name=file_name, - group_openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - payload.pop("msg_id", None) - ret = await self.client.api.post_group_message( - group_openid=session.session_id, - **payload, - ) - else: - if image_path: - payload["file_image"] = image_path - ret = await self.client.api.post_message( - channel_id=session.session_id, - **payload, - ) - - elif session.message_type == MessageType.FRIEND_MESSAGE: - # 参考 https://bot.q.qq.com/wiki/develop/pythonsdk/api/message/post_message.html - # msg_id 缺失时认为是主动推送,而似乎至少在私聊上主动推送是没有被限制的,这里直接移除 msg_id 可以避免越权或 msg_id 不可用的bug - payload.pop("msg_id", None) - payload["msg_seq"] = random.randint(1, 10000) - if image_base64: - media = await QQOfficialMessageEvent.upload_group_and_c2c_image( - send_helper, # type: ignore - image_base64, - QQOfficialMessageEvent.IMAGE_FILE_TYPE, - openid=session.session_id, - ) - payload["media"] = media - payload["msg_type"] = 7 - if record_file_path: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - record_file_path, - QQOfficialMessageEvent.VOICE_FILE_TYPE, - openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - if video_file_source: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - video_file_source, - QQOfficialMessageEvent.VIDEO_FILE_TYPE, - openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - if file_source: - media = await QQOfficialMessageEvent.upload_group_and_c2c_media( - send_helper, # type: ignore - file_source, - QQOfficialMessageEvent.FILE_FILE_TYPE, - file_name=file_name, - openid=session.session_id, - ) - if media: - payload["media"] = media - payload["msg_type"] = 7 - - ret = await QQOfficialMessageEvent.post_c2c_message( - send_helper, # type: ignore - openid=session.session_id, - **payload, - ) - else: - logger.warning( - "[QQOfficial] Unsupported message type for send_by_session: %s", - session.message_type, - ) - return - - sent_message_id = self._extract_message_id(ret) - if sent_message_id: - self.remember_session_message_id(session.session_id, sent_message_id) - await super().send_by_session(session, message_chain) - - def remember_session_message_id(self, session_id: str, message_id: str) -> None: - if not session_id or not message_id: - return - self._session_last_message_id[session_id] = message_id - - def remember_session_scene(self, session_id: str, scene: str) -> None: - if not session_id or not scene: - return - self._session_scene[session_id] = scene - - def _extract_message_id(self, ret: Any) -> str | None: - if isinstance(ret, dict): - message_id = ret.get("id") - return str(message_id) if message_id else None - message_id = getattr(ret, "id", None) - if message_id: - return str(message_id) - return None - - def meta(self) -> PlatformMetadata: - return PlatformMetadata( - name="qq_official", - description="QQ 机器人官方 API 适配器", - id=cast(str, self.config.get("id")), - support_proactive_message=True, - ) - - @staticmethod - def _normalize_attachment_url(url: str | None) -> str: - if not url: - return "" - if url.startswith("http://") or url.startswith("https://"): - return url - return f"https://{url}" - - @staticmethod - async def _prepare_audio_attachment( - url: str, - filename: str, - ) -> Record: - temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) - - ext = Path(filename).suffix.lower() - source_ext = ext or ".audio" - source_path = temp_dir / f"qqofficial_{uuid.uuid4().hex}{source_ext}" - await download_file(url, str(source_path)) - - return Record(file=str(source_path), url=str(source_path)) - - @staticmethod - async def _append_attachments( - msg: list[BaseMessageComponent], - attachments: list | None, - ) -> None: - if not attachments: - return - - for attachment in attachments: - content_type = cast( - str, - getattr(attachment, "content_type", "") or "", - ).lower() - url = QQOfficialPlatformAdapter._normalize_attachment_url( - cast(str | None, getattr(attachment, "url", None)) - ) - if not url: - continue - - if content_type.startswith("image"): - msg.append(Image.fromURL(url)) - else: - filename = cast( - str, - getattr(attachment, "filename", None) - or getattr(attachment, "name", None) - or "attachment", - ) - ext = Path(filename).suffix.lower() - image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} - audio_exts = { - ".mp3", - ".wav", - ".ogg", - ".m4a", - ".amr", - ".silk", - } - video_exts = { - ".mp4", - ".mov", - ".avi", - ".mkv", - ".webm", - } - - if content_type.startswith("voice") or ext in audio_exts: - try: - msg.append( - await QQOfficialPlatformAdapter._prepare_audio_attachment( - url, - filename, - ) - ) - except Exception as e: - logger.warning( - "[QQOfficial] Failed to prepare audio attachment %s: %s", - url, - e, - ) - msg.append(Record.fromURL(url)) - elif content_type.startswith("video") or ext in video_exts: - msg.append(Video.fromURL(url)) - elif content_type.startswith("image") or ext in image_exts: - msg.append(Image.fromURL(url)) - else: - msg.append(File(name=filename, file=url, url=url)) - - @staticmethod - def _parse_face_message(content: str) -> str: - """Parse QQ official face message format and convert to readable text. - - QQ official face message format: - - - The ext field contains base64-encoded JSON with a 'text' field - describing the emoji (e.g., '[满头问号]'). - - Args: - content: The message content that may contain face tags. - - Returns: - Content with face tags replaced by readable emoji descriptions. - """ - import base64 - import json - import re - - def replace_face(match): - face_tag = match.group(0) - # Extract ext field from the face tag - ext_match = re.search(r'ext="([^"]*)"', face_tag) - if ext_match: - try: - ext_encoded = ext_match.group(1) - # Decode base64 and parse JSON - ext_decoded = base64.b64decode(ext_encoded).decode("utf-8") - ext_data = json.loads(ext_decoded) - emoji_text = ext_data.get("text", "") - if emoji_text: - return f"[表情:{emoji_text}]" - except Exception: - pass - # Fallback if parsing fails - return "[表情]" - - # Match face tags: - return re.sub(r"]*>", replace_face, content) - - @staticmethod - async def _parse_from_qqofficial( - message: botpy.message.Message - | botpy.message.GroupMessage - | botpy.message.DirectMessage - | botpy.message.C2CMessage, - message_type: MessageType, - ) -> AstrBotMessage: - abm = AstrBotMessage() - abm.type = message_type - abm.timestamp = int(time.time()) - abm.raw_message = message - abm.message_id = message.id - # abm.tag = "qq_official" - msg: list[BaseMessageComponent] = [] - - if isinstance(message, botpy.message.GroupMessage) or isinstance( - message, - botpy.message.C2CMessage, - ): - if isinstance(message, botpy.message.GroupMessage): - abm.sender = MessageMember(message.author.member_openid, "") - abm.group_id = message.group_openid - else: - abm.sender = MessageMember(message.author.user_openid, "") - # Parse face messages to readable text - abm.message_str = QQOfficialPlatformAdapter._parse_face_message( - message.content.strip() - ) - abm.self_id = "unknown_selfid" - msg.append(At(qq="qq_official")) - msg.append(Plain(abm.message_str)) - await QQOfficialPlatformAdapter._append_attachments( - msg, message.attachments - ) - abm.message = msg - - elif isinstance(message, botpy.message.Message) or isinstance( - message, - botpy.message.DirectMessage, - ): - if isinstance(message, botpy.message.Message): - abm.self_id = str(message.mentions[0].id) - else: - abm.self_id = "" - - plain_content = QQOfficialPlatformAdapter._parse_face_message( - message.content.replace( - "<@!" + str(abm.self_id) + ">", - "", - ).strip() - ) - - await QQOfficialPlatformAdapter._append_attachments( - msg, message.attachments - ) - abm.message = msg - abm.message_str = plain_content - abm.sender = MessageMember( - str(message.author.id), - str(message.author.username), - ) - msg.append(At(qq="qq_official")) - msg.append(Plain(plain_content)) - - if isinstance(message, botpy.message.Message): - abm.group_id = message.channel_id - else: - raise ValueError(f"Unknown message type: {message_type}") - abm.self_id = "qq_official" - return abm - - def run(self): - return self.client.start(appid=self.appid, secret=self.secret) - - def get_client(self) -> botClient: - return self.client - - async def terminate(self) -> None: - await self.client.shutdown() - logger.info("QQ 官方机器人接口 适配器已被关闭") +from __future__ import annotations + +import asyncio +import logging +import os +import random +import time +import uuid +from pathlib import Path +from types import SimpleNamespace +from typing import Any, cast + +import botpy +import botpy.message +from botpy import Client + +from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import At, File, Image, Plain, Record, Video +from astrbot.api.platform import ( + AstrBotMessage, + MessageMember, + MessageType, + Platform, + PlatformMetadata, +) +from astrbot.core.message.components import BaseMessageComponent +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.io import download_file + +from ...register import register_platform_adapter +from .qqofficial_message_event import QQOfficialMessageEvent + +# remove logger handler +for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + +# QQ 机器人官方框架 +class botClient(Client): + def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: + self.platform = platform + + # 收到群消息 + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: + abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( + message, + MessageType.GROUP_MESSAGE, + ) + abm.group_id = cast(str, message.group_openid) + abm.session_id = abm.group_id + self.platform.remember_session_scene(abm.session_id, "group") + self._commit(abm) + + # 收到频道消息 + async def on_at_message_create(self, message: botpy.message.Message) -> None: + abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( + message, + MessageType.GROUP_MESSAGE, + ) + abm.group_id = message.channel_id + abm.session_id = abm.group_id + self.platform.remember_session_scene(abm.session_id, "channel") + self._commit(abm) + + # 收到私聊消息 + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: + abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( + message, + MessageType.FRIEND_MESSAGE, + ) + abm.session_id = abm.sender.user_id + self.platform.remember_session_scene(abm.session_id, "friend") + self._commit(abm) + + # 收到 C2C 消息 + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: + abm = await QQOfficialPlatformAdapter._parse_from_qqofficial( + message, + MessageType.FRIEND_MESSAGE, + ) + abm.session_id = abm.sender.user_id + self.platform.remember_session_scene(abm.session_id, "friend") + self._commit(abm) + + def _commit(self, abm: AstrBotMessage) -> None: + self.platform.remember_session_message_id(abm.session_id, abm.message_id) + self.platform.commit_event( + QQOfficialMessageEvent( + abm.message_str, + abm, + self.platform.meta(), + abm.session_id, + self.platform.client, + appid=self.platform.appid, + secret=self.platform.secret, + ), + ) + + +@register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器") +class QQOfficialPlatformAdapter(Platform): + def __init__( + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, + ) -> None: + super().__init__(platform_config, event_queue) + + self.appid = platform_config["appid"] + self.secret = platform_config["secret"] + qq_group = platform_config["enable_group_c2c"] + guild_dm = platform_config["enable_guild_direct_message"] + + if qq_group: + self.intents = botpy.Intents( + public_messages=True, + public_guild_messages=True, + direct_message=guild_dm, + ) + else: + self.intents = botpy.Intents( + public_guild_messages=True, + direct_message=guild_dm, + ) + self.client = botClient( + intents=self.intents, + bot_log=False, + timeout=20, + ) + self.client._appid = self.appid + self.client._secret = self.secret + + self.client.set_platform(self) + + self._session_last_message_id: dict[str, str] = {} + self._session_scene: dict[str, str] = {} + + self.test_mode = os.environ.get("TEST_MODE", "off") == "on" + + async def send_by_session( + self, + session: MessageSesion, + message_chain: MessageChain, + ) -> None: + await self._send_by_session_common(session, message_chain) + + async def _send_by_session_common( + self, + session: MessageSesion, + message_chain: MessageChain, + ) -> None: + ( + plain_text, + image_source, + record_file_path, + video_file_source, + file_source, + file_name, + ) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain) + if ( + not plain_text + and not image_source + and not record_file_path + and not video_file_source + and not file_source + ): + return + + msg_id = self._session_last_message_id.get(session.session_id) + if not msg_id: + logger.warning( + "[QQOfficial] No cached msg_id for session: %s, skip send_by_session", + session.session_id, + ) + return + + payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id} + ret: Any = None + send_helper = SimpleNamespace(bot=self.client) + + if session.message_type == MessageType.GROUP_MESSAGE: + scene = self._session_scene.get(session.session_id) + if scene == "group": + payload["msg_seq"] = random.randint(1, 10000) + if image_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_image( + send_helper, # type: ignore + image_source, + QQOfficialMessageEvent.IMAGE_FILE_TYPE, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if record_file_path: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + record_file_path, + QQOfficialMessageEvent.VOICE_FILE_TYPE, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if video_file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + video_file_source, + QQOfficialMessageEvent.VIDEO_FILE_TYPE, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("msg_id", None) + if file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + file_source, + QQOfficialMessageEvent.FILE_FILE_TYPE, + file_name=file_name, + group_openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + payload.pop("msg_id", None) + ret = await self.client.api.post_group_message( + group_openid=session.session_id, + **payload, + ) + else: + if image_source: + if os.path.exists(image_source): + payload["file_image"] = image_source + elif image_source.startswith("http"): + payload["image_url"] = image_source + ret = await self.client.api.post_message( + channel_id=session.session_id, + **payload, + ) + + elif session.message_type == MessageType.FRIEND_MESSAGE: + # 参考 https://bot.q.qq.com/wiki/develop/pythonsdk/api/message/post_message.html + # msg_id 缺失时认为是主动推送,而似乎至少在私聊上主动推送是没有被限制的,这里直接移除 msg_id 可以避免越权或 msg_id 不可用的bug + payload.pop("msg_id", None) + payload["msg_seq"] = random.randint(1, 10000) + if image_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_image( + send_helper, # type: ignore + image_source, + QQOfficialMessageEvent.IMAGE_FILE_TYPE, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if record_file_path: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + record_file_path, + QQOfficialMessageEvent.VOICE_FILE_TYPE, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if video_file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + video_file_source, + QQOfficialMessageEvent.VIDEO_FILE_TYPE, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + if file_source: + media = await QQOfficialMessageEvent.upload_group_and_c2c_media( + send_helper, # type: ignore + file_source, + QQOfficialMessageEvent.FILE_FILE_TYPE, + file_name=file_name, + openid=session.session_id, + ) + if media: + payload["media"] = media + payload["msg_type"] = 7 + + ret = await QQOfficialMessageEvent.post_c2c_message( + send_helper, # type: ignore + openid=session.session_id, + **payload, + ) + else: + logger.warning( + "[QQOfficial] Unsupported message type for send_by_session: %s", + session.message_type, + ) + return + + sent_message_id = self._extract_message_id(ret) + if sent_message_id: + self.remember_session_message_id(session.session_id, sent_message_id) + await super().send_by_session(session, message_chain) + + def remember_session_message_id(self, session_id: str, message_id: str) -> None: + if not session_id or not message_id: + return + self._session_last_message_id[session_id] = message_id + + def remember_session_scene(self, session_id: str, scene: str) -> None: + if not session_id or not scene: + return + self._session_scene[session_id] = scene + + def _extract_message_id(self, ret: Any) -> str | None: + if isinstance(ret, dict): + message_id = ret.get("id") + return str(message_id) if message_id else None + message_id = getattr(ret, "id", None) + if message_id: + return str(message_id) + return None + + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + name="qq_official", + description="QQ 机器人官方 API 适配器", + id=cast(str, self.config.get("id")), + support_proactive_message=True, + ) + + @staticmethod + def _normalize_attachment_url(url: str | None) -> str: + if not url: + return "" + if url.startswith("http://") or url.startswith("https://"): + return url + return f"https://{url}" + + @staticmethod + async def _prepare_audio_attachment( + url: str, + filename: str, + ) -> Record: + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + + ext = Path(filename).suffix.lower() + source_ext = ext or ".audio" + source_path = temp_dir / f"qqofficial_{uuid.uuid4().hex}{source_ext}" + await download_file(url, str(source_path)) + + return Record(file=str(source_path), url=str(source_path)) + + @staticmethod + async def _append_attachments( + msg: list[BaseMessageComponent], + attachments: list | None, + ) -> None: + if not attachments: + return + + for attachment in attachments: + content_type = cast( + str, + getattr(attachment, "content_type", "") or "", + ).lower() + url = QQOfficialPlatformAdapter._normalize_attachment_url( + cast(str | None, getattr(attachment, "url", None)) + ) + if not url: + continue + + if content_type.startswith("image"): + msg.append(Image.fromURL(url)) + else: + filename = cast( + str, + getattr(attachment, "filename", None) + or getattr(attachment, "name", None) + or "attachment", + ) + ext = Path(filename).suffix.lower() + image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} + audio_exts = { + ".mp3", + ".wav", + ".ogg", + ".m4a", + ".amr", + ".silk", + } + video_exts = { + ".mp4", + ".mov", + ".avi", + ".mkv", + ".webm", + } + + if content_type.startswith("voice") or ext in audio_exts: + try: + msg.append( + await QQOfficialPlatformAdapter._prepare_audio_attachment( + url, + filename, + ) + ) + except Exception as e: + logger.warning( + "[QQOfficial] Failed to prepare audio attachment %s: %s", + url, + e, + ) + msg.append(Record.fromURL(url)) + elif content_type.startswith("video") or ext in video_exts: + msg.append(Video.fromURL(url)) + elif content_type.startswith("image") or ext in image_exts: + msg.append(Image.fromURL(url)) + else: + msg.append(File(name=filename, file=url, url=url)) + + @staticmethod + def _parse_face_message(content: str) -> str: + """Parse QQ official face message format and convert to readable text. + + QQ official face message format: + + + The ext field contains base64-encoded JSON with a 'text' field + describing the emoji (e.g., '[满头问号]'). + + Args: + content: The message content that may contain face tags. + + Returns: + Content with face tags replaced by readable emoji descriptions. + """ + import base64 + import json + import re + + def replace_face(match): + face_tag = match.group(0) + # Extract ext field from the face tag + ext_match = re.search(r'ext="([^"]*)"', face_tag) + if ext_match: + try: + ext_encoded = ext_match.group(1) + # Decode base64 and parse JSON + ext_decoded = base64.b64decode(ext_encoded).decode("utf-8") + ext_data = json.loads(ext_decoded) + emoji_text = ext_data.get("text", "") + if emoji_text: + return f"[表情:{emoji_text}]" + except Exception: + pass + # Fallback if parsing fails + return "[表情]" + + # Match face tags: + return re.sub(r"]*>", replace_face, content) + + @staticmethod + async def _parse_from_qqofficial( + message: botpy.message.Message + | botpy.message.GroupMessage + | botpy.message.DirectMessage + | botpy.message.C2CMessage, + message_type: MessageType, + ) -> AstrBotMessage: + abm = AstrBotMessage() + abm.type = message_type + abm.timestamp = int(time.time()) + abm.raw_message = message + abm.message_id = message.id + # abm.tag = "qq_official" + msg: list[BaseMessageComponent] = [] + + if isinstance(message, botpy.message.GroupMessage) or isinstance( + message, + botpy.message.C2CMessage, + ): + if isinstance(message, botpy.message.GroupMessage): + abm.sender = MessageMember(message.author.member_openid, "") + abm.group_id = message.group_openid + else: + abm.sender = MessageMember(message.author.user_openid, "") + # Parse face messages to readable text + abm.message_str = QQOfficialPlatformAdapter._parse_face_message( + message.content.strip() + ) + abm.self_id = "unknown_selfid" + msg.append(At(qq="qq_official")) + msg.append(Plain(abm.message_str)) + await QQOfficialPlatformAdapter._append_attachments( + msg, message.attachments + ) + abm.message = msg + + elif isinstance(message, botpy.message.Message) or isinstance( + message, + botpy.message.DirectMessage, + ): + if isinstance(message, botpy.message.Message): + abm.self_id = str(message.mentions[0].id) + else: + abm.self_id = "" + + plain_content = QQOfficialPlatformAdapter._parse_face_message( + message.content.replace( + "<@!" + str(abm.self_id) + ">", + "", + ).strip() + ) + + await QQOfficialPlatformAdapter._append_attachments( + msg, message.attachments + ) + abm.message = msg + abm.message_str = plain_content + abm.sender = MessageMember( + str(message.author.id), + str(message.author.username), + ) + msg.append(At(qq="qq_official")) + msg.append(Plain(plain_content)) + + if isinstance(message, botpy.message.Message): + abm.group_id = message.channel_id + else: + raise ValueError(f"Unknown message type: {message_type}") + abm.self_id = "qq_official" + return abm + + def run(self): + return self.client.start(appid=self.appid, secret=self.secret) + + def get_client(self) -> botClient: + return self.client + + async def terminate(self) -> None: + await self.client.close() + logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial/rate_limiter.py b/astrbot/core/platform/sources/qqofficial/rate_limiter.py new file mode 100644 index 0000000000..852d3fd465 --- /dev/null +++ b/astrbot/core/platform/sources/qqofficial/rate_limiter.py @@ -0,0 +1,226 @@ +""" +消息回复限流器 +参照 openclaw-qqbot 的 outbound.ts 实现 + +规则: +- 同一 message_id 1小时内最多回复 4 次 +- 超过 1 小时 message_id 失效,需要降级为主动消息 +""" + +import time +import threading +from dataclasses import dataclass, field +from typing import Dict, Optional + +from astrbot import logger + + +@dataclass +class MessageReplyRecord: + """消息回复记录""" + + count: int = 0 + first_reply_at: float = 0.0 + + +@dataclass +class ReplyLimitResult: + """限流检查结果""" + + # 是否允许被动回复 + allowed: bool + # 剩余被动回复次数 + remaining: int + # 是否需要降级为主动消息 + should_fallback_to_proactive: bool + # 降级原因 + fallback_reason: Optional[str] = None + # 提示消息 + message: Optional[str] = None + + +class MessageReplyLimiter: + """ + 消息回复限流器 + + 规则: + - 同一 message_id 1小时内最多回复 4 次 + - 超过 1 小时 message_id 失效,需要降级为主动消息 + """ + + # 同一 message_id 1小时内最多回复次数 + MESSAGE_REPLY_LIMIT = 4 + + # message_id 有效期(毫秒)- 1小时 + MESSAGE_REPLY_TTL_MS = 60 * 60 * 1000 + + # 最大追踪消息数(避免内存泄漏) + MAX_TRACKED_MESSAGES = 10000 + + def __init__(self): + self._tracker: Dict[str, MessageReplyRecord] = {} + self._lock = threading.RLock() + + def check_limit(self, message_id: str) -> ReplyLimitResult: + """ + 检查是否可以回复该消息(限流检查) + + Args: + message_id: 消息ID + + Returns: + ReplyLimitResult: 限流检查结果 + """ + now = time.time() * 1000 # 转换为毫秒 + + with self._lock: + record = self._tracker.get(message_id) + + # 定期清理过期记录(避免内存泄漏) + if len(self._tracker) > self.MAX_TRACKED_MESSAGES: + self._cleanup_expired_records(now) + + # 新消息,首次回复 + if not record: + return ReplyLimitResult( + allowed=True, + remaining=self.MESSAGE_REPLY_LIMIT, + should_fallback_to_proactive=False, + ) + + # 检查是否超过1小时(message_id 过期) + if now - record.first_reply_at > self.MESSAGE_REPLY_TTL_MS: + # 超过1小时,被动回复不可用,需要降级为主动消息 + return ReplyLimitResult( + allowed=False, + remaining=0, + should_fallback_to_proactive=True, + fallback_reason="expired", + message="消息已超过1小时有效期,将使用主动消息发送", + ) + + # 检查是否超过回复次数限制 + remaining = self.MESSAGE_REPLY_LIMIT - record.count + if remaining <= 0: + return ReplyLimitResult( + allowed=False, + remaining=0, + should_fallback_to_proactive=True, + fallback_reason="limit_exceeded", + message=f"该消息已达到1小时内最大回复次数({self.MESSAGE_REPLY_LIMIT}次),将使用主动消息发送", + ) + + return ReplyLimitResult( + allowed=True, + remaining=remaining, + should_fallback_to_proactive=False, + ) + + def record_reply(self, message_id: str) -> None: + """ + 记录一次消息回复 + + Args: + message_id: 消息ID + """ + now = time.time() * 1000 + + with self._lock: + record = self._tracker.get(message_id) + + if not record: + self._tracker[message_id] = MessageReplyRecord( + count=1, first_reply_at=now + ) + else: + # 检查是否过期,过期则重新计数 + if now - record.first_reply_at > self.MESSAGE_REPLY_TTL_MS: + self._tracker[message_id] = MessageReplyRecord( + count=1, first_reply_at=now + ) + else: + record.count += 1 + + record = self._tracker.get(message_id) + if record: + logger.debug( + f"[QQOfficial] recordReply: {message_id}, count={record.count}" + ) + + def get_stats(self) -> Dict[str, int]: + """ + 获取消息回复统计信息 + + Returns: + Dict: 包含 tracked_messages 和 total_replies + """ + with self._lock: + total_replies = sum(r.count for r in self._tracker.values()) + return { + "tracked_messages": len(self._tracker), + "total_replies": total_replies, + } + + def get_config(self) -> Dict[str, int]: + """ + 获取消息回复限制配置(供外部查询) + + Returns: + Dict: 包含 limit, ttl_ms, ttl_hours + """ + return { + "limit": self.MESSAGE_REPLY_LIMIT, + "ttl_ms": self.MESSAGE_REPLY_TTL_MS, + "ttl_hours": self.MESSAGE_REPLY_TTL_MS // (60 * 60 * 1000), + } + + def _cleanup_expired_records(self, now: float) -> None: + """清理过期记录""" + expired_keys = [ + msg_id + for msg_id, rec in self._tracker.items() + if now - rec.first_reply_at > self.MESSAGE_REPLY_TTL_MS + ] + for key in expired_keys: + del self._tracker[key] + if expired_keys: + logger.debug( + f"[QQOfficial] Cleaned up {len(expired_keys)} expired message records" + ) + + +# 全局限流器实例 +_global_limiter: Optional[MessageReplyLimiter] = None +_global_limiter_lock = threading.RLock() + + +def get_rate_limiter() -> MessageReplyLimiter: + """获取全局限流器实例""" + global _global_limiter + with _global_limiter_lock: + if _global_limiter is None: + _global_limiter = MessageReplyLimiter() + return _global_limiter + + +def check_message_reply_limit(message_id: str) -> ReplyLimitResult: + """ + 检查是否可以回复该消息(便捷函数) + + Args: + message_id: 消息ID + + Returns: + ReplyLimitResult: 限流检查结果 + """ + return get_rate_limiter().check_limit(message_id) + + +def record_message_reply(message_id: str) -> None: + """ + 记录一次消息回复(便捷函数) + + Args: + message_id: 消息ID + """ + get_rate_limiter().record_reply(message_id)