Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 154 additions & 17 deletions api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,20 @@ async def streamed_response_generator(chat_id, dia, msg):
@token_required
async def agents_completion_openai_compatibility(tenant_id, agent_id):
req = await get_request_json()
extra_body = req.get("extra_body", {})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parsing block is duplicated again below in agent_completions(...), and I think it can be reduced to a tiny shared helper.

Suggested change here:

err, include_reference_metadata, metadata_fields = _parse_reference_metadata(req)
if err:
    return err

Then add this helper near the bottom of the file:

def _parse_reference_metadata(req):
    extra_body = req.get("extra_body") or {}
    if extra_body and not isinstance(extra_body, dict):
        return get_error_data_result("extra_body must be an object."), False, None
    reference_metadata = extra_body.get("reference_metadata") or {}
    if reference_metadata and not isinstance(reference_metadata, dict):
        return get_error_data_result("reference_metadata must be an object."), False, None
    metadata_fields = reference_metadata.get("fields")
    if metadata_fields is not None and not isinstance(metadata_fields, list):
        return get_error_data_result("reference_metadata.fields must be an array."), False, None
    return None, bool(reference_metadata.get("include", False)), metadata_fields

if extra_body is None:
extra_body = {}
elif not isinstance(extra_body, dict):
return get_error_data_result("extra_body must be an object.")
reference_metadata = extra_body.get("reference_metadata", {})
if reference_metadata is None:
reference_metadata = {}
elif not isinstance(reference_metadata, dict):
return get_error_data_result("reference_metadata must be an object.")
include_reference_metadata = bool(reference_metadata.get("include", False))
metadata_fields = reference_metadata.get("fields")
if metadata_fields is not None and not isinstance(metadata_fields, list):
return get_error_data_result("reference_metadata.fields must be an array.")
messages = req.get("messages", [])
if not messages:
return get_error_data_result("You must provide at least one message.")
Expand All @@ -464,15 +478,23 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):

stream = req.pop("stream", False)
if stream:
resp = Response(
completion_openai(
async def generate():
async for answer in completion_openai(
tenant_id,
agent_id,
question,
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
stream=True,
**req,
),
):
yield _build_agent_openai_response(
answer,
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)

resp = Response(
generate(),
mimetype="text/event-stream",
)
resp.headers.add_header("Cache-control", "no-cache")
Expand All @@ -490,6 +512,11 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
stream=False,
**req,
):
response = _build_agent_openai_response(
response,
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)
return jsonify(response)

return None
Expand All @@ -499,6 +526,20 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
@token_required
async def agent_completions(tenant_id, agent_id):
req = await get_request_json()
extra_body = req.get("extra_body", {})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same extra_body.reference_metadata parsing logic again, so I think it should be replaced with the same small helper instead of duplicating the whole block.

Suggested replacement:

err, include_reference_metadata, metadata_fields = _parse_reference_metadata(req)
if err:
    return err

That keeps the validation consistent across both agent endpoints and reduces the amount of code changed for this feature.

if extra_body is None:
extra_body = {}
elif not isinstance(extra_body, dict):
return get_error_data_result("extra_body must be an object.")
reference_metadata = extra_body.get("reference_metadata", {})
if reference_metadata is None:
reference_metadata = {}
elif not isinstance(reference_metadata, dict):
return get_error_data_result("reference_metadata must be an object.")
include_reference_metadata = bool(reference_metadata.get("include", False))
metadata_fields = reference_metadata.get("fields")
if metadata_fields is not None and not isinstance(metadata_fields, list):
return get_error_data_result("reference_metadata.fields must be an array.")
return_trace = bool(req.get("return_trace", False))

if req.get("stream", True):
Expand All @@ -511,19 +552,29 @@ async def generate():
ans = json.loads(answer[5:]) # remove "data:"
except Exception:
continue
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can stay local to the agent path without needing the extra _to_sse() helper.

Suggested change here:

data = ans.get("data", {})
if include_reference_metadata and data.get("reference") is not None:
    data["reference"] = _build_agent_reference(data["reference"], metadata_fields=metadata_fields)
    answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"

It avoids adding another formatting helper just for this.

ans = answer

