Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 98 additions & 58 deletions api/apps/restful_apis/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import re
import tempfile
from copy import deepcopy
from types import SimpleNamespace

from quart import Response, request

Expand All @@ -30,7 +31,7 @@
)
from api.db.services.chunk_feedback_service import ChunkFeedbackService
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
Expand Down Expand Up @@ -67,6 +68,15 @@
"tts": False,
"refine_multiturn": True,
}
_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = {
"system": "",
"prologue": "",
"parameters": [],
"empty_response": "",
"quote": False,
"tts": False,
"refine_multiturn": True,
}
_DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"}
_READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"}
_PERSISTED_FIELDS = set(DialogService.model._meta.fields)
Expand Down Expand Up @@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id):
)


def _build_default_completion_dialog():
return SimpleNamespace(
tenant_id=current_user.id,
llm_id="",
tenant_llm_id=None,
llm_setting={},
prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG),
kb_ids=[],
top_n=6,
top_k=1024,
rerank_id="",
similarity_threshold=0.1,
vector_similarity_weight=0.3,
meta_data_filter=None,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _create_session_for_completion(chat_id, dialog, user_id):
conv = {
"id": get_uuid(),
"dialog_id": chat_id,
"name": "New session",
"message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}],
"user_id": user_id,
"reference": [],
}
ConversationService.save(**conv)
ok, conv_obj = ConversationService.get_by_id(conv["id"])
if not ok:
raise LookupError("Fail to create a session!")
return conv_obj


def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
if not llm_id:
return None
Expand Down Expand Up @@ -671,7 +714,7 @@ async def get_session(chat_id, session_id):
return server_error_response(ex)


@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821
@login_required
async def update_session(chat_id, session_id):
if not _ensure_owned_chat(chat_id):
Expand Down Expand Up @@ -829,7 +872,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
return server_error_response(ex)


@manager.route("/chats/tts", methods=["POST"]) # noqa: F821
@manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821
@login_required
async def tts():
req = await get_request_json()
Expand Down Expand Up @@ -857,9 +900,9 @@ def stream_audio():
return resp


@manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821
@manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821
@login_required
async def transcriptions():
async def transcription():
req = await request.form
stream_mode = req.get("stream", "false").lower() == "true"
files = await request.files
Expand Down Expand Up @@ -915,7 +958,7 @@ async def event_stream():
return Response(event_stream(), content_type="text/event-stream")


@manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821
@manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
async def mindmap():
Expand All @@ -933,10 +976,10 @@ async def mindmap():
return get_json_result(data=mind_map)


@manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821
@manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
async def related_questions():
async def recommendation():
req = await get_request_json()

search_id = req.get("search_id", "")
Expand Down Expand Up @@ -971,10 +1014,10 @@ async def related_questions():
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])


@manager.route("/chats/<chat_id>/sessions/<session_id>/completions", methods=["POST"]) # noqa: F821
@manager.route("/chat/completions", methods=["POST"]) # noqa: F821
@login_required
@validate_request("messages")
async def session_completion(chat_id, session_id):
async def session_completion():
req = await get_request_json()
msg = []
for m in req["messages"]:
Expand All @@ -984,6 +1027,8 @@ async def session_completion(chat_id, session_id):
continue
msg.append(m)
message_id = msg[-1].get("id") if msg else None
chat_id = req.pop("chat_id", "") or ""
session_id = req.pop("session_id", "") or ""
chat_model_id = req.pop("llm_id", "")

chat_model_config = {}
Expand All @@ -993,38 +1038,63 @@ async def session_completion(chat_id, session_id):
chat_model_config[model_config] = config

try:
e, conv = ConversationService.get_by_id(session_id)
if not e:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
conv.message = deepcopy(req["messages"])
e, dia = DialogService.get_by_id(chat_id)
if not e:
return get_data_error_result(message="Chat not found!")
conv = None
if session_id and not chat_id:
return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")

if chat_id:
if not _ensure_owned_chat(chat_id):
return get_json_result(
data=False,
message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR,
)
e, dia = DialogService.get_by_id(chat_id)
if not e:
return get_data_error_result(message="Chat not found!")
if session_id:
e, conv = ConversationService.get_by_id(session_id)
if not e:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
else:
conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
session_id = conv.id
Comment on lines +1061 to +1063
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't persist a new session before the completion request is known-good.

