diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 82e048ff17b..7bd04ff4b2b 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -440,6 +440,9 @@ async def streamed_response_generator(chat_id, dia, msg): @token_required async def agents_completion_openai_compatibility(tenant_id, agent_id): req = await get_request_json() + err, include_reference_metadata, metadata_fields = _parse_reference_metadata(req) + if err: + return err messages = req.get("messages", []) if not messages: return get_error_data_result("You must provide at least one message.") @@ -461,18 +464,32 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): ) question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") + metadata = req.get("metadata") + if metadata is None: + metadata = {} + elif not isinstance(metadata, dict): + return get_error_data_result("metadata must be an object.") + session_id = req.pop("session_id", req.get("id", "")) or metadata.get("id", "") stream = req.pop("stream", False) if stream: + body = completion_openai( + tenant_id, + agent_id, + question, + session_id=session_id, + stream=True, + **req, + ) + if include_reference_metadata: + async def generate(): + async for answer in body: + yield _build_agent_openai_response(answer, metadata_fields=metadata_fields) + + body = generate() + resp = Response( - completion_openai( - tenant_id, - agent_id, - question, - session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), - stream=True, - **req, - ), + body, mimetype="text/event-stream", ) resp.headers.add_header("Cache-control", "no-cache") @@ -486,10 +503,12 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): tenant_id, agent_id, question, - session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), + session_id=session_id, stream=False, **req, ): + if include_reference_metadata: + response = _build_agent_openai_response(response, metadata_fields=metadata_fields) return jsonify(response) return None @@ -499,6 +518,9 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): @token_required async def agent_completions(tenant_id, agent_id): req = await get_request_json() + err, include_reference_metadata, metadata_fields = _parse_reference_metadata(req) + if err: + return err return_trace = bool(req.get("return_trace", False)) if req.get("stream", True): @@ -511,11 +533,24 @@ async def generate(): ans = json.loads(answer[5:]) # remove "data:" except Exception: continue + else: + ans = answer + answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + + data = ans.get("data") or {} + if not isinstance(data, dict): + data = {} + ans["data"] = data + if include_reference_metadata and data.get("reference") is not None: + data["reference"] = _build_agent_reference( + data["reference"], + metadata_fields=metadata_fields, + ) + answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" event = ans.get("event") if event == "node_finished": if return_trace: - data = ans.get("data", {}) trace_items.append( { "component_id": data.get("component_id"), @@ -547,17 +582,25 @@ async def generate(): structured_output = {} async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): try: - ans = json.loads(answer[5:]) + ans = json.loads(answer[5:]) if isinstance(answer, str) else answer + data = ans.get("data") or {} + if not isinstance(data, dict): + data = {} + ans["data"] = data if ans["event"] == "message": - full_content += ans["data"]["content"] + full_content += data.get("content", "") - if ans.get("data", {}).get("reference", None): - reference.update(ans["data"]["reference"]) + if data.get("reference", None): + ref = data["reference"] + if include_reference_metadata: + ref = _build_agent_reference(ref, metadata_fields=metadata_fields) + reference.update(ref) if ans.get("event") == "node_finished": - data = ans.get("data", {}) node_out = data.get("outputs", {}) + if not isinstance(node_out, dict): + node_out = {} component_id = data.get("component_id") if component_id is not None and "structured" in node_out: structured_output[component_id] = copy.deepcopy(node_out["structured"]) @@ -1302,3 +1345,105 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N chunk["document_metadata"] = meta return chunks + + +def _parse_reference_metadata(req): + extra_body = req.get("extra_body", {}) + if extra_body is None: + extra_body = {} + elif not isinstance(extra_body, dict): + return get_error_data_result("extra_body must be an object."), False, None + + reference_metadata = extra_body.get("reference_metadata", {}) + if reference_metadata is None: + reference_metadata = {} + elif not isinstance(reference_metadata, dict): + return get_error_data_result("reference_metadata must be an object."), False, None + + metadata_fields = reference_metadata.get("fields") + if metadata_fields is not None and not isinstance(metadata_fields, list): + return get_error_data_result("reference_metadata.fields must be an array."), False, None + + include_reference_metadata = reference_metadata.get("include", False) + if not isinstance(include_reference_metadata, bool): + return get_error_data_result("reference_metadata.include must be a boolean."), False, None + + return None, include_reference_metadata, metadata_fields + + +def _build_agent_reference(reference, metadata_fields=None): + if not isinstance(reference, dict): + return reference + + raw_chunks = reference.get("chunks") + if isinstance(raw_chunks, dict): + chunks = [chunk for chunk in raw_chunks.values() if isinstance(chunk, dict)] + elif isinstance(raw_chunks, list): + chunks = [chunk for chunk in raw_chunks if isinstance(chunk, dict)] + else: + return reference + + if metadata_fields is not None: + metadata_fields = {f for f in metadata_fields if isinstance(f, str)} + if not metadata_fields: + return reference + + doc_ids_by_kb = {} + for chunk in chunks: + kb_id = chunk.get("kb_id") or chunk.get("dataset_id") + doc_id = chunk.get("doc_id") or chunk.get("document_id") + if not kb_id or not doc_id: + continue + doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) + + if not doc_ids_by_kb: + return reference + + meta_by_doc = {} + for kb_id, doc_ids in doc_ids_by_kb.items(): + meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id) + if meta_map: + meta_by_doc.update(meta_map) + + if not meta_by_doc: + return reference + + for chunk in chunks: + doc_id = chunk.get("doc_id") or chunk.get("document_id") + meta = meta_by_doc.get(doc_id) + if not meta: + continue + if metadata_fields is not None: + meta = {k: v for k, v in meta.items() if k in metadata_fields} + if meta: + chunk["document_metadata"] = meta + return reference + + +def _build_agent_openai_response(answer, metadata_fields=None): + + if isinstance(answer, str): + if answer.strip() in {"data:[DONE]", "data: [DONE]"}: + return answer + if not answer.startswith("data:"): + return answer + try: + payload = json.loads(answer[5:].strip()) + except Exception: + return answer + payload = _build_agent_openai_response(payload, metadata_fields=metadata_fields) + return "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + + if not isinstance(answer, dict): + return answer + + for choice in answer.get("choices", []): + if not isinstance(choice, dict): + continue + delta = choice.get("delta") + if isinstance(delta, dict) and delta.get("reference") is not None: + delta["reference"] = _build_agent_reference(delta["reference"], metadata_fields=metadata_fields) + message = choice.get("message") + if isinstance(message, dict) and message.get("reference") is not None: + message["reference"] = _build_agent_reference(message["reference"], metadata_fields=metadata_fields) + return answer