Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No
enrich_chunks_with_document_metadata(chunks, metadata_fields)


def _dataset_access_actor_id(tenant_id: str, authenticated_user_id: str | None = None) -> str:
return authenticated_user_id or tenant_id


@manager.route("/datasets/<dataset_id>/documents/<document_id>", methods=["GET"]) # noqa: F821
@token_required
async def download(tenant_id, dataset_id, document_id):
async def download(tenant_id, dataset_id, document_id, authenticated_user_id=None):
"""
Download a document from a dataset.
---
Expand Down Expand Up @@ -90,7 +94,10 @@ async def download(tenant_id, dataset_id, document_id):
"""
if not document_id:
return get_error_data_result(message="Specify document_id please.")
if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id):
if not KnowledgebaseService.accessible(
kb_id=dataset_id,
user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id),
):
Comment on lines +97 to +100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add logging for dataset access authorization checks.

This security-critical authorization check determines access to private datasets but lacks logging. Adding a log statement would improve security audit trails and debugging.

🔒 Proposed logging addition
+    actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
+    logging.debug("Checking dataset access: dataset_id=%s actor_id=%s (authenticated_user_id=%s)", dataset_id, actor_id, authenticated_user_id)
     if not KnowledgebaseService.accessible(
         kb_id=dataset_id,
-        user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id),
+        user_id=actor_id,
     ):
+        logging.warning("Dataset access denied: dataset_id=%s actor_id=%s", dataset_id, actor_id)
         return get_error_data_result(message=f"You do not own the dataset {dataset_id}.")

As per coding guidelines, "**/*.py: Add logging for new flows".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@api/apps/sdk/doc.py` around lines 97 - 100, Add structured logging around the
KnowledgebaseService.accessible check: log the attempted access with dataset_id,
the actor returned by _dataset_access_actor_id(tenant_id,
authenticated_user_id), and the authorization result (allowed/denied). Use the
module logger (e.g., logger or security logger) and emit a warning or info-level
entry when access is denied (include tenant_id and authenticated_user_id as
context), so the authorization decision for KnowledgebaseService.accessible is
recorded for audit and debugging.

return get_error_data_result(message=f"You do not own the dataset {dataset_id}.")
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc:
Expand Down Expand Up @@ -161,7 +168,7 @@ async def download_doc(document_id):

@manager.route("/datasets/<dataset_id>/chunks", methods=["POST"]) # noqa: F821
@token_required
async def parse(tenant_id, dataset_id):
async def parse(tenant_id, dataset_id, authenticated_user_id=None):
"""
Start parsing documents into chunks.
---
Expand Down Expand Up @@ -198,7 +205,10 @@ async def parse(tenant_id, dataset_id):
schema:
type: object
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
if not KnowledgebaseService.accessible(
kb_id=dataset_id,
user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id),
):
Comment on lines +208 to +211
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add logging for dataset access authorization checks.

Similar to the download function, this authorization check lacks logging. Consider adding the same logging pattern here for consistency.

As per coding guidelines, "**/*.py: Add logging for new flows".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@api/apps/sdk/doc.py` around lines 208 - 211, The authorization check using
KnowledgebaseService.accessible(kb_id=dataset_id,
user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id)) lacks
logging; add logging similar to the download flow: log an info/debug message
before the check indicating the dataset_id and actor_id being validated and log
a warning/error when access is denied (including dataset_id and actor_id) so
operators can trace authorization failures; update the same function that
performs this check to call the logger and follow the existing logging format
used by the download function.

return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await get_request_json()
if not req.get("document_ids"):
Expand Down Expand Up @@ -252,7 +262,7 @@ async def parse(tenant_id, dataset_id):