A brand-new session is saved before llm_id validation and before the first successful model response is written back. Invalid model overrides or downstream chat failures will therefore leave behind empty "New session" records.

Also applies to: 1077-1079, 1098-1099, 1116-1117

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/chat_api.py` around lines 1061 - 1063, Currently a new
session is created and persisted immediately via _create_session_for_completion
(producing session_id) before validating llm_id and before a successful model
response; change this so you only instantiate an in-memory session object (do
not save) when branching to _create_session_for_completion, validate llm_id and
attempt the first model call, and only call the persistence path that writes the
"New session" record (the code that currently uses conv.id / session_id) after
the model response succeeds; apply the same change to the other occurrences
around the blocks referenced (the similar calls at the same pattern near the
later branches), ensuring any user_id derivation (req.get("user_id",
current_user.id)) is preserved but persistence is deferred until success.

conv.message = deepcopy(req["messages"])
Comment on lines +1055 to +1064
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep the stored session history authoritative when session_id is present.

This loads the existing Conversation and then immediately replaces conv.message with req["messages"]. Any client that sends only the new turn—or sends a trimmed history—will silently erase persisted context and desync the saved references from the message list.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/chat_api.py` around lines 1055 - 1064, When an existing
session is loaded via ConversationService.get_by_id, do not unconditionally
overwrite conv.message with req["messages"]; instead preserve the authoritative
stored history and only incorporate new turns from the request. Update the code
around ConversationService.get_by_id / conv.message so that when session_id is
present you: (a) if req["messages"] appears incremental (e.g., its last item(s)
are not in conv.message) append only the new messages to conv.message, or (b) if
req["messages"] appears to be a full history, first verify it shares the same
prefix as conv.message and only replace/extend when it is strictly longer and
consistent. For new sessions created via _create_session_for_completion keep the
existing behavior of setting conv.message from req["messages"]. Ensure
session_id, chat_id and conv.id references remain unchanged.

else:
dia = _build_default_completion_dialog()
dia.llm_setting = chat_model_config

del req["messages"]

if not conv.reference:
conv.reference = []
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})
if conv is not None:
if not conv.reference:
conv.reference = []
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})

if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config

is_embedded = bool(chat_model_id)
stream_mode = req.pop("stream", True)

def _format_answer(ans):
formatted = structure_answer(conv, ans, message_id, session_id)
if chat_id:
formatted["chat_id"] = chat_id
return formatted

async def stream():
nonlocal dia, msg, req, conv
try:
async for ans in async_chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
ans = _format_answer(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if not is_embedded:
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as ex:
logging.exception(ex)
Expand All @@ -1041,40 +1111,10 @@ async def stream():

answer = None
async for ans in async_chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
if not is_embedded:
answer = _format_answer(ans)
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as ex:
return server_error_response(ex)


@manager.route("/chats/ask", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question", "kb_ids")
async def ask():
req = await get_request_json()
uid = current_user.id

search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})

async def stream():
nonlocal req, uid
try:
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as ex:
yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
47 changes: 46 additions & 1 deletion api/apps/restful_apis/search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
# limitations under the License.
#

from quart import request
import json

from quart import Response, request
from api.db.services.dialog_service import async_ask
from api.apps import current_user, login_required

from api.constants import DATASET_NAME_LIMIT
Expand Down Expand Up @@ -168,3 +171,45 @@ def delete_search(search_id):
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)


@manager.route("/searches/<search_id>/completion", methods=["POST"]) # noqa: F821
@login_required
@validate_request("question")
async def completion(search_id):
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(
data=False,
message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR,
)

req = await get_request_json()
uid = current_user.id
search_app = SearchService.get_detail(search_id)
if not search_app:
return get_data_error_result(message=f"Cannot find search {search_id}")

search_config = search_app.get("search_config", {})
kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or []
if not kb_ids:
return get_data_error_result(message="`kb_ids` is required.")

