-
Notifications
You must be signed in to change notification settings - Fork 9.2k
fix(api): close private dataset doc auth bypass #14749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -50,9 +50,13 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No | |||||||||||||||||||||||
| enrich_chunks_with_document_metadata(chunks, metadata_fields) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _dataset_access_actor_id(tenant_id: str, authenticated_user_id: str | None = None) -> str: | ||||||||||||||||||||||||
| return authenticated_user_id or tenant_id | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821 | ||||||||||||||||||||||||
| @token_required | ||||||||||||||||||||||||
| async def download(tenant_id, dataset_id, document_id): | ||||||||||||||||||||||||
| async def download(tenant_id, dataset_id, document_id, authenticated_user_id=None): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Download a document from a dataset. | ||||||||||||||||||||||||
| --- | ||||||||||||||||||||||||
|
|
@@ -90,7 +94,10 @@ async def download(tenant_id, dataset_id, document_id): | |||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if not document_id: | ||||||||||||||||||||||||
| return get_error_data_result(message="Specify document_id please.") | ||||||||||||||||||||||||
| if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible( | ||||||||||||||||||||||||
| kb_id=dataset_id, | ||||||||||||||||||||||||
| user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") | ||||||||||||||||||||||||
| doc = DocumentService.query(kb_id=dataset_id, id=document_id) | ||||||||||||||||||||||||
| if not doc: | ||||||||||||||||||||||||
|
|
@@ -161,7 +168,7 @@ async def download_doc(document_id): | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821 | ||||||||||||||||||||||||
| @token_required | ||||||||||||||||||||||||
| async def parse(tenant_id, dataset_id): | ||||||||||||||||||||||||
| async def parse(tenant_id, dataset_id, authenticated_user_id=None): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Start parsing documents into chunks. | ||||||||||||||||||||||||
| --- | ||||||||||||||||||||||||
|
|
@@ -198,7 +205,10 @@ async def parse(tenant_id, dataset_id): | |||||||||||||||||||||||
| schema: | ||||||||||||||||||||||||
| type: object | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible( | ||||||||||||||||||||||||
| kb_id=dataset_id, | ||||||||||||||||||||||||
| user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
|
Comment on lines
+208
to
+211
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for dataset access authorization checks. Similar to the As per coding guidelines, " 🤖 Prompt for AI Agents |
||||||||||||||||||||||||
| return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") | ||||||||||||||||||||||||
| req = await get_request_json() | ||||||||||||||||||||||||
| if not req.get("document_ids"): | ||||||||||||||||||||||||
|
|
@@ -252,7 +262,7 @@ async def parse(tenant_id, dataset_id): | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821 | ||||||||||||||||||||||||
| @token_required | ||||||||||||||||||||||||
| async def stop_parsing(tenant_id, dataset_id): | ||||||||||||||||||||||||
| async def stop_parsing(tenant_id, dataset_id, authenticated_user_id=None): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Stop parsing documents into chunks. | ||||||||||||||||||||||||
| --- | ||||||||||||||||||||||||
|
|
@@ -289,7 +299,10 @@ async def stop_parsing(tenant_id, dataset_id): | |||||||||||||||||||||||
| schema: | ||||||||||||||||||||||||
| type: object | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible( | ||||||||||||||||||||||||
| kb_id=dataset_id, | ||||||||||||||||||||||||
| user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id), | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
|
Comment on lines
+302
to
+305
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for dataset access authorization checks. Similar to the As per coding guidelines, " 🤖 Prompt for AI Agents |
||||||||||||||||||||||||
| return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") | ||||||||||||||||||||||||
| req = await get_request_json() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -329,7 +342,7 @@ async def stop_parsing(tenant_id, dataset_id): | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @manager.route("/retrieval", methods=["POST"]) # noqa: F821 | ||||||||||||||||||||||||
| @token_required | ||||||||||||||||||||||||
| async def retrieval_test(tenant_id): | ||||||||||||||||||||||||
| async def retrieval_test(tenant_id, authenticated_user_id=None): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Retrieve chunks based on a query. | ||||||||||||||||||||||||
| --- | ||||||||||||||||||||||||
|
|
@@ -416,8 +429,9 @@ async def retrieval_test(tenant_id): | |||||||||||||||||||||||
| kb_ids = req["dataset_ids"] | ||||||||||||||||||||||||
| if not isinstance(kb_ids, list): | ||||||||||||||||||||||||
| return get_error_data_result("`dataset_ids` should be a list") | ||||||||||||||||||||||||
| actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id) | ||||||||||||||||||||||||
| for id in kb_ids: | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id): | ||||||||||||||||||||||||
| if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id): | ||||||||||||||||||||||||
| return get_error_data_result(f"You don't own the dataset {id}.") | ||||||||||||||||||||||||
|
Comment on lines
+432
to
435
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for dataset access authorization checks. This function checks access to potentially multiple datasets but lacks logging. Since this loops over dataset IDs, consider logging once before the loop (with the full list) and/or within the loop for denied access. 🔍 Proposed logging addition actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
+ logging.debug("Checking access to datasets: dataset_ids=%s actor_id=%s (authenticated_user_id=%s)", kb_ids, actor_id, authenticated_user_id)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id):
+ logging.warning("Dataset access denied: dataset_id=%s actor_id=%s", id, actor_id)
return get_error_data_result(f"You don't own the dataset {id}.")As per coding guidelines, " 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||
| kbs = KnowledgebaseService.get_by_ids(kb_ids) | ||||||||||||||||||||||||
| embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -286,6 +286,8 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da | |
|
|
||
|
|
||
| def token_required(func): | ||
| accepts_authenticated_user_id = "authenticated_user_id" in inspect.signature(func).parameters | ||
|
|
||
| @wraps(func) | ||
| async def wrapper(*args, **kwargs): | ||
| # Validate the token (API Key) | ||
|
|
@@ -334,6 +336,8 @@ async def wrapper(*args, **kwargs): | |
| tenants = UserTenantService.query(user_id=user[0].id) | ||
| if tenants: | ||
| kwargs["tenant_id"] = tenants[0].tenant_id | ||
| if accepts_authenticated_user_id: | ||
| kwargs["authenticated_user_id"] = user[0].id | ||
|
Comment on lines
+339
to
+340
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for the authenticated user injection flow. This security-relevant code path injects the authenticated user ID for dataset authorization checks but lacks logging. Adding a log statement would improve observability and audit trails for private dataset access. 📊 Proposed logging addition kwargs["tenant_id"] = tenants[0].tenant_id
if accepts_authenticated_user_id:
+ logging.debug("JWT authentication: injecting authenticated_user_id=%s for tenant_id=%s", user[0].id, tenants[0].tenant_id)
kwargs["authenticated_user_id"] = user[0].idAs per coding guidelines, " 🤖 Prompt for AI Agents |
||
| result = func(*args, **kwargs) | ||
| if inspect.iscoroutine(result): | ||
| return await result | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # | ||
| # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import asyncio | ||
| import importlib.util | ||
| import json | ||
| import sys | ||
| from enum import Enum | ||
| from pathlib import Path | ||
| from types import ModuleType, SimpleNamespace | ||
|
|
||
| from quart import Quart | ||
|
|
||
|
|
||
| def _run(coro): | ||
| return asyncio.run(coro) | ||
|
|
||
|
|
||
| def _load_api_utils_module(monkeypatch): | ||
| repo_root = Path(__file__).resolve().parents[4] | ||
|
|
||
| common_pkg = ModuleType("common") | ||
| common_pkg.__path__ = [str(repo_root / "common")] | ||
| monkeypatch.setitem(sys.modules, "common", common_pkg) | ||
|
|
||
| class _RetCode(int, Enum): | ||
| SUCCESS = 0 | ||
| EXCEPTION_ERROR = 100 | ||
| ARGUMENT_ERROR = 101 | ||
| DATA_ERROR = 102 | ||
| OPERATING_ERROR = 103 | ||
| PERMISSION_ERROR = 108 | ||
| AUTHENTICATION_ERROR = 109 | ||
| FORBIDDEN = 403 | ||
| UNAUTHORIZED = 401 | ||
|
|
||
| class _ActiveEnum(str, Enum): | ||
| ACTIVE = "1" | ||
| INACTIVE = "0" | ||
|
|
||
| class _StatusEnum(str, Enum): | ||
| VALID = "1" | ||
| INVALID = "0" | ||
|
|
||
| common_constants_mod = ModuleType("common.constants") | ||
| common_constants_mod.RetCode = _RetCode | ||
| common_constants_mod.ActiveEnum = _ActiveEnum | ||
| common_constants_mod.StatusEnum = _StatusEnum | ||
| monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) | ||
|
|
||
| common_settings_mod = ModuleType("common.settings") | ||
| common_settings_mod.get_secret_key = lambda: "test-secret" | ||
| monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) | ||
| common_pkg.settings = common_settings_mod | ||
|
|
||
| common_misc_utils_mod = ModuleType("common.misc_utils") | ||
| common_misc_utils_mod.thread_pool_exec = lambda func, *args, **kwargs: func(*args, **kwargs) | ||
| monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc_utils_mod) | ||
|
|
||
| common_connection_utils_mod = ModuleType("common.connection_utils") | ||
| common_connection_utils_mod.timeout = lambda *_args, **_kwargs: None | ||
| monkeypatch.setitem(sys.modules, "common.connection_utils", common_connection_utils_mod) | ||
|
|
||
| common_mcp_tool_call_conn_mod = ModuleType("common.mcp_tool_call_conn") | ||
| common_mcp_tool_call_conn_mod.MCPToolCallSession = object | ||
| common_mcp_tool_call_conn_mod.close_multiple_mcp_toolcall_sessions = lambda *_args, **_kwargs: None | ||
| monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", common_mcp_tool_call_conn_mod) | ||
|
|
||
| api_db_models_mod = ModuleType("api.db.db_models") | ||
|
|
||
| class _APIToken: | ||
| @staticmethod | ||
| def query(**_kwargs): | ||
| return [] | ||
|
|
||
| api_db_models_mod.APIToken = _APIToken | ||
| monkeypatch.setitem(sys.modules, "api.db.db_models", api_db_models_mod) | ||
|
|
||
| tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") | ||
| tenant_llm_service_mod.LLMFactoriesService = object | ||
| monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) | ||
|
|
||
| user_service_mod = ModuleType("api.db.services.user_service") | ||
| user_service_mod.UserService = SimpleNamespace(query=lambda **_kwargs: []) | ||
| user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: []) | ||
| monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) | ||
|
|
||
| json_encode_mod = ModuleType("api.utils.json_encode") | ||
| json_encode_mod.CustomJSONEncoder = json.JSONEncoder | ||
| monkeypatch.setitem(sys.modules, "api.utils.json_encode", json_encode_mod) | ||
|
|
||
| module_name = "test_api_utils_token_required_module" | ||
| module_path = repo_root / "api" / "utils" / "api_utils.py" | ||
| spec = importlib.util.spec_from_file_location(module_name, module_path) | ||
| module = importlib.util.module_from_spec(spec) | ||
| monkeypatch.setitem(sys.modules, module_name, module) | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| def test_token_required_injects_authenticated_user_id_for_login_tokens(monkeypatch): | ||
| module = _load_api_utils_module(monkeypatch) | ||
| app = Quart(__name__) | ||
|
|
||
| monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: []) | ||
|
|
||
| user_service_mod = sys.modules["api.db.services.user_service"] | ||
| user_service_mod.UserService = SimpleNamespace( | ||
| query=lambda **kwargs: [SimpleNamespace(id="user-2")] if kwargs.get("access_token") == "raw-login-token" else [], | ||
| ) | ||
| user_service_mod.UserTenantService = SimpleNamespace( | ||
| query=lambda **kwargs: [SimpleNamespace(tenant_id="tenant-1")] if kwargs.get("user_id") == "user-2" else [], | ||
| ) | ||
|
|
||
| from itsdangerous.url_safe import URLSafeTimedSerializer | ||
|
|
||
| monkeypatch.setattr(URLSafeTimedSerializer, "loads", lambda self, _token: "raw-login-token") | ||
|
|
||
| @module.token_required | ||
| async def _handler(tenant_id=None, authenticated_user_id=None): | ||
| return { | ||
| "tenant_id": tenant_id, | ||
| "authenticated_user_id": authenticated_user_id, | ||
| } | ||
|
|
||
| async def _case(): | ||
| async with app.test_request_context("/", headers={"Authorization": "Bearer login-token"}): | ||
| return await _handler() | ||
|
|
||
| assert _run(_case()) == { | ||
| "tenant_id": "tenant-1", | ||
| "authenticated_user_id": "user-2", | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add logging for dataset access authorization checks.
This security-critical authorization check determines access to private datasets but lacks logging. Adding a log statement would improve security audit trails and debugging.
🔒 Proposed logging addition
As per coding guidelines, "
**/*.py: Add logging for new flows".🤖 Prompt for AI Agents