diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index e85a1d439c5..ab0e1262696 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -15,7 +15,13 @@ # import logging -from quart import jsonify +from quart import jsonify, request +from werkzeug.exceptions import BadRequest as WerkzeugBadRequest + +try: + from quart.exceptions import BadRequest as QuartBadRequest +except ImportError: # pragma: no cover - optional dependency + QuartBadRequest = None from api.db.services.document_service import DocumentService from api.db.services.doc_metadata_service import DocMetadataService @@ -23,14 +29,86 @@ from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from common.metadata_utils import meta_filter, convert_conditions -from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request +from api.utils.api_utils import apikey_required, build_error_result, get_request_json from rag.app.tag import label_question from common.constants import RetCode, LLMType from common import settings -@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 +logger = logging.getLogger(__name__) + + +async def _read_retrieval_request(): + try: + method = request.method + except RuntimeError: + # Unit tests may call the handler directly without a request context. + method = "POST" + if method == "GET": + query_args = request.args + retrieval_setting = {} + knowledge_id = query_args.get("knowledge_id") + query = query_args.get("query") + use_kg = str(query_args.get("use_kg", "")).lower() in {"1", "true", "yes", "on"} + top_k = query_args.get("top_k") + score_threshold = query_args.get("score_threshold") + try: + if top_k not in (None, ""): + retrieval_setting["top_k"] = int(top_k) + if score_threshold not in (None, ""): + retrieval_setting["score_threshold"] = float(score_threshold) + except (TypeError, ValueError): + raise ValueError("top_k must be integer and score_threshold must be numeric") + safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0" + logger.debug( + "Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s", + knowledge_id, + safe_query, + use_kg, + retrieval_setting.get("top_k"), + retrieval_setting.get("score_threshold"), + ) + + req = { + "knowledge_id": knowledge_id, + "query": query, + "use_kg": use_kg, + "retrieval_setting": retrieval_setting, + } + return req + req = await get_request_json() + knowledge_id = req.get("knowledge_id") if isinstance(req, dict) else None + query = req.get("query") if isinstance(req, dict) else None + use_kg = req.get("use_kg", False) if isinstance(req, dict) else False + retrieval_setting = req.get("retrieval_setting", {}) if isinstance(req, dict) else {} + if not isinstance(retrieval_setting, dict): + retrieval_setting = {} + safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0" + logger.debug( + "Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s", + knowledge_id, + safe_query, + use_kg, + retrieval_setting.get("top_k"), + retrieval_setting.get("score_threshold"), + ) + return req + + +def _parse_retrieval_options(retrieval_setting): + if retrieval_setting is None: + retrieval_setting = {} + if not isinstance(retrieval_setting, dict): + raise ValueError("retrieval_setting must be an object") + try: + similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) + top = int(retrieval_setting.get("top_k", 1024)) + except (TypeError, ValueError): + raise ValueError("top_k must be integer and score_threshold must be numeric") + return retrieval_setting, similarity_threshold, top + + +@manager.route('/dify/retrieval', methods=['POST', 'GET']) # noqa: F821 @apikey_required -@validate_request("knowledge_id", "query") async def retrieval(tenant_id): """ Dify-compatible retrieval API @@ -40,9 +118,34 @@ async def retrieval(tenant_id): security: - ApiKeyAuth: [] parameters: + - in: query + name: knowledge_id + required: false + type: string + description: Knowledge base ID (for GET requests) + - in: query + name: query + required: false + type: string + description: Query text (for GET requests) + - in: query + name: use_kg + required: false + type: boolean + description: Whether to use knowledge graph (for GET requests) + - in: query + name: top_k + required: false + type: integer + description: Number of results to return (for GET requests) + - in: query + name: score_threshold + required: false + type: number + description: Similarity threshold (for GET requests) - in: body name: body - required: true + required: false schema: type: object required: @@ -115,15 +218,32 @@ async def retrieval(tenant_id): 404: description: Knowledge base or document not found """ - req = await get_request_json() + parse_exception_types = (AttributeError, TypeError, ValueError, WerkzeugBadRequest) + if QuartBadRequest is not None: + parse_exception_types = parse_exception_types + (QuartBadRequest,) + try: + req = await _read_retrieval_request() + except parse_exception_types as e: + return build_error_result( + message=f"invalid or malformed arguments: {str(e)}; ", + code=RetCode.ARGUMENT_ERROR, + ) + missing = [field for field in ("knowledge_id", "query") if not req.get(field)] + if missing: + return build_error_result( + message=f"required arguments are missing: {','.join(missing)}; ", + code=RetCode.ARGUMENT_ERROR, + ) question = req["query"] kb_id = req["knowledge_id"] use_kg = req.get("use_kg", False) - retrieval_setting = req.get("retrieval_setting", {}) - similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) - top = int(retrieval_setting.get("top_k", 1024)) - if top <= 0: - return build_error_result(message="`top_k` must be greater than 0", code=RetCode.DATA_ERROR) + try: + _, similarity_threshold, top = _parse_retrieval_options(req.get("retrieval_setting", {})) + except ValueError as e: + return build_error_result( + message=f"invalid or malformed arguments: {str(e)}; ", + code=RetCode.ARGUMENT_ERROR, + ) metadata_condition = req.get("metadata_condition", {}) or {} metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id]) diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index ac98d9e1d33..8234866e82f 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -352,3 +352,82 @@ async def retrieval(self, *_args, **_kwargs): res = _run(inspect.unwrap(module.retrieval)("tenant-1")) assert res["code"] == module.RetCode.SERVER_ERROR, res assert "boom" in res["message"], res + + +@pytest.mark.p2 +def test_read_retrieval_request_from_get_args(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + method="GET", + args={ + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": "true", + "top_k": "12", + "score_threshold": "0.66", + }, + ), + ) + + req = _run(module._read_retrieval_request()) + assert req["knowledge_id"] == "kb-1", req + assert req["query"] == "hello", req + assert req["use_kg"] is True, req + assert req["retrieval_setting"]["top_k"] == 12, req + assert req["retrieval_setting"]["score_threshold"] == 0.66, req + + +@pytest.mark.p2 +def test_read_retrieval_request_from_post_json(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + payload = {"knowledge_id": "kb-1", "query": "hello"} + monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={})) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) + + req = _run(module._read_retrieval_request()) + assert req == payload, req + + +@pytest.mark.p2 +def test_retrieval_argument_error_messages(monkeypatch): + """Guard: distinguish malformed vs missing argument errors.""" + module = _load_dify_retrieval_module(monkeypatch) + + # Case 1: malformed numeric options in retrieval_setting + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"}, + }, + ) + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "invalid or malformed arguments:" in res["message"], res + + # Case 2: missing required fields (knowledge_id, query) + _set_request_json(monkeypatch, module, {}) + res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing + assert "required arguments are missing:" in res_missing["message"], res_missing + + # Case 3: partially missing required field (query) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"}) + res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query + assert "query" in res_missing_query["message"], res_missing_query + + # Case 4: retrieval_setting wrong type + _set_request_json( + monkeypatch, + module, + {"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"}, + ) + res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type + assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type