@manager.route("/datasets/<dataset_id>/chunks", methods=["DELETE"]) # noqa: F821
@token_required
async def stop_parsing(tenant_id, dataset_id):
async def stop_parsing(tenant_id, dataset_id, authenticated_user_id=None):
"""
Stop parsing documents into chunks.
---
Expand Down Expand Up @@ -289,7 +299,10 @@ async def stop_parsing(tenant_id, dataset_id):
schema:
type: object
"""
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
if not KnowledgebaseService.accessible(
kb_id=dataset_id,
user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id),
):
Comment on lines +302 to +305
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add logging for dataset access authorization checks.

Similar to the download and parse functions, this authorization check lacks logging. Consider adding the same logging pattern for consistency.

As per coding guidelines, "**/*.py: Add logging for new flows".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@api/apps/sdk/doc.py` around lines 302 - 305, The authorization check using
KnowledgebaseService.accessible(kb_id=dataset_id,
user_id=_dataset_access_actor_id(tenant_id, authenticated_user_id)) is missing
logging; add the same logging pattern used in the download/parse flows to record
the check input (dataset_id, tenant_id, authenticated_user_id) and the
authorization outcome. Specifically, before/around the if, emit a log entry
(matching the existing logger level/format used in download/parse) that includes
dataset_id, tenant_id, actor id from _dataset_access_actor_id(...), and whether
access was granted/denied so the decision is auditable and consistent with other
flows.

return get_error_data_result(message=f"You don't own the dataset {dataset_id}.")
req = await get_request_json()

Expand Down Expand Up @@ -329,7 +342,7 @@ async def stop_parsing(tenant_id, dataset_id):

@manager.route("/retrieval", methods=["POST"]) # noqa: F821
@token_required
async def retrieval_test(tenant_id):
async def retrieval_test(tenant_id, authenticated_user_id=None):
"""
Retrieve chunks based on a query.
---
Expand Down Expand Up @@ -416,8 +429,9 @@ async def retrieval_test(tenant_id):
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids, list):
return get_error_data_result("`dataset_ids` should be a list")
actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id):
return get_error_data_result(f"You don't own the dataset {id}.")
Comment on lines +432 to 435
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add logging for dataset access authorization checks.

This function checks access to potentially multiple datasets but lacks logging. Since this loops over dataset IDs, consider logging once before the loop (with the full list) and/or within the loop for denied access.

🔍 Proposed logging addition
     actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
+    logging.debug("Checking access to datasets: dataset_ids=%s actor_id=%s (authenticated_user_id=%s)", kb_ids, actor_id, authenticated_user_id)
     for id in kb_ids:
         if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id):
+            logging.warning("Dataset access denied: dataset_id=%s actor_id=%s", id, actor_id)
             return get_error_data_result(f"You don't own the dataset {id}.")

As per coding guidelines, "**/*.py: Add logging for new flows".

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id):
return get_error_data_result(f"You don't own the dataset {id}.")
actor_id = _dataset_access_actor_id(tenant_id, authenticated_user_id)
logging.debug("Checking access to datasets: dataset_ids=%s actor_id=%s (authenticated_user_id=%s)", kb_ids, actor_id, authenticated_user_id)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=actor_id):
logging.warning("Dataset access denied: dataset_id=%s actor_id=%s", id, actor_id)
return get_error_data_result(f"You don't own the dataset {id}.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@api/apps/sdk/doc.py` around lines 432 - 435, The dataset access check loop
lacks logging; add structured logs using the existing actor id and kb_ids: log a
single INFO-level message before the loop mentioning kb_ids and actor_id (from
_dataset_access_actor_id), and inside the loop log a WARN/ERROR when
KnowledgebaseService.accessible(kb_id=id, user_id=actor_id) returns False
including the denied kb id and actor_id (and authenticated_user_id if available)
before returning get_error_data_result so denied access events are recorded for
debugging and audit.

kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison
Expand Down
4 changes: 4 additions & 0 deletions api/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da


def token_required(func):
accepts_authenticated_user_id = "authenticated_user_id" in inspect.signature(func).parameters

