Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions rag/utils/azure_sas_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __close__(self):
self.conn = None

def health(self):
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=f"{_bucket}/{fnm}", data=BytesIO(binary), length=len(binary))
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=f"{bucket}/{fnm}", data=BytesIO(binary), length=len(binary))

def put(self, bucket, fnm, binary, tenant_id=None):
blob_name = f"{bucket}/{fnm}"
Expand All @@ -62,10 +62,11 @@ def put(self, bucket, fnm, binary, tenant_id=None):
time.sleep(1)

def rm(self, bucket, fnm):
blob_name = f"{bucket}/{fnm}"
try:
self.conn.delete_blob(f"{bucket}/{fnm}")
self.conn.delete_blob(blob_name)
except Exception:
logging.exception(f"Fail rm {bucket}/{fnm}")
logging.exception(f"Fail rm {blob_name}")

def get(self, bucket, fnm):
blob_name = f"{bucket}/{fnm}"
Expand All @@ -82,7 +83,7 @@ def get(self, bucket, fnm):
def obj_exist(self, bucket, fnm):
blob_name = f"{bucket}/{fnm}"
try:
return self.conn.get_blob_client(f"{blob_name}").exists()
return self.conn.get_blob_client(blob_name).exists()
except Exception:
logging.exception(f"Fail put {blob_name}")
return False
Expand All @@ -91,7 +92,7 @@ def get_presigned_url(self, bucket, fnm, expires):
blob_name = f"{bucket}/{fnm}"
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, blob_name, expires)
return self.conn.get_blob_client(blob_name).url
except Exception:
logging.exception(f"fail get {blob_name}")
self.__open__()
Expand Down
17 changes: 9 additions & 8 deletions rag/utils/azure_spn_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ def __close__(self):
self.conn = None

def health(self):
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(f"{_bucket}/{fnm}")
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(f"{bucket}/{fnm}")
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))

def put(self, bucket, fnm, binary, tenant_id=None):
blob = f"{bucket}/{fnm}"
for _ in range(3):
try:
f = self.conn.create_file(f"{blob}")
f = self.conn.create_file(blob)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))
except Exception:
Expand All @@ -85,15 +85,15 @@ def put(self, bucket, fnm, binary, tenant_id=None):
def rm(self, bucket, fnm):
blob = f"{bucket}/{fnm}"
try:
self.conn.delete_file(f"{blob}")
self.conn.delete_file(blob)
except Exception:
logging.exception(f"Fail rm {blob}")

def get(self, bucket, fnm):
blob = f"{bucket}/{fnm}"
for _ in range(1):
try:
client = self.conn.get_file_client(f"{blob}")
client = self.conn.get_file_client(blob)
r = client.download_file()
return r.read()
except Exception:
Expand All @@ -105,17 +105,18 @@ def get(self, bucket, fnm):
def obj_exist(self, bucket, fnm):
blob = f"{bucket}/{fnm}"
try:
client = self.conn.get_blob_client(f"{blob}")
client = self.conn.get_file_client(blob)
return client.exists()
except Exception:
logging.exception(f"Fail put {blob}")
return False

