diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index 7300a55a9f7..69303fcefaf 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -46,6 +46,7 @@ from common import settings from common.constants import ParserType, RetCode, TaskStatus, SANDBOX_ARTIFACT_BUCKET +from common.doc_store.doc_store_base import OrderByExpr from common.metadata_utils import convert_conditions, meta_filter, turn2jsonschema from common.misc_utils import get_uuid, thread_pool_exec from api.utils.file_utils import filename_type, thumbnail @@ -733,7 +734,7 @@ def list_docs(dataset_id, tenant_id): renamed_doc_list = [map_doc_keys(doc) for doc in payload] for doc_item in renamed_doc_list: if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): - doc_item["thumbnail"] = f"/api/v1/documents/images/{dataset_id}-{doc_item['thumbnail']}" + doc_item["thumbnail"] = _document_thumbnail_url(doc_item["id"]) if doc_item.get("source_type"): doc_item["source_type"] = doc_item["source_type"].split("/")[0] if doc_item["parser_config"].get("metadata"): @@ -835,6 +836,54 @@ def _get_docs_with_request(req, dataset_id:str): return RetCode.SUCCESS, "", docs, total +def _document_thumbnail_url(doc_id: str) -> str: + return f"/api/v1/documents/{doc_id}/thumbnail" + + +def _apply_image_response_headers(response, filename: str, default_content_type: str): + ext = Path(filename).suffix.lower().lstrip(".") or None + content_type = CONTENT_TYPE_MAP.get(ext, default_content_type) if ext else default_content_type + apply_safe_file_response_headers(response, content_type, ext) + + +def _get_accessible_chunk_image_doc_id(image_id: str) -> str | None: + arr = image_id.split("-", 1) + if len(arr) != 2: + return None + + kb_id, _ = arr + if not KnowledgebaseService.accessible(kb_id, current_user.id): + return None + + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + return None + + index_name = search.index_name(kb.tenant_id) + if not settings.docStoreConn.index_exist(index_name, kb_id): + return None + + result = settings.docStoreConn.search( + ["doc_id"], + [], + {"img_id": image_id}, + [], + OrderByExpr(), + 0, + 1, + index_name, + [kb_id], + ) + fields = settings.docStoreConn.get_fields(result, ["doc_id"]) + if not fields: + return None + + doc_id = next(iter(fields.values())).get("doc_id") + if doc_id and DocumentService.accessible(doc_id, current_user.id): + return doc_id + return None + + def _get_doc_filters_with_request(req, dataset_id: str): """Get aggregated document filters with request parameters from a dataset.""" q = req.args @@ -1165,6 +1214,7 @@ async def update_metadata_config(tenant_id, dataset_id, document_id): @manager.route("/thumbnails", methods=["GET"]) # noqa: F821 +@login_required def list_thumbnails(): """ Get thumbnails for documents. @@ -1191,11 +1241,12 @@ def list_thumbnails(): return get_json_result(data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) try: - docs = DocumentService.get_thumbnails(doc_ids) + authorized_doc_ids = [doc_id for doc_id in doc_ids if DocumentService.accessible(doc_id, current_user.id)] + docs = DocumentService.get_thumbnails(authorized_doc_ids) for doc_item in docs: if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): - doc_item["thumbnail"] = f"/api/v1/documents/images/{doc_item['kb_id']}-{doc_item['thumbnail']}" + doc_item["thumbnail"] = _document_thumbnail_url(doc_item["id"]) return get_json_result(data={d["id"]: d["thumbnail"] for d in docs}) except Exception as e: @@ -1615,7 +1666,54 @@ def _run_sync(): return get_error_data_result(message="Internal server error") +@manager.route("/documents//thumbnail", methods=["GET"]) # noqa: F821 +@login_required +async def get_document_thumbnail(doc_id): + """ + Get a document thumbnail by document ID. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + parameters: + - name: doc_id + in: path + required: true + schema: + type: string + description: The document ID. + responses: + 200: + description: Thumbnail image file + content: + image/png: + schema: + type: string + format: binary + """ + try: + e, doc = DocumentService.get_by_id(doc_id) + if not e: + return get_data_error_result(message="Document not found!") + + if not KnowledgebaseService.accessible(doc.kb_id, current_user.id): + logging.warning("get_document_thumbnail: access denied for doc_id=%s user_id=%s", doc_id, current_user.id) + return get_data_error_result(message="Document not found!") + + if not doc.thumbnail or doc.thumbnail.startswith(IMG_BASE64_PREFIX): + return get_data_error_result(message="Image not found.") + + data = await thread_pool_exec(settings.STORAGE_IMPL.get, doc.kb_id, doc.thumbnail) + response = await make_response(data) + _apply_image_response_headers(response, doc.thumbnail, "image/png") + return response + except Exception as e: + return server_error_response(e) + + @manager.route("/documents/images/", methods=["GET"]) # noqa: F821 +@login_required async def get_document_image(image_id): """ Get a document image by ID. @@ -1639,13 +1737,19 @@ async def get_document_image(image_id): format: binary """ try: - arr = image_id.split("-") + arr = image_id.split("-", 1) if len(arr) != 2: return get_data_error_result(message="Image not found.") - bkt, nm = image_id.split("-") + + doc_id = _get_accessible_chunk_image_doc_id(image_id) + if not doc_id: + logging.warning("get_document_image: access denied for image_id=%s user_id=%s", image_id, current_user.id) + return get_data_error_result(message="Image not found.") + + bkt, nm = arr data = await thread_pool_exec(settings.STORAGE_IMPL.get, bkt, nm) response = await make_response(data) - response.headers.set("Content-Type", "image/JPEG") + _apply_image_response_headers(response, nm, "image/jpeg") return response except Exception as e: return server_error_response(e) diff --git a/test/testcases/test_web_api/test_common.py b/test/testcases/test_web_api/test_common.py index 170d530af1a..459982700f3 100644 --- a/test/testcases/test_web_api/test_common.py +++ b/test/testcases/test_web_api/test_common.py @@ -419,6 +419,11 @@ def document_get(auth, document_id, *, headers=HEADERS, data=None): return res +def document_thumbnail(auth, document_id, *, headers=HEADERS, data=None): + res = requests.get(url=f"{HOST_ADDRESS}/api/{VERSION}/documents/{document_id}/thumbnail", headers=headers, auth=auth, data=data) + return res + + def document_download(auth, attachment_id, *, ext="markdown", headers=HEADERS, data=None): res = requests.get( url=f"{HOST_ADDRESS}/api/{VERSION}/documents/{attachment_id}/download", diff --git a/test/testcases/test_web_api/test_document_app/conftest.py b/test/testcases/test_web_api/test_document_app/conftest.py index 0e719a15276..5f18710cff2 100644 --- a/test/testcases/test_web_api/test_document_app/conftest.py +++ b/test/testcases/test_web_api/test_document_app/conftest.py @@ -78,13 +78,34 @@ def cleanup(): return dataset_id, bulk_upload_documents(WebApiAuth, dataset_id, 3, ragflow_tmp_dir) -@pytest.fixture() -def document_app_module(monkeypatch): - repo_root = Path(__file__).resolve().parents[4] +def _check_duplicate_ids(ids, *_args, **_kwargs): + return list(dict.fromkeys(ids)), [] + + +def _stub_document_api_dependencies(monkeypatch, repo_root): common_pkg = ModuleType("common") common_pkg.__path__ = [str(repo_root / "common")] monkeypatch.setitem(sys.modules, "common", common_pkg) + common_settings_mod = ModuleType("common.settings") + common_settings_mod.STORAGE_IMPL = SimpleNamespace(get=lambda *_args, **_kwargs: b"", obj_exist=lambda *_args, **_kwargs: False) + common_settings_mod.docStoreConn = SimpleNamespace( + index_exist=lambda *_args, **_kwargs: False, + search=lambda *_args, **_kwargs: {}, + get_fields=lambda *_args, **_kwargs: {}, + ) + monkeypatch.setitem(sys.modules, "common.settings", common_settings_mod) + + metadata_utils_mod = ModuleType("common.metadata_utils") + metadata_utils_mod.convert_conditions = lambda *_args, **_kwargs: {} + metadata_utils_mod.meta_filter = lambda *_args, **_kwargs: True + metadata_utils_mod.turn2jsonschema = lambda value: value + monkeypatch.setitem(sys.modules, "common.metadata_utils", metadata_utils_mod) + + rag_nlp_mod = ModuleType("rag.nlp") + rag_nlp_mod.search = SimpleNamespace(index_name=lambda tenant_id: f"ragflow_{tenant_id}") + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp_mod) + deepdoc_pkg = ModuleType("deepdoc") deepdoc_parser_pkg = ModuleType("deepdoc.parser") deepdoc_parser_pkg.__path__ = [] @@ -95,112 +116,33 @@ class _StubPdfParser: class _StubExcelParser: pass - deepdoc_parser_pkg.PdfParser = _StubPdfParser - deepdoc_pkg.parser = deepdoc_parser_pkg - monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) - monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) - deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") - deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) - deepdoc_html_module = ModuleType("deepdoc.parser.html_parser") - class _StubHtmlParser: pass - deepdoc_html_module.RAGFlowHtmlParser = _StubHtmlParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.html_parser", deepdoc_html_module) - deepdoc_mineru_module = ModuleType("deepdoc.parser.mineru_parser") - class _StubMinerUParser: pass - deepdoc_mineru_module.MinerUParser = _StubMinerUParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.mineru_parser", deepdoc_mineru_module) - deepdoc_paddleocr_module = ModuleType("deepdoc.parser.paddleocr_parser") - class _StubPaddleOCRParser: pass - deepdoc_paddleocr_module.PaddleOCRParser = _StubPaddleOCRParser - monkeypatch.setitem(sys.modules, "deepdoc.parser.paddleocr_parser", deepdoc_paddleocr_module) - monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) - - stub_apps = ModuleType("api.apps") - stub_apps.__path__ = [str(repo_root / "api" / "apps")] - stub_apps.current_user = SimpleNamespace(id="user-1") - stub_apps.login_required = lambda func: func - monkeypatch.setitem(sys.modules, "api.apps", stub_apps) - - stub_apps_services = ModuleType("api.apps.services") - stub_apps_services.__path__ = [str(repo_root / "api" / "apps" / "services")] - monkeypatch.setitem(sys.modules, "api.apps.services", stub_apps_services) - - document_api_service_mod = ModuleType("api.apps.services.document_api_service") - document_api_service_mod.validate_document_update_fields = lambda *_args, **_kwargs: (None, None) - document_api_service_mod.map_doc_keys = lambda doc: doc.to_dict() if hasattr(doc, "to_dict") else doc - - def _map_doc_keys_with_run_status(doc, run_status="0"): - payload = doc if isinstance(doc, dict) else doc.to_dict() - return {**payload, "run": run_status} - - document_api_service_mod.map_doc_keys_with_run_status = _map_doc_keys_with_run_status - document_api_service_mod.update_document_name_only = lambda *_args, **_kwargs: None - document_api_service_mod.update_chunk_method = lambda *_args, **_kwargs: None - document_api_service_mod.update_document_status_only = lambda *_args, **_kwargs: None - document_api_service_mod.reset_document_for_reparse = lambda *_args, **_kwargs: None - monkeypatch.setitem(sys.modules, "api.apps.services.document_api_service", document_api_service_mod) - - module_path = repo_root / "api" / "apps" / "restful_apis" / "document_api.py" - spec = importlib.util.spec_from_file_location("test_document_app_unit", module_path) - module = importlib.util.module_from_spec(spec) - module.manager = _DummyManager() - spec.loader.exec_module(module) - return module - - -@pytest.fixture() -def document_rest_api_module(monkeypatch): - repo_root = Path(__file__).resolve().parents[4] - common_pkg = ModuleType("common") - common_pkg.__path__ = [str(repo_root / "common")] - monkeypatch.setitem(sys.modules, "common", common_pkg) - - deepdoc_pkg = ModuleType("deepdoc") - deepdoc_parser_pkg = ModuleType("deepdoc.parser") - deepdoc_parser_pkg.__path__ = [] - - class _StubPdfParser: - pass - - class _StubExcelParser: - pass - deepdoc_parser_pkg.PdfParser = _StubPdfParser deepdoc_pkg.parser = deepdoc_parser_pkg monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) + deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) - deepdoc_html_module = ModuleType("deepdoc.parser.html_parser") - - class _StubHtmlParser: - pass + deepdoc_html_module = ModuleType("deepdoc.parser.html_parser") deepdoc_html_module.RAGFlowHtmlParser = _StubHtmlParser monkeypatch.setitem(sys.modules, "deepdoc.parser.html_parser", deepdoc_html_module) - deepdoc_mineru_module = ModuleType("deepdoc.parser.mineru_parser") - - class _StubMinerUParser: - pass + deepdoc_mineru_module = ModuleType("deepdoc.parser.mineru_parser") deepdoc_mineru_module.MinerUParser = _StubMinerUParser monkeypatch.setitem(sys.modules, "deepdoc.parser.mineru_parser", deepdoc_mineru_module) - deepdoc_paddleocr_module = ModuleType("deepdoc.parser.paddleocr_parser") - - class _StubPaddleOCRParser: - pass + deepdoc_paddleocr_module = ModuleType("deepdoc.parser.paddleocr_parser") deepdoc_paddleocr_module.PaddleOCRParser = _StubPaddleOCRParser monkeypatch.setitem(sys.modules, "deepdoc.parser.paddleocr_parser", deepdoc_paddleocr_module) monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) @@ -218,6 +160,7 @@ class _StubPaddleOCRParser: document_api_service_mod = ModuleType("api.apps.services.document_api_service") document_api_service_mod.validate_document_update_fields = lambda *_args, **_kwargs: (None, None) document_api_service_mod.map_doc_keys = lambda doc: doc.to_dict() if hasattr(doc, "to_dict") else doc + def _map_doc_keys_with_run_status(doc, run_status="0"): payload = doc if isinstance(doc, dict) else doc.to_dict() return {**payload, "run": run_status} @@ -229,11 +172,117 @@ def _map_doc_keys_with_run_status(doc, run_status="0"): document_api_service_mod.reset_document_for_reparse = lambda *_args, **_kwargs: None monkeypatch.setitem(sys.modules, "api.apps.services.document_api_service", document_api_service_mod) + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.Task = type("Task", (), {}) + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + doc_metadata_service_mod = ModuleType("api.db.services.doc_metadata_service") + doc_metadata_service_mod.DocMetadataService = SimpleNamespace(get_metadata_for_documents=lambda *_args, **_kwargs: {}) + monkeypatch.setitem(sys.modules, "api.db.services.doc_metadata_service", doc_metadata_service_mod) + + document_service_mod = ModuleType("api.db.services.document_service") + document_service_mod.DocumentService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_id=lambda _doc_id: (False, None), + accessible=lambda *_args, **_kwargs: False, + get_by_kb_id=lambda *_args, **_kwargs: ([], 0), + get_thumbnails=lambda _doc_ids: [], + update_parser_config=lambda *_args, **_kwargs: None, + update_by_id=lambda *_args, **_kwargs: True, + ) + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + + file2document_service_mod = ModuleType("api.db.services.file2document_service") + file2document_service_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket", "name")) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_service_mod) + + file_service_mod = ModuleType("api.db.services.file_service") + file_service_mod.FileService = SimpleNamespace(get_by_id=lambda *_args, **_kwargs: (False, None)) + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + + knowledgebase_service_mod = ModuleType("api.db.services.knowledgebase_service") + knowledgebase_service_mod.KnowledgebaseService = SimpleNamespace( + query=lambda **_kwargs: [], + get_by_tenant_ids=lambda *_args, **_kwargs: ([], 0), + get_by_id=lambda _dataset_id: (False, None), + accessible=lambda *_args, **_kwargs: False, + ) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", knowledgebase_service_mod) + + task_service_mod = ModuleType("api.db.services.task_service") + task_service_mod.TaskService = SimpleNamespace(query=lambda **_kwargs: []) + task_service_mod.cancel_all_task_of = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_mod) + + check_team_permission_mod = ModuleType("api.common.check_team_permission") + check_team_permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "api.common.check_team_permission", check_team_permission_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _default_request_json(): + return {} + + def _ok_result(*, data=None, message="success", code=0): + return {"code": code, "message": message, "data": data} + + def _error_result(*, message="Sorry! Data missing!", code=102): + return {"code": code, "message": message} + + def _server_error_response(error): + return {"code": 500, "message": str(error)} + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_data_error_result = _error_result + api_utils_mod.get_error_data_result = _error_result + api_utils_mod.get_result = _ok_result + api_utils_mod.get_json_result = _ok_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func + api_utils_mod.get_error_argument_result = _error_result + api_utils_mod.check_duplicate_ids = _check_duplicate_ids + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + file_utils_mod = ModuleType("api.utils.file_utils") + file_utils_mod.filename_type = lambda *_args, **_kwargs: "txt" + file_utils_mod.thumbnail = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "api.utils.file_utils", file_utils_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.CONTENT_TYPE_MAP = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", "txt": "text/plain"} + web_utils_mod.html2pdf = lambda *_args, **_kwargs: b"" + web_utils_mod.is_valid_url = lambda *_args, **_kwargs: True + web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, extension=None: response.headers.update({"content_type": content_type, "extension": extension}) + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.UserService = SimpleNamespace(query=lambda **_kwargs: []) + user_service_mod.UserTenantService = SimpleNamespace(query=lambda **_kwargs: []) + user_service_mod.TenantService = SimpleNamespace(query=lambda **_kwargs: []) + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + +def _load_document_api_module(repo_root, module_name): module_path = repo_root / "api" / "apps" / "restful_apis" / "document_api.py" - spec = importlib.util.spec_from_file_location("test_document_api_unit", module_path) + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() spec.loader.exec_module(module) + return module + + +@pytest.fixture() +def document_app_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + _stub_document_api_dependencies(monkeypatch, repo_root) + return _load_document_api_module(repo_root, "test_document_app_unit") + + +@pytest.fixture() +def document_rest_api_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + _stub_document_api_dependencies(monkeypatch, repo_root) + module = _load_document_api_module(repo_root, "test_document_api_unit") monkeypatch.setattr( module.KnowledgebaseService, "get_by_id", diff --git a/test/testcases/test_web_api/test_document_app/test_document_metadata.py b/test/testcases/test_web_api/test_document_app/test_document_metadata.py index 71bf32d5658..53ffe474370 100644 --- a/test/testcases/test_web_api/test_document_app/test_document_metadata.py +++ b/test/testcases/test_web_api/test_document_app/test_document_metadata.py @@ -21,6 +21,7 @@ document_change_status, document_filter, document_infos, + document_thumbnail, document_metadata_summary, document_metadata_update, document_update_metadata_setting, @@ -54,6 +55,14 @@ def test_infos_auth_invalid(self, invalid_auth, expected_code, expected_fragment assert res["code"] == expected_code, res assert expected_fragment in res["message"], res + @pytest.mark.p2 + @pytest.mark.parametrize("invalid_auth, expected_code, expected_fragment", INVALID_AUTH_CASES) + def test_thumbnail_auth_invalid(self, invalid_auth, expected_code, expected_fragment, add_document_func): + _, doc_id = add_document_func + res = document_thumbnail(invalid_auth, doc_id) + assert res.status_code == expected_code, res.text + assert expected_fragment in res.text, res.text + ## The inputs has been changed to add 'doc_ids' ## TODO: #@pytest.mark.p2 @@ -451,8 +460,60 @@ async def raise_error(*_args, **_kwargs): assert "download boom" in res["message"] - @pytest.mark.skip(reason="Moved to /api/v1/documents/images/") - def test_get_image_success_and_exception_unit(self, document_app_module, monkeypatch): + def test_get_document_thumbnail_success_and_exception_unit(self, document_app_module, monkeypatch): + module = document_app_module + + class _Headers(dict): + def set(self, key, value): + self[key] = value + + class _ImageResponse: + def __init__(self, data): + self.data = data + self.headers = _Headers() + + async def fake_thread_pool_exec(*_args, **_kwargs): + return b"image-bytes" + + async def fake_make_response(data): + return _ImageResponse(data) + + kb_accessible_calls = [] + monkeypatch.setattr( + module.DocumentService, + "get_by_id", + lambda _doc_id: (True, SimpleNamespace(kb_id="kb-1", thumbnail="thumbnail_doc-1.png")), + ) + def fake_kb_accessible_denied(kb_id, user_id): + kb_accessible_calls.append((kb_id, user_id)) + return False + + monkeypatch.setattr(module.KnowledgebaseService, "accessible", fake_kb_accessible_denied) + res = _run(module.get_document_thumbnail("doc-1")) + assert res["code"] == RetCode.DATA_ERROR + assert "Document not found!" in res["message"] + assert kb_accessible_calls == [("kb-1", "user-1")] + + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda _kb_id, _user_id: True) + monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec) + monkeypatch.setattr(module, "make_response", fake_make_response) + monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"image-bytes")) + res = _run(module.get_document_thumbnail("doc-1")) + assert isinstance(res, _ImageResponse) + assert res.data == b"image-bytes" + assert res.headers["content_type"] == "image/png" + assert res.headers["extension"] == "png" + + async def raise_error(*_args, **_kwargs): + raise RuntimeError("image boom") + + monkeypatch.setattr(module, "thread_pool_exec", raise_error) + monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) + res = _run(module.get_document_thumbnail("doc-1")) + assert res["code"] == 500 + assert "image boom" in res["message"] + + def test_get_document_image_authorization_success_and_exception_unit(self, document_app_module, monkeypatch): module = document_app_module class _Headers(dict): @@ -470,20 +531,27 @@ async def fake_thread_pool_exec(*_args, **_kwargs): async def fake_make_response(data): return _ImageResponse(data) + monkeypatch.setattr(module, "_get_accessible_chunk_image_doc_id", lambda _image_id: None) + res = _run(module.get_document_image("bucket-name.jpg")) + assert res["code"] == RetCode.DATA_ERROR + assert "Image not found." in res["message"] + + monkeypatch.setattr(module, "_get_accessible_chunk_image_doc_id", lambda _image_id: "doc-1") monkeypatch.setattr(module, "thread_pool_exec", fake_thread_pool_exec) monkeypatch.setattr(module, "make_response", fake_make_response) monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=lambda *_args, **_kwargs: b"image-bytes")) - res = _run(module.get_image("bucket-name")) + res = _run(module.get_document_image("bucket-name.jpg")) assert isinstance(res, _ImageResponse) assert res.data == b"image-bytes" - assert res.headers["Content-Type"] == "image/JPEG" + assert res.headers["content_type"] == "image/jpeg" + assert res.headers["extension"] == "jpg" async def raise_error(*_args, **_kwargs): raise RuntimeError("image boom") monkeypatch.setattr(module, "thread_pool_exec", raise_error) monkeypatch.setattr(module, "server_error_response", lambda e: {"code": 500, "message": str(e)}) - res = _run(module.get_image("bucket-name")) + res = _run(module.get_document_image("bucket-name.jpg")) assert res["code"] == 500 assert "image boom" in res["message"]