diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 5d6289e5734..7992cdb6105 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -678,17 +678,10 @@ def get_tenant_id_by_name(cls, name): @classmethod @DB.connection_context() def accessible(cls, doc_id, user_id): - docs = ( - cls.model.select(cls.model.id) - .join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)) - .join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)) - .where(cls.model.id == doc_id, UserTenant.user_id == user_id) - .paginate(0, 1) - ) - docs = docs.dicts() - if not docs: + e, doc = cls.get_by_id(doc_id) + if not e: return False - return True + return KnowledgebaseService.accessible(doc.kb_id, user_id) @classmethod @DB.connection_context() diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index c66d66a6821..a164287fa4e 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -18,7 +18,7 @@ from peewee import fn, JOIN from api.db import TenantPermission -from api.db.db_models import DB, Document, Knowledgebase, User, UserTenant, UserCanvas +from api.db.db_models import DB, Document, Knowledgebase, User, UserCanvas from api.db.services.common_service import CommonService from common.time_utils import current_timestamp, datetime_format from api.db.services import duplicate_name @@ -485,13 +485,21 @@ def accessible(cls, kb_id, user_id): # user_id: User ID # Returns: # Boolean indicating accessibility - docs = cls.model.select( - cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) - docs = docs.dicts() - if not docs: + e, kb = cls.get_by_id(kb_id) + if not e: return False - return True + + if kb.status != StatusEnum.VALID.value: + return False + + if kb.tenant_id == user_id: + return True + + if kb.permission != TenantPermission.TEAM.value: + return False + + joined_tenants = TenantService.get_joined_tenants_by_user_id(user_id) + return any(tenant["tenant_id"] == kb.tenant_id for tenant in joined_tenants) @classmethod @DB.connection_context() @@ -502,10 +510,10 @@ def get_kb_by_id(cls, kb_id, user_id): # user_id: User ID # Returns: # List containing dataset information - kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) - kbs = kbs.dicts() - return list(kbs) + e, kb = cls.get_by_id(kb_id) + if not e or not cls.accessible(kb_id, user_id): + return [] + return [kb.to_dict()] @classmethod @DB.connection_context() @@ -516,10 +524,11 @@ def get_kb_by_name(cls, kb_name, user_id): # user_id: User ID # Returns: # List containing dataset information - kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) - ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) - kbs = kbs.dicts() - return list(kbs) + kbs = cls.query(name=kb_name, status=StatusEnum.VALID.value) + for kb in kbs: + if cls.accessible(kb.id, user_id): + return [kb.to_dict()] + return [] @classmethod @DB.connection_context() diff --git a/test/unit_test/api/db/services/test_dataset_access_permissions.py b/test/unit_test/api/db/services/test_dataset_access_permissions.py new file mode 100644 index 00000000000..e3db6d0f2af --- /dev/null +++ b/test/unit_test/api/db/services/test_dataset_access_permissions.py @@ -0,0 +1,119 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import types +import warnings +from types import SimpleNamespace + +# xgboost imports pkg_resources and emits a deprecation warning that is promoted +# to error in our pytest configuration; ignore it for this unit test module. +warnings.filterwarnings( + "ignore", + message="pkg_resources is deprecated as an API.*", + category=UserWarning, +) + + +def _install_cv2_stub_if_unavailable(): + try: + import cv2 # noqa: F401 + return + except Exception: + pass + + stub = types.ModuleType("cv2") + + stub.INTER_LINEAR = 1 + stub.INTER_CUBIC = 2 + stub.BORDER_CONSTANT = 0 + stub.BORDER_REPLICATE = 1 + stub.COLOR_BGR2RGB = 0 + stub.COLOR_BGR2GRAY = 1 + stub.COLOR_GRAY2BGR = 2 + stub.IMREAD_IGNORE_ORIENTATION = 128 + stub.IMREAD_COLOR = 1 + stub.RETR_LIST = 1 + stub.CHAIN_APPROX_SIMPLE = 2 + + def _missing(*_args, **_kwargs): + raise RuntimeError("cv2 runtime call is unavailable in this test environment") + + def _module_getattr(name): + if name.isupper(): + return 0 + return _missing + + stub.__getattr__ = _module_getattr + sys.modules["cv2"] = stub + + +_install_cv2_stub_if_unavailable() + +from api.db import TenantPermission +from api.db.services.document_service import DocumentService +from api.db.services.knowledgebase_service import KnowledgebaseService +from common.constants import StatusEnum + + +def _unwrapped_kb_accessible(): + return KnowledgebaseService.accessible.__func__.__wrapped__ + + +def _unwrapped_doc_accessible(): + return DocumentService.accessible.__func__.__wrapped__ + + +def test_private_dataset_is_not_accessible_to_other_tenant_member(monkeypatch): + kb = SimpleNamespace( + id="kb-private", + tenant_id="owner-1", + permission=TenantPermission.ME.value, + status=StatusEnum.VALID.value, + ) + + monkeypatch.setattr(KnowledgebaseService, "get_by_id", classmethod(lambda cls, kb_id: (True, kb))) + monkeypatch.setattr( + "api.db.services.knowledgebase_service.TenantService.get_joined_tenants_by_user_id", + lambda _user_id: [{"tenant_id": "owner-1"}], + ) + + assert _unwrapped_kb_accessible()(KnowledgebaseService, "kb-private", "member-2") is False + + +def test_team_dataset_is_accessible_to_joined_tenant_member(monkeypatch): + kb = SimpleNamespace( + id="kb-team", + tenant_id="owner-1", + permission=TenantPermission.TEAM.value, + status=StatusEnum.VALID.value, + ) + + monkeypatch.setattr(KnowledgebaseService, "get_by_id", classmethod(lambda cls, kb_id: (True, kb))) + monkeypatch.setattr( + "api.db.services.knowledgebase_service.TenantService.get_joined_tenants_by_user_id", + lambda _user_id: [{"tenant_id": "owner-1"}], + ) + + assert _unwrapped_kb_accessible()(KnowledgebaseService, "kb-team", "member-2") is True + + +def test_document_access_respects_dataset_permission(monkeypatch): + doc = SimpleNamespace(id="doc-1", kb_id="kb-private") + + monkeypatch.setattr(DocumentService, "get_by_id", classmethod(lambda cls, doc_id: (True, doc))) + monkeypatch.setattr(KnowledgebaseService, "accessible", classmethod(lambda cls, kb_id, user_id: False)) + + assert _unwrapped_doc_accessible()(DocumentService, "doc-1", "member-2") is False