@wraps(func)
async def wrapper(*args, **kwargs):
# Validate the token (API Key)
Expand Down Expand Up @@ -334,6 +336,8 @@ async def wrapper(*args, **kwargs):
tenants = UserTenantService.query(user_id=user[0].id)
if tenants:
kwargs["tenant_id"] = tenants[0].tenant_id
if accepts_authenticated_user_id:
kwargs["authenticated_user_id"] = user[0].id
Comment on lines +339 to +340
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add logging for the authenticated user injection flow.

This security-relevant code path injects the authenticated user ID for dataset authorization checks but lacks logging. Adding a log statement would improve observability and audit trails for private dataset access.

📊 Proposed logging addition
                    kwargs["tenant_id"] = tenants[0].tenant_id
                    if accepts_authenticated_user_id:
+                       logging.debug("JWT authentication: injecting authenticated_user_id=%s for tenant_id=%s", user[0].id, tenants[0].tenant_id)
                        kwargs["authenticated_user_id"] = user[0].id

As per coding guidelines, "**/*.py: Add logging for new flows".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@api/utils/api_utils.py` around lines 339 - 340, The code path that injects
authenticated_user_id into kwargs (the block checking
accepts_authenticated_user_id and setting kwargs["authenticated_user_id"] =
user[0].id) needs an audit log entry; add a concise log statement (using the
module logger or existing logger variable) immediately before or after the
assignment that records the action and key context such as the injected user id
and the target dataset/request identifier (if available) and use an appropriate
level (info/debug) while avoiding sensitive data; update any import or logger
initialization (e.g., logger = logging.getLogger(__name__)) if not already
present.

result = func(*args, **kwargs)
if inspect.iscoroutine(result):
return await result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,11 @@ def test_download_and_download_doc_errors(self, monkeypatch):
_patch_storage(monkeypatch, module, file_stream=b"")
res = _run(module.download.__wrapped__("tenant-1", "ds-1", ""))
assert res["message"] == "Specify document_id please."
monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [])
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False)
res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1"))
assert "do not own the dataset" in res["message"]

monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [1])
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: True)
monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [])
res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1"))
assert "not own the document" in res["message"]
Expand Down Expand Up @@ -597,6 +597,38 @@ def test_parse_branches(self, monkeypatch):
assert res["code"] == module.RetCode.DATA_ERROR
assert "Duplicate document ids" in res["message"]

def test_sdk_routes_use_authenticated_user_for_dataset_access(self, monkeypatch):
module = _load_doc_module(monkeypatch)
access_calls = []

def _accessible(**kwargs):
access_calls.append(kwargs)
return False

monkeypatch.setattr(module.KnowledgebaseService, "accessible", _accessible)

res = _run(module.download.__wrapped__("tenant-1", "ds-1", "doc-1", authenticated_user_id="user-2"))
assert "do not own the dataset" in res["message"]

monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"document_ids": ["doc-1"]}))
res = _run(module.parse.__wrapped__("tenant-1", "ds-1", authenticated_user_id="user-2"))
assert "don't own the dataset" in res["message"]

monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"document_ids": ["doc-1"]}))
res = _run(module.stop_parsing.__wrapped__("tenant-1", "ds-1", authenticated_user_id="user-2"))
assert "don't own the dataset" in res["message"]

monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"dataset_ids": ["ds-1"], "question": "q"}))
res = _run(module.retrieval_test.__wrapped__("tenant-1", authenticated_user_id="user-2"))
assert "don't own the dataset ds-1" in res["message"]

assert access_calls == [
{"kb_id": "ds-1", "user_id": "user-2"},
{"kb_id": "ds-1", "user_id": "user-2"},
{"kb_id": "ds-1", "user_id": "user-2"},
{"kb_id": "ds-1", "user_id": "user-2"},
]

def test_stop_parsing_branches(self, monkeypatch):
module = _load_doc_module(monkeypatch)
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False)
Expand Down
146 changes: 146 additions & 0 deletions test/unit_test/api/utils/test_api_utils_token_required.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import asyncio
import importlib.util
import json
import sys
from enum import Enum
from pathlib import Path
from types import ModuleType, SimpleNamespace

from quart import Quart


def _run(coro):
return asyncio.run(coro)


def _load_api_utils_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]

common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)

class _RetCode(int, Enum):
SUCCESS = 0
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
FORBIDDEN = 403
UNAUTHORIZED = 401

class _ActiveEnum(str, Enum):
ACTIVE = "1"
INACTIVE = "0"

class _StatusEnum(str, Enum):
VALID = "1"
INVALID = "0"

common_constants_mod = ModuleType("common.constants")
common_constants_mod.RetCode = _RetCode
common_constants_mod.ActiveEnum = _ActiveEnum
common_constants_mod.StatusEnum = _StatusEnum
monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod)

common_settings_mod = ModuleType("common.settings")
common_settings_mod.get_secret_key = lambda: "test-secret"
monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod)
common_pkg.settings = common_settings_mod

common_misc_utils_mod = ModuleType("common.misc_utils")
common_misc_utils_mod.thread_pool_exec = lambda func, *args, **kwargs: func(*args, **kwargs)
monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc_utils_mod)

common_connection_utils_mod = ModuleType("common.connection_utils")
common_connection_utils_mod.timeout = lambda *_args, **_kwargs: None
monkeypatch.setitem(sys.modules, "common.connection_utils", common_connection_utils_mod)

common_mcp_tool_call_conn_mod = ModuleType("common.mcp_tool_call_conn")
common_mcp_tool_call_conn_mod.MCPToolCallSession = object
common_mcp_tool_call_conn_mod.close_multiple_mcp_toolcall_sessions = lambda *_args, **_kwargs: None
monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", common_mcp_tool_call_conn_mod)

api_db_models_mod = ModuleType("api.db.db_models")

class _APIToken:
@staticmethod
def query(**_kwargs):
return []

api_db_models_mod.APIToken = _APIToken
monkeypatch.setitem(sys.modules, "api.db.db_models", api_db_models_mod)

tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service")
tenant_llm_service_mod.LLMFactoriesService = object
monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod)

user_service_mod = ModuleType("api.db.services.user_service")
user_service_mod.UserService = SimpleNamespace(query=lambda **_kwargs: [])
user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: [])
monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod)

json_encode_mod = ModuleType("api.utils.json_encode")
json_encode_mod.CustomJSONEncoder = json.JSONEncoder
monkeypatch.setitem(sys.modules, "api.utils.json_encode", json_encode_mod)

module_name = "test_api_utils_token_required_module"
module_path = repo_root / "api" / "utils" / "api_utils.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, module_name, module)
spec.loader.exec_module(module)
return module


def test_token_required_injects_authenticated_user_id_for_login_tokens(monkeypatch):
module = _load_api_utils_module(monkeypatch)
app = Quart(__name__)

monkeypatch.setattr(module.APIToken, "query", lambda **_kwargs: [])

user_service_mod = sys.modules["api.db.services.user_service"]
user_service_mod.UserService = SimpleNamespace(
query=lambda **kwargs: [SimpleNamespace(id="user-2")] if kwargs.get("access_token") == "raw-login-token" else [],
)
user_service_mod.UserTenantService = SimpleNamespace(
query=lambda **kwargs: [SimpleNamespace(tenant_id="tenant-1")] if kwargs.get("user_id") == "user-2" else [],
)

from itsdangerous.url_safe import URLSafeTimedSerializer

monkeypatch.setattr(URLSafeTimedSerializer, "loads", lambda self, _token: "raw-login-token")

@module.token_required
async def _handler(tenant_id=None, authenticated_user_id=None):
return {
"tenant_id": tenant_id,
"authenticated_user_id": authenticated_user_id,
}

async def _case():
async with app.test_request_context("/", headers={"Authorization": "Bearer login-token"}):
return await _handler()

assert _run(_case()) == {
"tenant_id": "tenant-1",
"authenticated_user_id": "user-2",
}
Loading