data = ans.get("data", {})
if data.get("reference") is not None:
data["reference"] = _build_agent_reference(
data["reference"],
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)
answer = _to_sse(ans)

event = ans.get("event")
if event == "node_finished":
if return_trace:
data = ans.get("data", {})
trace_items.append(
{
"component_id": data.get("component_id"),
"trace": [copy.deepcopy(data)],
}
)
ans.setdefault("data", {})["trace"] = trace_items
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
answer = _to_sse(ans)
yield answer

if event not in ["message", "message_end"]:
Expand All @@ -547,13 +598,19 @@ async def generate():
structured_output = {}
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
try:
ans = json.loads(answer[5:])
ans = json.loads(answer[5:]) if isinstance(answer, str) else answer

if ans["event"] == "message":
full_content += ans["data"]["content"]

if ans.get("data", {}).get("reference", None):
reference.update(ans["data"]["reference"])
reference.update(
_build_agent_reference(
ans["data"]["reference"],
include_metadata=include_reference_metadata,
metadata_fields=metadata_fields,
)
)

if ans.get("event") == "node_finished":
data = ans.get("data", {})
Expand Down Expand Up @@ -1267,6 +1324,22 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N
if not include_metadata:
return chunks

meta_by_doc = _get_reference_metadata_by_doc(chunks, metadata_fields)
if not meta_by_doc:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this helper section can be made quite a bit smaller.

For this issue, we do not really need to refactor _build_reference_chunks() or add the extra helper stack around _get_reference_metadata_by_doc(...) and _to_sse(). The feature can stay isolated to the agent path with a smaller helper set.

I would suggest:

  • keep _build_reference_chunks() unchanged
  • add a tiny _parse_reference_metadata(req)
  • add an agent-only _build_agent_reference(...)
  • add a small _build_agent_openai_response(...)
  • skip _get_reference_metadata_by_doc(...)
  • skip _to_sse()

Suggested helper set:

def _parse_reference_metadata(req):
    extra_body = req.get("extra_body") or {}
    if extra_body and not isinstance(extra_body, dict):
        return get_error_data_result("extra_body must be an object."), False, None
    reference_metadata = extra_body.get("reference_metadata") or {}
    if reference_metadata and not isinstance(reference_metadata, dict):
        return get_error_data_result("reference_metadata must be an object."), False, None
    metadata_fields = reference_metadata.get("fields")
    if metadata_fields is not None and not isinstance(metadata_fields, list):
        return get_error_data_result("reference_metadata.fields must be an array."), False, None
    return None, bool(reference_metadata.get("include", False)), metadata_fields


def _build_agent_reference(reference, metadata_fields=None):
    if not isinstance(reference, dict):
        return reference
    raw_chunks = reference.get("chunks")
    if isinstance(raw_chunks, dict):
        chunks = [chunk for chunk in raw_chunks.values() if isinstance(chunk, dict)]
    elif isinstance(raw_chunks, list):
        chunks = [chunk for chunk in raw_chunks if isinstance(chunk, dict)]
    else:
        return reference

    if metadata_fields is not None:
        metadata_fields = {f for f in metadata_fields if isinstance(f, str)}
        if not metadata_fields:
            return reference

    doc_ids_by_kb = {}
    for chunk in chunks:
        kb_id = chunk.get("kb_id", chunk.get("dataset_id"))
        doc_id = chunk.get("doc_id", chunk.get("document_id"))
        if not kb_id or not doc_id:
            continue
        doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id)

    if not doc_ids_by_kb:
        return reference

    meta_by_doc = {}
    for kb_id, doc_ids in doc_ids_by_kb.items():
        meta_map = DocMetadataService.get_metadata_for_documents(list(doc_ids), kb_id)
        if meta_map:
            meta_by_doc.update(meta_map)

    if not meta_by_doc:
        return reference

    for chunk in chunks:
        doc_id = chunk.get("doc_id", chunk.get("document_id"))
        meta = meta_by_doc.get(doc_id)
        if not meta:
            continue
        if metadata_fields is not None:
            meta = {k: v for k, v in meta.items() if k in metadata_fields}
        if meta:
            chunk["document_metadata"] = meta
    return reference


def _build_agent_openai_response(answer, metadata_fields=None):
    if isinstance(answer, str):
        if answer.strip() in {"data:[DONE]", "data: [DONE]"}:
            return answer
        if not answer.startswith("data:"):
            return answer
        try:
            payload = json.loads(answer[5:].strip())
        except Exception:
            return answer
        payload = _build_agent_openai_response(payload, metadata_fields=metadata_fields)
        return "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"

    if not isinstance(answer, dict):
        return answer

    for choice in answer.get("choices", []):
        if not isinstance(choice, dict):
            continue
        delta = choice.get("delta")
        if isinstance(delta, dict) and delta.get("reference") is not None:
            delta["reference"] = _build_agent_reference(delta["reference"], metadata_fields=metadata_fields)
        message = choice.get("message")
        if isinstance(message, dict) and message.get("reference") is not None:
            message["reference"] = _build_agent_reference(message["reference"], metadata_fields=metadata_fields)
    return answer

I think that keeps the requested feature fully intact, but with a smaller and more targeted patch.

return chunks

for chunk in chunks:
doc_id = chunk.get("document_id")
if not doc_id:
continue
meta = meta_by_doc.get(doc_id)
if meta:
chunk["document_metadata"] = meta

return chunks


def _get_reference_metadata_by_doc(chunks, metadata_fields=None):
doc_ids_by_kb = {}
for chunk in chunks:
kb_id = chunk.get("dataset_id")
Expand All @@ -1276,7 +1349,7 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N
doc_ids_by_kb.setdefault(kb_id, set()).add(doc_id)

if not doc_ids_by_kb:
return chunks
return {}

meta_by_doc = {}
for kb_id, doc_ids in doc_ids_by_kb.items():
Expand All @@ -1287,18 +1360,82 @@ def _build_reference_chunks(reference, include_metadata=False, metadata_fields=N
if metadata_fields is not None:
metadata_fields = {f for f in metadata_fields if isinstance(f, str)}
if not metadata_fields:
return chunks
return {}

for chunk in chunks:
doc_id = chunk.get("document_id")
if not doc_id:
continue
meta = meta_by_doc.get(doc_id)
if not meta:
continue
filtered = {}
for doc_id, meta in meta_by_doc.items():
if metadata_fields is not None:
meta = {k: v for k, v in meta.items() if k in metadata_fields}
if meta:
filtered[doc_id] = meta
return filtered


def _build_agent_reference(reference, include_metadata=False, metadata_fields=None):
if not include_metadata or not isinstance(reference, dict):
return reference

meta_by_doc = _get_reference_metadata_by_doc(chunks_format(reference), metadata_fields)
if not meta_by_doc:
return reference

enriched_reference = copy.deepcopy(reference)
raw_chunks = enriched_reference.get("chunks")
if isinstance(raw_chunks, dict):
raw_chunks = raw_chunks.values()
elif not isinstance(raw_chunks, list):
return enriched_reference

for chunk in raw_chunks:
if not isinstance(chunk, dict):
continue
doc_id = chunk.get("doc_id", chunk.get("document_id"))
meta = meta_by_doc.get(doc_id)
if meta:
chunk["document_metadata"] = meta
return enriched_reference

return chunks

def _build_agent_openai_response(answer, include_metadata=False, metadata_fields=None):
if not include_metadata:
return answer

if isinstance(answer, str):
if answer.strip() == "data: [DONE]" or answer.strip() == "data:[DONE]":
return answer
try:
payload = json.loads(answer[5:])
except Exception:
return answer
payload = _build_agent_openai_response(
payload,
include_metadata=include_metadata,
metadata_fields=metadata_fields,
)
return _to_sse(payload)

if not isinstance(answer, dict):
return answer

for choice in answer.get("choices", []):
if not isinstance(choice, dict):
continue
delta = choice.get("delta")
if isinstance(delta, dict) and delta.get("reference") is not None:
delta["reference"] = _build_agent_reference(
delta["reference"],
include_metadata=include_metadata,
metadata_fields=metadata_fields,
)
message = choice.get("message")
if isinstance(message, dict) and message.get("reference") is not None:
message["reference"] = _build_agent_reference(
message["reference"],
include_metadata=include_metadata,
metadata_fields=metadata_fields,
)
return answer


def _to_sse(payload):
return "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n"