diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index cf297c4b250..c1ede1f0ee8 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -50,9 +50,13 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No enrich_chunks_with_document_metadata(chunks, metadata_fields) +def _dataset_access_actor_id(tenant_id: str, authenticated_user_id: str | None = None) -> str: + return authenticated_user_id or tenant_id + + @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @token_required -async def download(tenant_id, dataset_id, document_id): +async def download(tenant_id, dataset_id, document_id, authenticated_user_id=None): """ Download a document from a dataset. --- @@ -90,7 +94,10 @@ async def download(tenant_id, dataset_id, document_id): """ if not document_id: return get_error_data_result(message="Specify document_id please.") - if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): + if not KnowledgebaseService.accessible( + kb_id=dataset_id, + user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), + ): return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: @@ -161,7 +168,7 @@ async def download_doc(document_id): @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 @token_required -async def parse(tenant_id, dataset_id): +async def parse(tenant_id, dataset_id, authenticated_user_id=None): """ Start parsing documents into chunks. --- @@ -198,7 +205,10 @@ async def parse(tenant_id, dataset_id): schema: type: object """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + if not KnowledgebaseService.accessible( + kb_id=dataset_id, + user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), + ): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") req = await get_request_json() if not req.get("document_ids"): @@ -252,7 +262,7 @@ async def parse(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 @token_required -async def stop_parsing(tenant_id, dataset_id): +async def stop_parsing(tenant_id, dataset_id, authenticated_user_id=None): """ Stop parsing documents into chunks. --- @@ -289,7 +299,10 @@ async def stop_parsing(tenant_id, dataset_id): schema: type: object """ - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + if not KnowledgebaseService.accessible( + kb_id=dataset_id, + user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), + ): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") req = await get_request_json() @@ -329,7 +342,7 @@ async def stop_parsing(tenant_id, dataset_id): @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required -async def retrieval_test(tenant_id): +async def retrieval_test(tenant_id, authenticated_user_id=None): """ Retrieve chunks based on a query. --- @@ -416,8 +429,9 @@ async def retrieval_test(tenant_id): kb_ids = req["dataset_ids"] if not isinstance(kb_ids, list): return get_error_data_result("`dataset_ids` should be a list") + actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id) for id in kb_ids: - if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): + if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id): return get_error_data_result(f"You don't own the dataset {id}.") kbs = KnowledgebaseService.get_by_ids(kb_ids) embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index a041ee0819f..516cf4a1a21 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -286,6 +286,8 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da def token_required(func): + accepts_authenticated_user_id = "authenticated_user_id" in inspect.signature(func).parameters + @wraps(func) async def wrapper(*args, **kwargs): # Validate the token (API Key) @@ -334,6 +336,8 @@ async def wrapper(*args, **kwargs): tenants = UserTenantService.query(user_id=user[0].id) if tenants: kwargs["tenant_id"] = tenants[0].tenant_id + if accepts_authenticated_user_id: + kwargs["authenticated_user_id"] = user[0].id result = func(*args, **kwargs) if inspect.iscoroutine(result): return await result diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index ca440d4ae0f..4225095382e 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -493,11 +493,11 @@ def test_download_and_download_doc_errors(self, monkeypatch): _patch_storage(monkeypatch, module, file_stream=b"") res = _run(module.download.__wrapped__("tenant-1", "ds-1", "")) assert res["message"] == "Specify document_id please." - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False) res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1")) assert "do not own the dataset" in res["message"] - monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [1]) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: True) monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: []) res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1")) assert "not own the document" in res["message"] @@ -597,6 +597,38 @@ def test_parse_branches(self, monkeypatch): assert res["code"] == module.RetCode.DATA_ERROR assert "Duplicate document ids" in res["message"] + def test_sdk_routes_use_authenticated_user_for_dataset_access(self, monkeypatch): + module = _load_doc_module(monkeypatch) + access_calls = [] + + def _accessible(**kwargs): + access_calls.append(kwargs) + return False + + monkeypatch.setattr(module.KnowledgebaseService, "accessible", _accessible) + + res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1", authenticated_user_id="user-2")) + assert "do not own the dataset" in res["message"] + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"document_ids": ["doc-1"]})) + res = _run(module.parse.__wrapped__("tenant-1", "ds-1", authenticated_user_id="user-2")) + assert "don't own the dataset" in res["message"] + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"document_ids": ["doc-1"]})) + res = _run(module.stop_parsing.__wrapped__("tenant-1", "ds-1", authenticated_user_id="user-2")) + assert "don't own the dataset" in res["message"] + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"dataset_ids": ["ds-1"], "question": "q"})) + res = _run(module.retrieval_test.__wrapped__("tenant-1", authenticated_user_id="user-2")) + assert "don't own the dataset ds-1" in res["message"] + + assert access_calls == [ + {"kb_id": "ds-1", "user_id": "user-2"}, + {"kb_id": "ds-1", "user_id": "user-2"}, + {"kb_id": "ds-1", "user_id": "user-2"}, + {"kb_id": "ds-1", "user_id": "user-2"}, + ] + def test_stop_parsing_branches(self, monkeypatch): module = _load_doc_module(monkeypatch) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False) diff --git a/test/unit_test/api/utils/test_api_utils_token_required.py b/test/unit_test/api/utils/test_api_utils_token_required.py new file mode 100644 index 00000000000..9ca2112a90e --- /dev/null +++ b/test/unit_test/api/utils/test_api_utils_token_required.py @@ -0,0 +1,146 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from enum import Enum +from pathlib import Path +from types import ModuleType, SimpleNamespace + +from quart import Quart + + +def _run(coro): + return asyncio.run(coro) + + +def _load_api_utils_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + class _RetCode(int, Enum): + SUCCESS = 0 + EXCEPTION_ERROR = 100 + ARGUMENT_ERROR = 101 + DATA_ERROR = 102 + OPERATING_ERROR = 103 + PERMISSION_ERROR = 108 + AUTHENTICATION_ERROR = 109 + FORBIDDEN = 403 + UNAUTHORIZED = 401 + + class _ActiveEnum(str, Enum): + ACTIVE = "1" + INACTIVE = "0" + + class _StatusEnum(str, Enum): + VALID = "1" + INVALID = "0" + + common_constants_mod = ModuleType("common.constants") + common_constants_mod.RetCode = _RetCode + common_constants_mod.ActiveEnum = _ActiveEnum + common_constants_mod.StatusEnum = _StatusEnum + monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) + + common_settings_mod = ModuleType("common.settings") + common_settings_mod.get_secret_key = lambda: "test-secret" + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + common_pkg.settings = common_settings_mod + + common_misc_utils_mod = ModuleType("common.misc_utils") + common_misc_utils_mod.thread_pool_exec = lambda func, *args, **kwargs: func(*args, **kwargs) + monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc_utils_mod) + + common_connection_utils_mod = ModuleType("common.connection_utils") + common_connection_utils_mod.timeout = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "common.connection_utils", common_connection_utils_mod) + + common_mcp_tool_call_conn_mod = ModuleType("common.mcp_tool_call_conn") + common_mcp_tool_call_conn_mod.MCPToolCallSession = object + common_mcp_tool_call_conn_mod.close_multiple_mcp_toolcall_sessions = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", common_mcp_tool_call_conn_mod) + + api_db_models_mod = ModuleType("api.db.db_models") + + class _APIToken: + @staticmethod + def query(**_kwargs): + return [] + + api_db_models_mod.APIToken = _APIToken + monkeypatch.setitem(sys.modules, "api.db.db_models", api_db_models_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + tenant_llm_service_mod.LLMFactoriesService = object + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.UserService = SimpleNamespace(query=lambda **_kwargs: []) + user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: []) + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + json_encode_mod = ModuleType("api.utils.json_encode") + json_encode_mod.CustomJSONEncoder = json.JSONEncoder + monkeypatch.setitem(sys.modules, "api.utils.json_encode", json_encode_mod) + + module_name = "test_api_utils_token_required_module" + module_path = repo_root / "api" / "utils" / "api_utils.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def test_token_required_injects_authenticated_user_id_for_login_tokens(monkeypatch): + module = _load_api_utils_module(monkeypatch) + app = Quart(__name__) + + monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: []) + + user_service_mod = sys.modules["api.db.services.user_service"] + user_service_mod.UserService = SimpleNamespace( + query=lambda **kwargs: [SimpleNamespace(id="user-2")] if kwargs.get("access_token") == "raw-login-token" else [], + ) + user_service_mod.UserTenantService = SimpleNamespace( + query=lambda **kwargs: [SimpleNamespace(tenant_id="tenant-1")] if kwargs.get("user_id") == "user-2" else [], + ) + + from itsdangerous.url_safe import URLSafeTimedSerializer + + monkeypatch.setattr(URLSafeTimedSerializer, "loads", lambda self, _token: "raw-login-token") + + @module.token_required + async def _handler(tenant_id=None, authenticated_user_id=None): + return { + "tenant_id": tenant_id, + "authenticated_user_id": authenticated_user_id, + } + + async def _case(): + async with app.test_request_context("/", headers={"Authorization": "Bearer login-token"}): + return await _handler() + + assert _run(_case()) == { + "tenant_id": "tenant-1", + "authenticated_user_id": "user-2", + }