Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions chromadb/api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ class Function(str, Enum):
RECORD_COUNTER = "record_counter"
"""Counts records in a collection."""

DUMMY_ASYNC = "dummy_async"
"""Async test helper function used for distributed task API coverage."""

# Used only for failure testing - not a real function
_NONEXISTENT_TEST_ONLY = "nonexistent_function"


# Convenience aliases for cleaner imports
STATISTICS_FUNCTION = Function.STATISTICS
RECORD_COUNTER_FUNCTION = Function.RECORD_COUNTER
DUMMY_ASYNC_FUNCTION = Function.DUMMY_ASYNC
189 changes: 187 additions & 2 deletions chromadb/test/distributed/test_task_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from chromadb.api.client import Client as ClientCreator
from chromadb.api.functions import (
DUMMY_ASYNC_FUNCTION,
RECORD_COUNTER_FUNCTION,
STATISTICS_FUNCTION,
Function,
Expand Down Expand Up @@ -360,8 +361,10 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
collection.detach_function(attached_fn.name, delete_output_collection=True)


def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
"""Test that attaching a function to an output collection fails"""
def test_attach_to_output_collection_fails_for_sync_upstream(
basic_http_client: System,
) -> None:
"""Test that attaching a function to an output collection still fails when an upstream function is sync"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

Expand All @@ -388,6 +391,188 @@ def test_attach_to_output_collection_fails(basic_http_client: System) -> None:
)


def test_attach_to_output_collection_succeeds_for_async_upstream(
basic_http_client: System,
) -> None:
"""Test that attaching a function to an output collection succeeds when all upstream functions are async"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

input_collection = client.create_collection(name="async_input_collection")
input_collection.add(ids=["id1"], documents=["test"])

_, _ = input_collection.attach_function(
name="async_test_function",
function=DUMMY_ASYNC_FUNCTION,
output_collection="async_output_collection",
params=None,
)
output_collection = client.get_collection(name="async_output_collection")

attached_fn, created = output_collection.attach_function(
name="downstream_test_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="downstream_output_collection",
params=None,
)

assert attached_fn is not None
assert created is True


def test_attach_to_output_collection_fails_for_mixed_sync_and_async_upstream(
basic_http_client: System,
) -> None:
"""Test that attaching to an output collection fails when upstream functions are a mix of sync and async"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

async_input_collection = client.create_collection(
name="mixed_async_input_collection"
)
async_input_collection.add(ids=["id1"], documents=["test"])

sync_input_collection = client.create_collection(name="mixed_sync_input_collection")
sync_input_collection.add(ids=["id2"], documents=["test"])

_, _ = async_input_collection.attach_function(
name="mixed_async_upstream",
function=DUMMY_ASYNC_FUNCTION,
output_collection="mixed_output_collection",
params=None,
)

_, _ = sync_input_collection.attach_function(
name="mixed_sync_upstream",
function=RECORD_COUNTER_FUNCTION,
output_collection="mixed_output_collection",
params=None,
)

output_collection = client.get_collection(name="mixed_output_collection")

with pytest.raises(
ChromaError, match="cannot attach function to an output collection"
):
_ = output_collection.attach_function(
name="mixed_downstream_test_function",
function=RECORD_COUNTER_FUNCTION,
output_collection="mixed_downstream_output_collection",
params=None,
)


def test_attach_to_existing_output_collection_rejects_cycle(
basic_http_client: System,
) -> None:
"""Test that attaching to an existing output collection rejects a cycle like A -> B -> C -> A"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

collection_a = client.create_collection(name="cycle_collection_a")
collection_a.add(ids=["id1"], documents=["doc1"])

_, _ = collection_a.attach_function(
name="a_to_b",
function=DUMMY_ASYNC_FUNCTION,
output_collection="cycle_collection_b",
params=None,
)

collection_b = client.get_collection(name="cycle_collection_b")

_, _ = collection_b.attach_function(
name="b_to_c",
function=DUMMY_ASYNC_FUNCTION,
output_collection="cycle_collection_c",
params=None,
)

collection_c = client.get_collection(name="cycle_collection_c")

with pytest.raises(
ChromaError, match="cannot attach function to an output collection"
):
collection_c.attach_function(
name="c_to_a",
function=RECORD_COUNTER_FUNCTION,
output_collection="cycle_collection_a",
params=None,
)


def test_attach_function_rejects_depth_above_maximum(
basic_http_client: System,
) -> None:
"""Test that attach_function rejects chains deeper than the configured maximum depth"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

current_collection = client.create_collection(name="depth_collection_0")
current_collection.add(ids=["id0"], documents=["doc0"])

for i in range(1, 6):
_, _ = current_collection.attach_function(
name=f"depth_edge_{i}",
function=DUMMY_ASYNC_FUNCTION,
output_collection=f"depth_collection_{i}",
params=None,
)
current_collection = client.get_collection(name=f"depth_collection_{i}")

with pytest.raises(
ChromaError, match="attached function depth exceeds maximum of 5"
):
current_collection.attach_function(
name="depth_edge_6",
function=RECORD_COUNTER_FUNCTION,
output_collection="depth_collection_6",
params=None,
)


def test_attach_function_rejects_when_connecting_two_chains_exceeds_maximum_depth(
basic_http_client: System,
) -> None:
"""Test that attach_function rejects connecting two valid chains if the combined path would exceed the maximum depth"""
client = ClientCreator.from_system(basic_http_client)
client.reset()

left_current = client.create_collection(name="left_depth_collection_0")
left_current.add(ids=["left_id0"], documents=["left_doc0"])

for i in range(1, 3):
_, _ = left_current.attach_function(
name=f"left_depth_edge_{i}",
function=DUMMY_ASYNC_FUNCTION,
output_collection=f"left_depth_collection_{i}",
params=None,
)
left_current = client.get_collection(name=f"left_depth_collection_{i}")

right_current = client.create_collection(name="right_depth_collection_0")
right_current.add(ids=["right_id0"], documents=["right_doc0"])

for i in range(1, 4):
_, _ = right_current.attach_function(
name=f"right_depth_edge_{i}",
function=DUMMY_ASYNC_FUNCTION,
output_collection=f"right_depth_collection_{i}",
params=None,
)
right_current = client.get_collection(name=f"right_depth_collection_{i}")

with pytest.raises(
ChromaError, match="attached function depth exceeds maximum of 5"
):
left_current.attach_function(
name="bridge_two_chains",
function=RECORD_COUNTER_FUNCTION,
output_collection="right_depth_collection_0",
params=None,
)


def test_delete_output_collection_detaches_function(basic_http_client: System) -> None:
"""Test that deleting an output collection also detaches the attached function"""
client = ClientCreator.from_system(basic_http_client)
Expand Down
Loading
Loading