-
Notifications
You must be signed in to change notification settings - Fork 9.2k
Fix: Support extra_body.reference_metadata on agent completions (#14308) #14314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
1c93221
8fe183e
d08d10f
c80be78
ed92396
56ecf78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -440,6 +440,20 @@ async def streamed_response_generator(chat_id, dia, msg): | |
| @token_required | ||
| async def agents_completion_openai_compatibility(tenant_id, agent_id): | ||
| req = await get_request_json() | ||
| extra_body = req.get("extra_body", {}) | ||
| if extra_body is None: | ||
| extra_body = {} | ||
| elif not isinstance(extra_body, dict): | ||
| return get_error_data_result("extra_body must be an object.") | ||
| reference_metadata = extra_body.get("reference_metadata", {}) | ||
| if reference_metadata is None: | ||
| reference_metadata = {} | ||
| elif not isinstance(reference_metadata, dict): | ||
| return get_error_data_result("reference_metadata must be an object.") | ||
| include_reference_metadata = bool(reference_metadata.get("include", False)) | ||
| metadata_fields = reference_metadata.get("fields") | ||
| if metadata_fields is not None and not isinstance(metadata_fields, list): | ||
| return get_error_data_result("reference_metadata.fields must be an array.") | ||
| messages = req.get("messages", []) | ||
| if not messages: | ||
| return get_error_data_result("You must provide at least one message.") | ||
|
|
@@ -464,15 +478,23 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): | |
|
|
||
| stream = req.pop("stream", False) | ||
| if stream: | ||
| resp = Response( | ||
| completion_openai( | ||
| async def generate(): | ||
| async for answer in completion_openai( | ||
| tenant_id, | ||
| agent_id, | ||
| question, | ||
| session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), | ||
| stream=True, | ||
| **req, | ||
| ), | ||
| ): | ||
| yield _build_agent_openai_response( | ||
| answer, | ||
| include_metadata=include_reference_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
|
|
||
| resp = Response( | ||
| generate(), | ||
| mimetype="text/event-stream", | ||
| ) | ||
| resp.headers.add_header("Cache-control", "no-cache") | ||
|
|
@@ -490,6 +512,11 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): | |
| stream=False, | ||
| **req, | ||
| ): | ||
| response = _build_agent_openai_response( | ||
| response, | ||
| include_metadata=include_reference_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| return jsonify(response) | ||
|
|
||
| return None | ||
|
|
@@ -499,6 +526,20 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id): | |
| @token_required | ||
| async def agent_completions(tenant_id, agent_id): | ||
| req = await get_request_json() | ||
| extra_body = req.get("extra_body", {}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the same Suggested replacement: That keeps the validation consistent across both agent endpoints and reduces the amount of code changed for this feature. |
||
| if extra_body is None: | ||
| extra_body = {} | ||
| elif not isinstance(extra_body, dict): | ||
| return get_error_data_result("extra_body must be an object.") | ||
| reference_metadata = extra_body.get("reference_metadata", {}) | ||
| if reference_metadata is None: | ||
| reference_metadata = {} | ||
| elif not isinstance(reference_metadata, dict): | ||
| return get_error_data_result("reference_metadata must be an object.") | ||
| include_reference_metadata = bool(reference_metadata.get("include", False)) | ||
| metadata_fields = reference_metadata.get("fields") | ||
| if metadata_fields is not None and not isinstance(metadata_fields, list): | ||
| return get_error_data_result("reference_metadata.fields must be an array.") | ||
| return_trace = bool(req.get("return_trace", False)) | ||
|
|
||
| if req.get("stream", True): | ||
|
|
@@ -511,19 +552,29 @@ async def generate(): | |
| ans = json.loads(answer[5:]) # remove "data:" | ||
| except Exception: | ||
| continue | ||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can stay local to the agent path without needing the extra Suggested change here: It avoids adding another formatting helper just for this. |
||
| ans = answer | ||
|
|
||
| data = ans.get("data", {}) | ||
| if data.get("reference") is not None: | ||
| data["reference"] = _build_agent_reference( | ||
| data["reference"], | ||
| include_metadata=include_reference_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| answer = _to_sse(ans) | ||
|
|
||
| event = ans.get("event") | ||
| if event == "node_finished": | ||
| if return_trace: | ||
| data = ans.get("data", {}) | ||
| trace_items.append( | ||
| { | ||
| "component_id": data.get("component_id"), | ||
| "trace": [copy.deepcopy(data)], | ||
| } | ||
| ) | ||
| ans.setdefault("data", {})["trace"] = trace_items | ||
| answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" | ||
| answer = _to_sse(ans) | ||
| yield answer | ||
|
|
||
| if event not in ["message", "message_end"]: | ||
|
|
@@ -547,13 +598,19 @@ async def generate(): | |
| structured_output = {} | ||
| async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): | ||
| try: | ||
| ans = json.loads(answer[5:]) | ||
| ans = json.loads(answer[5:]) if isinstance(answer, str) else answer | ||
|
|
||
| if ans["event"] == "message": | ||
| full_content += ans["data"]["content"] | ||
|
|
||
| if ans.get("data", {}).get("reference", None): | ||
| reference.update(ans["data"]["reference"]) | ||
| reference.update( | ||
| _build_agent_reference( | ||
| ans["data"]["reference"], | ||
| include_metadata=include_reference_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| ) | ||
|
|
||
| if ans.get("event") == "node_finished": | ||
| data = ans.get("data", {}) | ||
|
|
@@ -1267,6 +1324,22 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N | |
| if not include_metadata: | ||
| return chunks | ||
|
|
||
| meta_by_doc = _get_reference_metadata_by_doc(chunks, metadata_fields) | ||
| if not meta_by_doc: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this helper section can be made quite a bit smaller. For this issue, we do not really need to refactor I would suggest:
Suggested helper set: I think that keeps the requested feature fully intact, but with a smaller and more targeted patch. |
||
| return chunks | ||
|
|
||
| for chunk in chunks: | ||
| doc_id = chunk.get("document_id") | ||
| if not doc_id: | ||
| continue | ||
| meta = meta_by_doc.get(doc_id) | ||
| if meta: | ||
| chunk["document_metadata"] = meta | ||
|
|
||
| return chunks | ||
|
|
||
|
|
||
| def _get_reference_metadata_by_doc(chunks, metadata_fields=None): | ||
| doc_ids_by_kb = {} | ||
| for chunk in chunks: | ||
| kb_id = chunk.get("dataset_id") | ||
|
|
@@ -1276,7 +1349,7 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N | |
| doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id) | ||
|
|
||
| if not doc_ids_by_kb: | ||
| return chunks | ||
| return {} | ||
|
|
||
| meta_by_doc = {} | ||
| for kb_id, doc_ids in doc_ids_by_kb.items(): | ||
|
|
@@ -1287,18 +1360,82 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N | |
| if metadata_fields is not None: | ||
| metadata_fields = {f for f in metadata_fields if isinstance(f, str)} | ||
| if not metadata_fields: | ||
| return chunks | ||
| return {} | ||
|
|
||
| for chunk in chunks: | ||
| doc_id = chunk.get("document_id") | ||
| if not doc_id: | ||
| continue | ||
| meta = meta_by_doc.get(doc_id) | ||
| if not meta: | ||
| continue | ||
| filtered = {} | ||
| for doc_id, meta in meta_by_doc.items(): | ||
| if metadata_fields is not None: | ||
| meta = {k: v for k, v in meta.items() if k in metadata_fields} | ||
| if meta: | ||
| filtered[doc_id] = meta | ||
| return filtered | ||
|
|
||
|
|
||
| def _build_agent_reference(reference, include_metadata=False, metadata_fields=None): | ||
| if not include_metadata or not isinstance(reference, dict): | ||
| return reference | ||
|
|
||
| meta_by_doc = _get_reference_metadata_by_doc(chunks_format(reference), metadata_fields) | ||
| if not meta_by_doc: | ||
| return reference | ||
|
|
||
| enriched_reference = copy.deepcopy(reference) | ||
| raw_chunks = enriched_reference.get("chunks") | ||
| if isinstance(raw_chunks, dict): | ||
| raw_chunks = raw_chunks.values() | ||
| elif not isinstance(raw_chunks, list): | ||
| return enriched_reference | ||
|
|
||
| for chunk in raw_chunks: | ||
| if not isinstance(chunk, dict): | ||
| continue | ||
| doc_id = chunk.get("doc_id", chunk.get("document_id")) | ||
| meta = meta_by_doc.get(doc_id) | ||
| if meta: | ||
| chunk["document_metadata"] = meta | ||
| return enriched_reference | ||
|
|
||
| return chunks | ||
|
|
||
| def _build_agent_openai_response(answer, include_metadata=False, metadata_fields=None): | ||
| if not include_metadata: | ||
| return answer | ||
|
|
||
| if isinstance(answer, str): | ||
| if answer.strip() == "data: [DONE]" or answer.strip() == "data:[DONE]": | ||
| return answer | ||
| try: | ||
| payload = json.loads(answer[5:]) | ||
| except Exception: | ||
| return answer | ||
| payload = _build_agent_openai_response( | ||
| payload, | ||
| include_metadata=include_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| return _to_sse(payload) | ||
|
|
||
| if not isinstance(answer, dict): | ||
| return answer | ||
|
|
||
| for choice in answer.get("choices", []): | ||
| if not isinstance(choice, dict): | ||
| continue | ||
| delta = choice.get("delta") | ||
| if isinstance(delta, dict) and delta.get("reference") is not None: | ||
| delta["reference"] = _build_agent_reference( | ||
| delta["reference"], | ||
| include_metadata=include_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| message = choice.get("message") | ||
| if isinstance(message, dict) and message.get("reference") is not None: | ||
| message["reference"] = _build_agent_reference( | ||
| message["reference"], | ||
| include_metadata=include_metadata, | ||
| metadata_fields=metadata_fields, | ||
| ) | ||
| return answer | ||
|
|
||
|
|
||
| def _to_sse(payload): | ||
| return "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parsing block is duplicated again below in
agent_completions(...), and I think it can be reduced to a tiny shared helper.Suggested change here:
Then add this helper near the bottom of the file: