-
Notifications
You must be signed in to change notification settings - Fork 9.1k
Refa: align chat and search restful APIs #14229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f7cdd34
aeb2923
586fd55
f47a343
ac49dcb
9ce4d04
b33d5c0
b1488ca
e992200
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import re | ||
| import tempfile | ||
| from copy import deepcopy | ||
| from types import SimpleNamespace | ||
|
|
||
| from quart import Response, request | ||
|
|
||
|
|
@@ -30,7 +31,7 @@ | |
| ) | ||
| from api.db.services.chunk_feedback_service import ChunkFeedbackService | ||
| from api.db.services.conversation_service import ConversationService, structure_answer | ||
| from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap | ||
| from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap | ||
| from api.db.services.knowledgebase_service import KnowledgebaseService | ||
| from api.db.services.llm_service import LLMBundle | ||
| from api.db.services.search_service import SearchService | ||
|
|
@@ -67,6 +68,15 @@ | |
| "tts": False, | ||
| "refine_multiturn": True, | ||
| } | ||
| _DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = { | ||
| "system": "", | ||
| "prologue": "", | ||
| "parameters": [], | ||
| "empty_response": "", | ||
| "quote": False, | ||
| "tts": False, | ||
| "refine_multiturn": True, | ||
| } | ||
| _DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"} | ||
| _READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"} | ||
| _PERSISTED_FIELDS = set(DialogService.model._meta.fields) | ||
|
|
@@ -124,6 +134,39 @@ def _ensure_owned_chat(chat_id): | |
| ) | ||
|
|
||
|
|
||
| def _build_default_completion_dialog(): | ||
| return SimpleNamespace( | ||
| tenant_id=current_user.id, | ||
| llm_id="", | ||
| tenant_llm_id=None, | ||
| llm_setting={}, | ||
| prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG), | ||
| kb_ids=[], | ||
| top_n=6, | ||
| top_k=1024, | ||
| rerank_id="", | ||
| similarity_threshold=0.1, | ||
| vector_similarity_weight=0.3, | ||
| meta_data_filter=None, | ||
| ) | ||
|
|
||
|
|
||
| def _create_session_for_completion(chat_id, dialog, user_id): | ||
| conv = { | ||
| "id": get_uuid(), | ||
| "dialog_id": chat_id, | ||
| "name": "New session", | ||
| "message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}], | ||
| "user_id": user_id, | ||
| "reference": [], | ||
| } | ||
| ConversationService.save(**conv) | ||
| ok, conv_obj = ConversationService.get_by_id(conv["id"]) | ||
| if not ok: | ||
| raise LookupError("Fail to create a session!") | ||
| return conv_obj | ||
|
|
||
|
|
||
| def _validate_llm_id(llm_id, tenant_id, llm_setting=None): | ||
| if not llm_id: | ||
| return None | ||
|
|
@@ -671,7 +714,7 @@ async def get_session(chat_id, session_id): | |
| return server_error_response(ex) | ||
|
|
||
|
|
||
| @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821 | ||
| @manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821 | ||
| @login_required | ||
| async def update_session(chat_id, session_id): | ||
| if not _ensure_owned_chat(chat_id): | ||
|
|
@@ -829,7 +872,7 @@ async def update_message_feedback(chat_id, session_id, msg_id): | |
| return server_error_response(ex) | ||
|
|
||
|
|
||
| @manager.route("/chats/tts", methods=["POST"]) # noqa: F821 | ||
| @manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| async def tts(): | ||
| req = await get_request_json() | ||
|
|
@@ -857,9 +900,9 @@ def stream_audio(): | |
| return resp | ||
|
|
||
|
|
||
| @manager.route("/chats/transcriptions", methods=["POST"]) # noqa: F821 | ||
| @manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| async def transcriptions(): | ||
| async def transcription(): | ||
| req = await request.form | ||
| stream_mode = req.get("stream", "false").lower() == "true" | ||
| files = await request.files | ||
|
|
@@ -915,7 +958,7 @@ async def event_stream(): | |
| return Response(event_stream(), content_type="text/event-stream") | ||
|
|
||
|
|
||
| @manager.route("/chats/mindmap", methods=["POST"]) # noqa: F821 | ||
| @manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| @validate_request("question", "kb_ids") | ||
| async def mindmap(): | ||
|
|
@@ -933,10 +976,10 @@ async def mindmap(): | |
| return get_json_result(data=mind_map) | ||
|
|
||
|
|
||
| @manager.route("/chats/related_questions", methods=["POST"]) # noqa: F821 | ||
| @manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| @validate_request("question") | ||
| async def related_questions(): | ||
| async def recommendation(): | ||
| req = await get_request_json() | ||
|
|
||
| search_id = req.get("search_id", "") | ||
|
|
@@ -971,10 +1014,10 @@ async def related_questions(): | |
| return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) | ||
|
|
||
|
|
||
| @manager.route("/chats/<chat_id>/sessions/<session_id>/completions", methods=["POST"]) # noqa: F821 | ||
| @manager.route("/chat/completions", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| @validate_request("messages") | ||
| async def session_completion(chat_id, session_id): | ||
| async def session_completion(): | ||
| req = await get_request_json() | ||
| msg = [] | ||
| for m in req["messages"]: | ||
|
|
@@ -984,6 +1027,8 @@ async def session_completion(chat_id, session_id): | |
| continue | ||
| msg.append(m) | ||
| message_id = msg[-1].get("id") if msg else None | ||
| chat_id = req.pop("chat_id", "") or "" | ||
| session_id = req.pop("session_id", "") or "" | ||
| chat_model_id = req.pop("llm_id", "") | ||
|
|
||
| chat_model_config = {} | ||
|
|
@@ -993,38 +1038,63 @@ async def session_completion(chat_id, session_id): | |
| chat_model_config[model_config] = config | ||
|
|
||
| try: | ||
| e, conv = ConversationService.get_by_id(session_id) | ||
| if not e: | ||
| return get_data_error_result(message="Session not found!") | ||
| if conv.dialog_id != chat_id: | ||
| return get_data_error_result(message="Session does not belong to this chat!") | ||
| conv.message = deepcopy(req["messages"]) | ||
| e, dia = DialogService.get_by_id(chat_id) | ||
| if not e: | ||
| return get_data_error_result(message="Chat not found!") | ||
| conv = None | ||
| if session_id and not chat_id: | ||
| return get_data_error_result(message="`chat_id` is required when `session_id` is provided.") | ||
|
|
||
| if chat_id: | ||
| if not _ensure_owned_chat(chat_id): | ||
| return get_json_result( | ||
| data=False, | ||
| message="No authorization.", | ||
| code=RetCode.AUTHENTICATION_ERROR, | ||
| ) | ||
| e, dia = DialogService.get_by_id(chat_id) | ||
| if not e: | ||
| return get_data_error_result(message="Chat not found!") | ||
| if session_id: | ||
| e, conv = ConversationService.get_by_id(session_id) | ||
| if not e: | ||
| return get_data_error_result(message="Session not found!") | ||
| if conv.dialog_id != chat_id: | ||
| return get_data_error_result(message="Session does not belong to this chat!") | ||
| else: | ||
| conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) | ||
| session_id = conv.id | ||
|
Comment on lines
+1061
to
+1063
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't persist a new session before the completion request is known-good. A brand-new session is saved before Also applies to: 1077-1079, 1098-1099, 1116-1117 🤖 Prompt for AI Agents |
||
| conv.message = deepcopy(req["messages"]) | ||
|
Comment on lines
+1055
to
+1064
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep the stored session history authoritative when This loads the existing 🤖 Prompt for AI Agents |
||
| else: | ||
| dia = _build_default_completion_dialog() | ||
| dia.llm_setting = chat_model_config | ||
|
|
||
| del req["messages"] | ||
|
|
||
| if not conv.reference: | ||
| conv.reference = [] | ||
| conv.reference = [r for r in conv.reference if r] | ||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | ||
| if conv is not None: | ||
| if not conv.reference: | ||
| conv.reference = [] | ||
| conv.reference = [r for r in conv.reference if r] | ||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | ||
|
|
||
| if chat_model_id: | ||
| if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): | ||
| return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") | ||
| dia.llm_id = chat_model_id | ||
| dia.llm_setting = chat_model_config | ||
|
|
||
| is_embedded = bool(chat_model_id) | ||
| stream_mode = req.pop("stream", True) | ||
|
|
||
| def _format_answer(ans): | ||
| formatted = structure_answer(conv, ans, message_id, session_id) | ||
| if chat_id: | ||
| formatted["chat_id"] = chat_id | ||
| return formatted | ||
|
|
||
| async def stream(): | ||
| nonlocal dia, msg, req, conv | ||
| try: | ||
| async for ans in async_chat(dia, msg, True, **req): | ||
| ans = structure_answer(conv, ans, message_id, conv.id) | ||
| ans = _format_answer(ans) | ||
| yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" | ||
| if not is_embedded: | ||
| if conv is not None: | ||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | ||
| except Exception as ex: | ||
| logging.exception(ex) | ||
|
|
@@ -1041,40 +1111,10 @@ async def stream(): | |
|
|
||
| answer = None | ||
| async for ans in async_chat(dia, msg, **req): | ||
| answer = structure_answer(conv, ans, message_id, conv.id) | ||
| if not is_embedded: | ||
| answer = _format_answer(ans) | ||
| if conv is not None: | ||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | ||
| break | ||
| return get_json_result(data=answer) | ||
| except Exception as ex: | ||
| return server_error_response(ex) | ||
|
|
||
|
|
||
| @manager.route("/chats/ask", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| @validate_request("question", "kb_ids") | ||
| async def ask(): | ||
| req = await get_request_json() | ||
| uid = current_user.id | ||
|
|
||
| search_id = req.get("search_id", "") | ||
| search_config = {} | ||
| if search_id: | ||
| if search_app := SearchService.get_detail(search_id): | ||
| search_config = search_app.get("search_config", {}) | ||
|
|
||
| async def stream(): | ||
| nonlocal req, uid | ||
| try: | ||
| async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): | ||
| yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" | ||
| except Exception as ex: | ||
| yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" | ||
| yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" | ||
|
|
||
| resp = Response(stream(), mimetype="text/event-stream") | ||
| resp.headers.add_header("Cache-control", "no-cache") | ||
| resp.headers.add_header("Connection", "keep-alive") | ||
| resp.headers.add_header("X-Accel-Buffering", "no") | ||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | ||
| return resp | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,10 @@ | |
| # limitations under the License. | ||
| # | ||
|
|
||
| from quart import request | ||
| import json | ||
|
|
||
| from quart import Response, request | ||
| from api.db.services.dialog_service import async_ask | ||
| from api.apps import current_user, login_required | ||
|
|
||
| from api.constants import DATASET_NAME_LIMIT | ||
|
|
@@ -168,3 +171,45 @@ def delete_search(search_id): | |
| return get_json_result(data=True) | ||
| except Exception as e: | ||
| return server_error_response(e) | ||
|
|
||
|
|
||
| @manager.route("/searches/<search_id>/completion", methods=["POST"]) # noqa: F821 | ||
| @login_required | ||
| @validate_request("question") | ||
| async def completion(search_id): | ||
| if not SearchService.accessible4deletion(search_id, current_user.id): | ||
| return get_json_result( | ||
| data=False, | ||
| message="No authorization.", | ||
| code=RetCode.AUTHENTICATION_ERROR, | ||
| ) | ||
|
|
||
| req = await get_request_json() | ||
| uid = current_user.id | ||
| search_app = SearchService.get_detail(search_id) | ||
| if not search_app: | ||
| return get_data_error_result(message=f"Cannot find search {search_id}") | ||
|
|
||
| search_config = search_app.get("search_config", {}) | ||
| kb_ids = search_config.get("kb_ids") or req.get("kb_ids") or [] | ||
| if not kb_ids: | ||
| return get_data_error_result(message="`kb_ids` is required.") | ||
|
|
||
| async def stream(): | ||
| nonlocal req, uid, kb_ids, search_config | ||
| try: | ||
| async for ans in async_ask(req["question"], kb_ids, uid, search_config=search_config): | ||
| yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" | ||
| except Exception as ex: | ||
| yield "data:" + json.dumps( | ||
| {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, | ||
| ensure_ascii=False, | ||
| ) + "\n\n" | ||
|
Comment on lines
+203
to
+207
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid exposing raw exception details to clients. The error handler yields 🛡️ Proposed fix to sanitize error response except Exception as ex:
+ logger.exception(f"Error in search completion for search_id={search_id}")
yield "data:" + json.dumps(
- {"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}},
+ {"code": 500, "message": "An internal error occurred.", "data": {"answer": "**ERROR**: An internal error occurred.", "reference": []}},
ensure_ascii=False,
) + "\n\n"🤖 Prompt for AI Agents |
||
| yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" | ||
|
|
||
| resp = Response(stream(), mimetype="text/event-stream") | ||
| resp.headers.add_header("Cache-control", "no-cache") | ||
| resp.headers.add_header("Connection", "keep-alive") | ||
| resp.headers.add_header("X-Accel-Buffering", "no") | ||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | ||
| return resp | ||
|
Comment on lines
+176
to
+215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for the new completion flow. The new endpoint lacks logging for request handling and error conditions. As per coding guidelines, new flows should include logging for observability and debugging. 🔧 Proposed fix to add logging+import logging
+
+logger = logging.getLogger(__name__)
+
`@manager.route`("/searches/<search_id>/completion", methods=["POST"]) # noqa: F821
`@login_required`
`@validate_request`("question")
async def completion(search_id):
+ logger.info(f"Search completion request for search_id={search_id}")
if not SearchService.accessible4deletion(search_id, current_user.id):
+ logger.warning(f"Unauthorized access attempt to search_id={search_id} by user={current_user.id}")
return get_json_result(
data=False,
message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR,
)As per coding guidelines: 🤖 Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,13 +72,19 @@ See [Converse with chat assistant](../../references/http_api_reference.md#conver | |
|
|
||
| ```json {9} | ||
| curl --request POST \ | ||
| --url http://{address}/api/v1/chats/{chat_id}/completions \ | ||
| --url http://{address}/api/v1/chat/completions \ | ||
| --header 'Content-Type: application/json' \ | ||
| --header 'Authorization: Bearer <YOUR_API_KEY>' \ | ||
| --data-binary ' | ||
| { | ||
| "question": "xxxxxxxxx", | ||
| "chat_id": "{chat_id}", | ||
| "stream": true, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": "xxxxxxxxx" | ||
| } | ||
| ], | ||
| "style":"hilarious" | ||
| }' | ||
|
Comment on lines
73
to
89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the This example only sends 💡 Suggested doc tweak ```json {9}
curl --request POST \
--url http://{address}/api/v1/chat/completions \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data-binary '
{
"chat_id": "{chat_id}",
"stream": true,
"messages": [
{
"role": "user",
"content": "xxxxxxxxx"
}
],
"style":"hilarious"
}'+The first successful response returns a 🤖 Prompt for AI Agents |
||
| ``` | ||
|
|
@@ -109,4 +115,3 @@ while True: | |
| print(ans.content[len(cont):], end='', flush=True) | ||
| cont = ans.content | ||
| ``` | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.