async def stream():
nonlocal req, uid, kb_ids, search_config
try:
async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as ex:
yield "data:" + json.dumps(
{"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}},
ensure_ascii=False,
) + "\n\n"
Comment on lines +203 to +207
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid exposing raw exception details to clients.

The error handler yields str(ex) directly to the client, which may leak internal implementation details or sensitive information. Consider using a generic error message while logging the full exception server-side.

🛡️ Proposed fix to sanitize error response
         except Exception as ex:
+            logger.exception(f"Error in search completion for search_id={search_id}")
             yield "data:" + json.dumps(
-                {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}},
+                {"code": 500, "message": "An internal error occurred.", "data": {"answer": "**ERROR**: An internal error occurred.", "reference": []}},
                 ensure_ascii=False,
             ) + "\n\n"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/search_api.py` around lines 203 - 207, The except block
currently yields str(ex) directly to clients (see the except Exception as ex
handler in search_api.py), which can leak internal details; change it to return
a generic error payload (e.g., code 500 and a non-sensitive "Internal server
error" message and sanitized data.answer) while logging the full exception
server-side using logger.exception or similar inside the same except block so
the stack trace is preserved for debugging; update the yield in that handler to
use the generic message and keep logging of ex separate and detailed.

yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

resp = Response(stream(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
Comment on lines +176 to +215
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add logging for the new completion flow.

The new endpoint lacks logging for request handling and error conditions. As per coding guidelines, new flows should include logging for observability and debugging.

🔧 Proposed fix to add logging
+import logging
+
+logger = logging.getLogger(__name__)
+
 `@manager.route`("/searches/<search_id>/completion", methods=["POST"])  # noqa: F821
 `@login_required`
 `@validate_request`("question")
 async def completion(search_id):
+    logger.info(f"Search completion request for search_id={search_id}")
     if not SearchService.accessible4deletion(search_id, current_user.id):
+        logger.warning(f"Unauthorized access attempt to search_id={search_id} by user={current_user.id}")
         return get_json_result(
             data=False,
             message="No authorization.",
             code=RetCode.AUTHENTICATION_ERROR,
         )

As per coding guidelines: **/*.py: Add logging for new flows.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@api/apps/restful_apis/search_api.py` around lines 176 - 215, Add structured
logging for the new completion flow: ensure a module-level logger (logger =
logging.getLogger(__name__)) exists, then instrument the completion handler and
inner stream generator to log key events — entry to completion (log search_id
and current_user.id), authorization failure (when
SearchService.accessible4deletion returns False), missing search_app (when
get_detail returns None), missing kb_ids, and before starting the async_ask
loop; inside the stream() except block log the full exception with
logger.exception or logger.error(..., exc_info=True) when async_ask yields
errors and also log completion of the stream (success path) — reference function
names completion, stream, async_ask and service SearchService/get_detail to
locate where to add these logs.

11 changes: 8 additions & 3 deletions docs/guides/chat/set_chat_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,19 @@ See [Converse with chat assistant](../../references/http_api_reference.md#conver

```json {9}
curl --request POST \
--url http://{address}/api/v1/chats/{chat_id}/completions \
--url http://{address}/api/v1/chat/completions \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data-binary '
{
"question": "xxxxxxxxx",
"chat_id": "{chat_id}",
"stream": true,
"messages": [
{
"role": "user",
"content": "xxxxxxxxx"
}
],
"style":"hilarious"
}'
Comment on lines 73 to 89
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Document the session_id follow-up step here.

This example only sends chat_id, which means each call starts a brand-new session under the new /chat/completions flow. Please add a note that callers need to capture the returned session_id and send it on subsequent requests if they want multi-turn continuity.

💡 Suggested doc tweak
 ```json {9}
 curl --request POST \
      --url http://{address}/api/v1/chat/completions \
      --header 'Content-Type: application/json' \
      --header 'Authorization: Bearer <YOUR_API_KEY>' \
      --data-binary '
      {
           "chat_id": "{chat_id}",
           "stream": true,
           "messages": [
               {
                   "role": "user",
                   "content": "xxxxxxxxx"
               }
           ],
           "style":"hilarious"
      }'

+The first successful response returns a session_id. Reuse that session_id in later
+/chat/completions calls to continue the same conversation; otherwise a new session is created.

</details>

<!-- suggestion_start -->

<details>
<summary>📝 Committable suggestion</summary>

> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

```suggestion

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/guides/chat/set_chat_variables.md` around lines 73 - 89, Add a note to
the /chat/completions example explaining that the API returns a session_id on
the first successful call and that callers must capture and include that
session_id in subsequent /chat/completions requests to maintain multi-turn
continuity (otherwise each call with only chat_id will start a new session);
reference the response field name session_id and the request endpoint
/chat/completions so readers know where to read and reuse it.

```
Expand Down Expand Up @@ -109,4 +115,3 @@ while True:
print(ans.content[len(cont):], end='', flush=True)
cont = ans.content
```

Loading
Loading