def get_presigned_url(self, bucket, fnm, expires):
f_path = f"{bucket}/{fnm}"
blob = f"{bucket}/{fnm}"
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, f_path, expires)
client = self.conn.get_file_client(blob)
return client.url
except Exception:
logging.exception(f"fail get {bucket}/{fnm}")
self.__open__()
Expand Down
185 changes: 185 additions & 0 deletions test/unit_test/rag/utils/test_azure_blob_bucket_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for Azure Blob storage path construction (issue #14159).

Both AzureSpn and AzureSas implementations must prepend the bucket
parameter to file paths so that files with the same name from different
datasets do not overwrite each other in flat blob storage.
"""
import importlib
import sys
import types
from unittest.mock import MagicMock

import pytest


def _install_stubs():
"""Replace heavyweight runtime modules so the connection modules can be
imported in isolation without the full ragflow runtime or the real
Comment on lines +31 to +33
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Isolate sys.modules stubbing to avoid cross-test pollution.

These fixtures replace global module entries but never restore them. That can make unrelated tests fail depending on run order. Please switch to monkeypatch.setitem (or restore snapshot on teardown) so patches are automatically reverted.

Also applies to: 74-83, 86-97

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@test/unit_test/rag/utils/test_azure_blob_bucket_prefix.py` around lines 31 -
33, The helper _install_stubs currently mutates sys.modules directly and never
restores entries, causing cross-test pollution; change the stubbing to use
pytest's monkeypatch.setitem to insert the fake modules (e.g., replace
sys.modules["ragflow.runtime"] etc.) so patches are automatically reverted after
each test, and update the three other stub blocks (the ones creating fake
modules around the same area) to use monkeypatch.setitem instead of direct
sys.modules assignment or add explicit teardown that restores originals if
monkeypatch is unavailable; locate uses in function _install_stubs and the other
stub creation sites and replace sys.modules[...] = fake_module with
monkeypatch.setitem(sys.modules, key, fake_module).

`azure` SDK being installed."""

decorator_mod = types.ModuleType("common.decorator")
decorator_mod.singleton = lambda cls: cls

settings_mod = types.ModuleType("common.settings")
settings_mod.AZURE = {
"account_url": "https://example.dfs.core.windows.net",
"client_id": "x",
"secret": "x",
"tenant_id": "x",
"container_name": "c",
"cloud": "public",
"container_url": "https://example.blob.core.windows.net/c",
"sas_token": "sig=x",
}

common_pkg = types.ModuleType("common")
common_pkg.decorator = decorator_mod
common_pkg.settings = settings_mod

azure_pkg = types.ModuleType("azure")
azure_identity = types.ModuleType("azure.identity")
azure_identity.ClientSecretCredential = MagicMock()
azure_identity.AzureAuthorityHosts = types.SimpleNamespace(
AZURE_PUBLIC_CLOUD="public",
AZURE_CHINA="china",
AZURE_GOVERNMENT="gov",
AZURE_GERMANY="de",
)
azure_storage = types.ModuleType("azure.storage")
azure_fdl = types.ModuleType("azure.storage.filedatalake")
azure_fdl.FileSystemClient = MagicMock()
azure_blob = types.ModuleType("azure.storage.blob")
azure_blob.ContainerClient = MagicMock()
azure_pkg.identity = azure_identity
azure_pkg.storage = azure_storage
azure_storage.filedatalake = azure_fdl
azure_storage.blob = azure_blob

sys.modules.update({
"common": common_pkg,
"common.decorator": decorator_mod,
"common.settings": settings_mod,
"azure": azure_pkg,
"azure.identity": azure_identity,
"azure.storage": azure_storage,
"azure.storage.filedatalake": azure_fdl,
"azure.storage.blob": azure_blob,
})


@pytest.fixture(scope="module")
def spn_module():
_install_stubs()
sys.modules.pop("rag.utils.azure_spn_conn", None)
return importlib.import_module("rag.utils.azure_spn_conn")


@pytest.fixture(scope="module")
def sas_module():
_install_stubs()
sys.modules.pop("rag.utils.azure_sas_conn", None)
return importlib.import_module("rag.utils.azure_sas_conn")


def _make_instance(module, cls_name):
"""Build an instance with a mocked underlying connection, bypassing
__init__ so we don't need real Azure credentials or connectivity."""
cls = getattr(module, cls_name)
inst = cls.__new__(cls)
inst.conn = MagicMock()
return inst


class TestAzureSpnBucketPrefix:
"""RAGFlowAzureSpnBlob must include the bucket as a path prefix in all
operations so that identical filenames from different datasets are
isolated."""

def test_put_uses_bucket_prefix(self, spn_module):
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.put("kb_a", "doc.pdf", b"data")
spn.conn.create_file.assert_called_once_with("kb_a/doc.pdf")

def test_get_uses_bucket_prefix(self, spn_module):
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.get("kb_a", "doc.pdf")
spn.conn.get_file_client.assert_called_once_with("kb_a/doc.pdf")

def test_rm_uses_bucket_prefix(self, spn_module):
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.rm("kb_a", "doc.pdf")
spn.conn.delete_file.assert_called_once_with("kb_a/doc.pdf")

def test_obj_exist_uses_bucket_prefix(self, spn_module):
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.obj_exist("kb_a", "doc.pdf")
spn.conn.get_file_client.assert_called_once_with("kb_a/doc.pdf")

def test_get_presigned_url_uses_bucket_prefix(self, spn_module):
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.get_presigned_url("kb_a", "doc.pdf", 3600)
spn.conn.get_presigned_url.assert_called_once_with("GET", "kb_a/doc.pdf", 3600)

def test_same_filename_in_different_buckets_does_not_collide(self, spn_module):
"""Regression test for issue #14159: two datasets uploading a file
with the same name must produce two distinct storage paths."""
spn = _make_instance(spn_module, "RAGFlowAzureSpnBlob")
spn.put("kb_a", "report.pdf", b"data_a")
spn.put("kb_b", "report.pdf", b"data_b")
called_paths = [c.args[0] for c in spn.conn.create_file.call_args_list]
assert called_paths == ["kb_a/report.pdf", "kb_b/report.pdf"]
assert called_paths[0] != called_paths[1]


class TestAzureSasBucketPrefix:
"""Same contract for RAGFlowAzureSasBlob."""

def test_put_uses_bucket_prefix(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.put("kb_a", "doc.pdf", b"data")
kwargs = sas.conn.upload_blob.call_args.kwargs
assert kwargs["name"] == "kb_a/doc.pdf"

def test_get_uses_bucket_prefix(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.get("kb_a", "doc.pdf")
sas.conn.download_blob.assert_called_once_with("kb_a/doc.pdf")

def test_rm_uses_bucket_prefix(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.rm("kb_a", "doc.pdf")
sas.conn.delete_blob.assert_called_once_with("kb_a/doc.pdf")

def test_obj_exist_uses_bucket_prefix(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.obj_exist("kb_a", "doc.pdf")
sas.conn.get_blob_client.assert_called_once_with("kb_a/doc.pdf")

def test_get_presigned_url_uses_bucket_prefix(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.get_presigned_url("kb_a", "doc.pdf", 3600)
sas.conn.get_presigned_url.assert_called_once_with("GET", "kb_a/doc.pdf", 3600)

def test_same_filename_in_different_buckets_does_not_collide(self, sas_module):
sas = _make_instance(sas_module, "RAGFlowAzureSasBlob")
sas.put("kb_a", "report.pdf", b"data_a")
sas.put("kb_b", "report.pdf", b"data_b")
names = [c.kwargs["name"] for c in sas.conn.upload_blob.call_args_list]
assert names == ["kb_a/report.pdf", "kb_b/report.pdf"]
assert names[0] != names[1]
Loading