Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 131 additions & 11 deletions api/apps/sdk/dify_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,100 @@
#
import logging

from quart import jsonify
from quart import jsonify, request
from werkzeug.exceptions import BadRequest as WerkzeugBadRequest

try:
from quart.exceptions import BadRequest as QuartBadRequest
except ImportError: # pragma: no cover - optional dependency
QuartBadRequest = None

from api.db.services.document_service import DocumentService
from api.db.services.doc_metadata_service import DocMetadataService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from common.metadata_utils import meta_filter, convert_conditions
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
from api.utils.api_utils import apikey_required, build_error_result, get_request_json
from rag.app.tag import label_question
from common.constants import RetCode, LLMType
from common import settings

@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821
logger = logging.getLogger(__name__)


async def _read_retrieval_request():
try:
method = request.method
except RuntimeError:
# Unit tests may call the handler directly without a request context.
method = "POST"
if method == "GET":
query_args = request.args
retrieval_setting = {}
knowledge_id = query_args.get("knowledge_id")
query = query_args.get("query")
use_kg = str(query_args.get("use_kg", "")).lower() in {"1", "true", "yes", "on"}
top_k = query_args.get("top_k")
score_threshold = query_args.get("score_threshold")
try:
if top_k not in (None, ""):
retrieval_setting["top_k"] = int(top_k)
if score_threshold not in (None, ""):
retrieval_setting["score_threshold"] = float(score_threshold)
except (TypeError, ValueError):
raise ValueError("top_k must be integer and score_threshold must be numeric")
safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0"
logger.debug(
"Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s",
knowledge_id,
safe_query,
use_kg,
retrieval_setting.get("top_k"),
retrieval_setting.get("score_threshold"),
)

req = {
"knowledge_id": knowledge_id,
"query": query,
"use_kg": use_kg,
"retrieval_setting": retrieval_setting,
}
return req
req = await get_request_json()
knowledge_id = req.get("knowledge_id") if isinstance(req, dict) else None
query = req.get("query") if isinstance(req, dict) else None
use_kg = req.get("use_kg", False) if isinstance(req, dict) else False
retrieval_setting = req.get("retrieval_setting", {}) if isinstance(req, dict) else {}
if not isinstance(retrieval_setting, dict):
retrieval_setting = {}
safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0"
logger.debug(
"Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s",
knowledge_id,
safe_query,
use_kg,
retrieval_setting.get("top_k"),
retrieval_setting.get("score_threshold"),
)
return req


def _parse_retrieval_options(retrieval_setting):
if retrieval_setting is None:
retrieval_setting = {}
if not isinstance(retrieval_setting, dict):
raise ValueError("retrieval_setting must be an object")
try:
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
except (TypeError, ValueError):
raise ValueError("top_k must be integer and score_threshold must be numeric")
return retrieval_setting, similarity_threshold, top


@manager.route('/dify/retrieval', methods=['POST', 'GET']) # noqa: F821
@apikey_required
@validate_request("knowledge_id", "query")
async def retrieval(tenant_id):
"""
Dify-compatible retrieval API
Expand All @@ -40,9 +118,34 @@ async def retrieval(tenant_id):
security:
- ApiKeyAuth: []
parameters:
- in: query
name: knowledge_id
required: false
type: string
description: Knowledge base ID (for GET requests)
- in: query
name: query
required: false
type: string
description: Query text (for GET requests)
- in: query
name: use_kg
required: false
type: boolean
description: Whether to use knowledge graph (for GET requests)
- in: query
name: top_k
required: false
type: integer
description: Number of results to return (for GET requests)
- in: query
name: score_threshold
required: false
type: number
description: Similarity threshold (for GET requests)
- in: body
name: body
required: true
required: false
schema:
type: object
required:
Expand Down Expand Up @@ -115,15 +218,32 @@ async def retrieval(tenant_id):
404:
description: Knowledge base or document not found
"""
req = await get_request_json()
parse_exception_types = (AttributeError, TypeError, ValueError, WerkzeugBadRequest)
if QuartBadRequest is not None:
parse_exception_types = parse_exception_types + (QuartBadRequest,)
try:
req = await _read_retrieval_request()
except parse_exception_types as e:
return build_error_result(
message=f"invalid or malformed arguments: {str(e)}; ",
code=RetCode.ARGUMENT_ERROR,
)
missing = [field for field in ("knowledge_id", "query") if not req.get(field)]
if missing:
return build_error_result(
message=f"required arguments are missing: {','.join(missing)}; ",
code=RetCode.ARGUMENT_ERROR,
)
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
if top <= 0:
return build_error_result(message="`top_k` must be greater than 0", code=RetCode.DATA_ERROR)
try:
_, similarity_threshold, top = _parse_retrieval_options(req.get("retrieval_setting", {}))
except ValueError as e:
return build_error_result(
message=f"invalid or malformed arguments: {str(e)}; ",
code=RetCode.ARGUMENT_ERROR,
)
metadata_condition = req.get("metadata_condition", {}) or {}
metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,82 @@ async def retrieval(self, *_args, **_kwargs):
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.SERVER_ERROR, res
assert "boom" in res["message"], res


@pytest.mark.p2
def test_read_retrieval_request_from_get_args(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
monkeypatch.setattr(
module,
"request",
SimpleNamespace(
method="GET",
args={
"knowledge_id": "kb-1",
"query": "hello",
"use_kg": "true",
"top_k": "12",
"score_threshold": "0.66",
},
),
)

req = _run(module._read_retrieval_request())
assert req["knowledge_id"] == "kb-1", req
assert req["query"] == "hello", req
assert req["use_kg"] is True, req
assert req["retrieval_setting"]["top_k"] == 12, req
assert req["retrieval_setting"]["score_threshold"] == 0.66, req


@pytest.mark.p2
def test_read_retrieval_request_from_post_json(monkeypatch):
module = _load_dify_retrieval_module(monkeypatch)
payload = {"knowledge_id": "kb-1", "query": "hello"}
monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={}))
monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload))

req = _run(module._read_retrieval_request())
assert req == payload, req


@pytest.mark.p2
def test_retrieval_argument_error_messages(monkeypatch):
"""Guard: distinguish malformed vs missing argument errors."""
module = _load_dify_retrieval_module(monkeypatch)

# Case 1: malformed numeric options in retrieval_setting
_set_request_json(
monkeypatch,
module,
{
"knowledge_id": "kb-1",
"query": "hello",
"retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"},
},
)
res = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res["code"] == module.RetCode.ARGUMENT_ERROR, res
assert "invalid or malformed arguments:" in res["message"], res

# Case 2: missing required fields (knowledge_id, query)
_set_request_json(monkeypatch, module, {})
res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing
assert "required arguments are missing:" in res_missing["message"], res_missing

# Case 3: partially missing required field (query)
_set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"})
res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query
assert "query" in res_missing_query["message"], res_missing_query

# Case 4: retrieval_setting wrong type
_set_request_json(
monkeypatch,
module,
{"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"},
)
res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1"))
assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type
assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type
Loading