From 99a379ec5c471849ee0a9d2f7c4a93db0324a51f Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Mon, 25 May 2026 15:06:30 -0700 Subject: [PATCH 1/2] [ENH]: Allow cascading functions --- chromadb/api/functions.py | 4 + chromadb/test/distributed/test_task_api.py | 78 +++++++- go/pkg/sysdb/coordinator/create_task_test.go | 178 +++++++++++++++++++ go/pkg/sysdb/coordinator/task.go | 40 ++++- 4 files changed, 295 insertions(+), 5 deletions(-) diff --git a/chromadb/api/functions.py b/chromadb/api/functions.py index 86647172b35..963d8327a06 100644 --- a/chromadb/api/functions.py +++ b/chromadb/api/functions.py @@ -24,6 +24,9 @@ class Function(str, Enum): RECORD_COUNTER = "record_counter" """Counts records in a collection.""" + DUMMY_ASYNC = "dummy_async" + """Async test helper function used for distributed task API coverage.""" + # Used only for failure testing - not a real function _NONEXISTENT_TEST_ONLY = "nonexistent_function" @@ -31,3 +34,4 @@ class Function(str, Enum): # Convenience aliases for cleaner imports STATISTICS_FUNCTION = Function.STATISTICS RECORD_COUNTER_FUNCTION = Function.RECORD_COUNTER +DUMMY_ASYNC_FUNCTION = Function.DUMMY_ASYNC diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py index 487553efd2b..c4f3281d7ae 100644 --- a/chromadb/test/distributed/test_task_api.py +++ b/chromadb/test/distributed/test_task_api.py @@ -8,6 +8,7 @@ import pytest from chromadb.api.client import Client as ClientCreator from chromadb.api.functions import ( + DUMMY_ASYNC_FUNCTION, RECORD_COUNTER_FUNCTION, STATISTICS_FUNCTION, Function, @@ -360,8 +361,10 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None: collection.detach_function(attached_fn.name, delete_output_collection=True) -def test_attach_to_output_collection_fails(basic_http_client: System) -> None: - """Test that attaching a function to an output collection fails""" +def test_attach_to_output_collection_fails_for_sync_upstream( + basic_http_client: System, +) -> None: + """Test that attaching a function to an output collection still fails when an upstream function is sync""" client = ClientCreator.from_system(basic_http_client) client.reset() @@ -388,6 +391,77 @@ def test_attach_to_output_collection_fails(basic_http_client: System) -> None: ) +def test_attach_to_output_collection_succeeds_for_async_upstream( + basic_http_client: System, +) -> None: + """Test that attaching a function to an output collection succeeds when all upstream functions are async""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + input_collection = client.create_collection(name="async_input_collection") + input_collection.add(ids=["id1"], documents=["test"]) + + _, _ = input_collection.attach_function( + name="async_test_function", + function=DUMMY_ASYNC_FUNCTION, + output_collection="async_output_collection", + params=None, + ) + output_collection = client.get_collection(name="async_output_collection") + + attached_fn, created = output_collection.attach_function( + name="downstream_test_function", + function=RECORD_COUNTER_FUNCTION, + output_collection="downstream_output_collection", + params=None, + ) + + assert attached_fn is not None + assert created is True + + +def test_attach_to_output_collection_fails_for_mixed_sync_and_async_upstream( + basic_http_client: System, +) -> None: + """Test that attaching to an output collection fails when upstream functions are a mix of sync and async""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + async_input_collection = client.create_collection( + name="mixed_async_input_collection" + ) + async_input_collection.add(ids=["id1"], documents=["test"]) + + sync_input_collection = client.create_collection(name="mixed_sync_input_collection") + sync_input_collection.add(ids=["id2"], documents=["test"]) + + _, _ = async_input_collection.attach_function( + name="mixed_async_upstream", + function=DUMMY_ASYNC_FUNCTION, + output_collection="mixed_output_collection", + params=None, + ) + + _, _ = sync_input_collection.attach_function( + name="mixed_sync_upstream", + function=RECORD_COUNTER_FUNCTION, + output_collection="mixed_output_collection", + params=None, + ) + + output_collection = client.get_collection(name="mixed_output_collection") + + with pytest.raises( + ChromaError, match="cannot attach function to an output collection" + ): + _ = output_collection.attach_function( + name="mixed_downstream_test_function", + function=RECORD_COUNTER_FUNCTION, + output_collection="mixed_downstream_output_collection", + params=None, + ) + + def test_delete_output_collection_detaches_function(basic_http_client: System) -> None: """Test that deleting an output collection also detaches the attached function""" client = ClientCreator.from_system(basic_http_client) diff --git a/go/pkg/sysdb/coordinator/create_task_test.go b/go/pkg/sysdb/coordinator/create_task_test.go index 554db25728f..a541fe40c26 100644 --- a/go/pkg/sysdb/coordinator/create_task_test.go +++ b/go/pkg/sysdb/coordinator/create_task_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/chroma-core/chroma/go/pkg/common" "github.com/chroma-core/chroma/go/pkg/memberlist_manager" "github.com/chroma-core/chroma/go/pkg/proto/coordinatorpb" "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" @@ -329,6 +330,183 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea suite.mockDatabaseDb.AssertExpectations(suite.T()) } +func (suite *AttachFunctionTestSuite) TestAttachFunction_AllowsOutputCollectionInputWhenAllUpstreamFunctionsAreAsync() { + ctx := context.Background() + + attachedFunctionName := "async-chain-attached-function" + inputCollectionID := "output-of-async-upstream" + outputCollectionName := "next-output-collection" + functionName := "record_counter" + tenantID := "test-tenant" + databaseName := "test-database" + databaseID := "database-uuid" + functionID := dbmodel.FunctionRecordCounter + upstreamFunctionID := uuid.New() + minRecordsForInvocation := uint64(100) + + request := &coordinatorpb.AttachFunctionRequest{ + Name: attachedFunctionName, + InputCollectionId: inputCollectionID, + OutputCollectionName: outputCollectionName, + FunctionName: functionName, + TenantId: tenantID, + Database: databaseName, + MinRecordsForInvocation: minRecordsForInvocation, + } + + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Once() + + suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByName", functionName). + Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() + + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("GetCollections", + []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() + + existingUpstream := &dbmodel.AttachedFunction{ + ID: uuid.New(), + Name: "upstream-async-function", + InputCollectionID: "root-input-collection", + OutputCollectionID: func() *string { + id := inputCollectionID + return &id + }(), + FunctionID: upstreamFunctionID, + } + + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Once() + + suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByIDs", []uuid.UUID{upstreamFunctionID}). + Return([]*dbmodel.Function{{ID: upstreamFunctionID, Name: "async-upstream", IsAsync: true}}, nil).Once() + + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("GetCollections", + []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() + + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("Insert", mock.MatchedBy(func(attachedFunction *dbmodel.AttachedFunction) bool { + return attachedFunction.Name == attachedFunctionName && + attachedFunction.InputCollectionID == inputCollectionID && + attachedFunction.OutputCollectionName == outputCollectionName && + attachedFunction.FunctionID == functionID && + attachedFunction.TenantID == tenantID && + attachedFunction.DatabaseID == databaseID && + attachedFunction.MinRecordsForInvocation == int64(minRecordsForInvocation) + })).Return(nil).Once() + + suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(args mock.Arguments) { + txFunc := args.Get(1).(func(context.Context) error) + err := txFunc(context.Background()) + suite.NoError(err) + }).Return(nil).Once() + + response, err := suite.coordinator.AttachFunction(ctx, request) + + suite.NoError(err) + suite.NotNil(response) + suite.NotEmpty(response.AttachedFunction.Id) + + suite.mockMetaDomain.AssertExpectations(suite.T()) + suite.mockAttachedFunctionDb.AssertExpectations(suite.T()) + suite.mockFunctionDb.AssertExpectations(suite.T()) + suite.mockDatabaseDb.AssertExpectations(suite.T()) + suite.mockCollectionDb.AssertExpectations(suite.T()) + suite.mockTxImpl.AssertExpectations(suite.T()) +} + +func (suite *AttachFunctionTestSuite) TestAttachFunction_RejectsOutputCollectionInputWhenAnyUpstreamFunctionIsSync() { + ctx := context.Background() + + attachedFunctionName := "blocked-chain-attached-function" + inputCollectionID := "output-of-mixed-upstream" + outputCollectionName := "next-output-collection" + functionName := "record_counter" + tenantID := "test-tenant" + databaseName := "test-database" + databaseID := "database-uuid" + functionID := dbmodel.FunctionRecordCounter + upstreamFunctionID := uuid.New() + minRecordsForInvocation := uint64(100) + + request := &coordinatorpb.AttachFunctionRequest{ + Name: attachedFunctionName, + InputCollectionId: inputCollectionID, + OutputCollectionName: outputCollectionName, + FunctionName: functionName, + TenantId: tenantID, + Database: databaseName, + MinRecordsForInvocation: minRecordsForInvocation, + } + + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Once() + + suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() + suite.mockDatabaseDb.On("GetDatabases", tenantID, databaseName). + Return([]*dbmodel.Database{{ID: databaseID, Name: databaseName}}, nil).Once() + + suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByName", functionName). + Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() + + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() + suite.mockCollectionDb.On("GetCollections", + []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() + + existingUpstream := &dbmodel.AttachedFunction{ + ID: uuid.New(), + Name: "upstream-sync-function", + InputCollectionID: "root-input-collection", + OutputCollectionID: func() *string { + id := inputCollectionID + return &id + }(), + FunctionID: upstreamFunctionID, + } + + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Once() + + suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() + suite.mockFunctionDb.On("GetByIDs", []uuid.UUID{upstreamFunctionID}). + Return([]*dbmodel.Function{{ID: upstreamFunctionID, Name: "sync-upstream", IsAsync: false}}, nil).Once() + + suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). + Run(func(args mock.Arguments) { + txFunc := args.Get(1).(func(context.Context) error) + err := txFunc(context.Background()) + suite.ErrorIs(err, common.ErrCannotAttachToOutputCollection) + }).Return(common.ErrCannotAttachToOutputCollection).Once() + + response, err := suite.coordinator.AttachFunction(ctx, request) + + suite.ErrorIs(err, common.ErrCannotAttachToOutputCollection) + suite.Nil(response) + + suite.mockMetaDomain.AssertExpectations(suite.T()) + suite.mockAttachedFunctionDb.AssertExpectations(suite.T()) + suite.mockFunctionDb.AssertExpectations(suite.T()) + suite.mockDatabaseDb.AssertExpectations(suite.T()) + suite.mockCollectionDb.AssertExpectations(suite.T()) + suite.mockTxImpl.AssertExpectations(suite.T()) +} + // TestAttachFunction_RecoveryFlow tests the realistic recovery scenario: // - First AttachFunction: Phase 1 succeeds (attached function created), Phase 2 fails (heap error) // - Attached function left in incomplete state (lowest_live_nonce = NULL) diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index dad11895c34..16689b4609d 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -174,9 +174,43 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return err } if len(attachedFunctionsUsingAsOutput) > 0 { - log.Error("AttachFunction: cannot attach function to a collection that is already an output collection", - zap.String("collection_id", req.InputCollectionId)) - return common.ErrCannotAttachToOutputCollection + functionIDs := make([]uuid.UUID, 0, len(attachedFunctionsUsingAsOutput)) + seenFunctionIDs := make(map[uuid.UUID]struct{}, len(attachedFunctionsUsingAsOutput)) + for _, attachedFunction := range attachedFunctionsUsingAsOutput { + if _, ok := seenFunctionIDs[attachedFunction.FunctionID]; ok { + continue + } + seenFunctionIDs[attachedFunction.FunctionID] = struct{}{} + functionIDs = append(functionIDs, attachedFunction.FunctionID) + } + + functions, err := s.catalog.metaDomain.FunctionDb(txCtx).GetByIDs(functionIDs) + if err != nil { + log.Error("AttachFunction: failed to load functions for output collection validation", zap.Error(err)) + return err + } + + functionsByID := make(map[uuid.UUID]*dbmodel.Function, len(functions)) + for _, existingFunction := range functions { + functionsByID[existingFunction.ID] = existingFunction + } + + for _, attachedFunction := range attachedFunctionsUsingAsOutput { + existingFunction, ok := functionsByID[attachedFunction.FunctionID] + if !ok { + log.Error("AttachFunction: attached function references unknown function during output collection validation", + zap.String("collection_id", req.InputCollectionId), + zap.Stringer("function_id", attachedFunction.FunctionID)) + return common.ErrFunctionNotFound + } + if !existingFunction.IsAsync { + log.Error("AttachFunction: cannot attach function to a collection that is already an output collection with sync upstream functions", + zap.String("collection_id", req.InputCollectionId), + zap.Stringer("function_id", attachedFunction.FunctionID), + zap.String("function_name", existingFunction.Name)) + return common.ErrCannotAttachToOutputCollection + } + } } // Check if output collection already exists From 39d1c5f707b549a17422a2a8a4a5193d9c34958b Mon Sep 17 00:00:00 2001 From: Tanuj Nayak Date: Mon, 25 May 2026 15:44:44 -0700 Subject: [PATCH 2/2] add cycle detection --- chromadb/test/distributed/test_task_api.py | 111 ++++++ go/pkg/sysdb/coordinator/create_task_test.go | 71 ++-- go/pkg/sysdb/coordinator/task.go | 384 ++++++++++++++++++- go/pkg/sysdb/coordinator/task_graph_test.go | 136 +++++++ 4 files changed, 668 insertions(+), 34 deletions(-) create mode 100644 go/pkg/sysdb/coordinator/task_graph_test.go diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py index c4f3281d7ae..17d2a13140e 100644 --- a/chromadb/test/distributed/test_task_api.py +++ b/chromadb/test/distributed/test_task_api.py @@ -462,6 +462,117 @@ def test_attach_to_output_collection_fails_for_mixed_sync_and_async_upstream( ) +def test_attach_to_existing_output_collection_rejects_cycle( + basic_http_client: System, +) -> None: + """Test that attaching to an existing output collection rejects a cycle like A -> B -> C -> A""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + collection_a = client.create_collection(name="cycle_collection_a") + collection_a.add(ids=["id1"], documents=["doc1"]) + + _, _ = collection_a.attach_function( + name="a_to_b", + function=DUMMY_ASYNC_FUNCTION, + output_collection="cycle_collection_b", + params=None, + ) + + collection_b = client.get_collection(name="cycle_collection_b") + + _, _ = collection_b.attach_function( + name="b_to_c", + function=DUMMY_ASYNC_FUNCTION, + output_collection="cycle_collection_c", + params=None, + ) + + collection_c = client.get_collection(name="cycle_collection_c") + + with pytest.raises( + ChromaError, match="cannot attach function to an output collection" + ): + collection_c.attach_function( + name="c_to_a", + function=RECORD_COUNTER_FUNCTION, + output_collection="cycle_collection_a", + params=None, + ) + + +def test_attach_function_rejects_depth_above_maximum( + basic_http_client: System, +) -> None: + """Test that attach_function rejects chains deeper than the configured maximum depth""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + current_collection = client.create_collection(name="depth_collection_0") + current_collection.add(ids=["id0"], documents=["doc0"]) + + for i in range(1, 6): + _, _ = current_collection.attach_function( + name=f"depth_edge_{i}", + function=DUMMY_ASYNC_FUNCTION, + output_collection=f"depth_collection_{i}", + params=None, + ) + current_collection = client.get_collection(name=f"depth_collection_{i}") + + with pytest.raises( + ChromaError, match="attached function depth exceeds maximum of 5" + ): + current_collection.attach_function( + name="depth_edge_6", + function=RECORD_COUNTER_FUNCTION, + output_collection="depth_collection_6", + params=None, + ) + + +def test_attach_function_rejects_when_connecting_two_chains_exceeds_maximum_depth( + basic_http_client: System, +) -> None: + """Test that attach_function rejects connecting two valid chains if the combined path would exceed the maximum depth""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + left_current = client.create_collection(name="left_depth_collection_0") + left_current.add(ids=["left_id0"], documents=["left_doc0"]) + + for i in range(1, 3): + _, _ = left_current.attach_function( + name=f"left_depth_edge_{i}", + function=DUMMY_ASYNC_FUNCTION, + output_collection=f"left_depth_collection_{i}", + params=None, + ) + left_current = client.get_collection(name=f"left_depth_collection_{i}") + + right_current = client.create_collection(name="right_depth_collection_0") + right_current.add(ids=["right_id0"], documents=["right_doc0"]) + + for i in range(1, 4): + _, _ = right_current.attach_function( + name=f"right_depth_edge_{i}", + function=DUMMY_ASYNC_FUNCTION, + output_collection=f"right_depth_collection_{i}", + params=None, + ) + right_current = client.get_collection(name=f"right_depth_collection_{i}") + + with pytest.raises( + ChromaError, match="attached function depth exceeds maximum of 5" + ): + left_current.attach_function( + name="bridge_two_chains", + function=RECORD_COUNTER_FUNCTION, + output_collection="right_depth_collection_0", + params=None, + ) + + def test_delete_output_collection_detaches_function(basic_http_client: System) -> None: """Test that deleting an output collection also detaches the attached function""" client = ClientCreator.from_system(basic_http_client) diff --git a/go/pkg/sysdb/coordinator/create_task_test.go b/go/pkg/sysdb/coordinator/create_task_test.go index a541fe40c26..5ad5c1f8555 100644 --- a/go/pkg/sysdb/coordinator/create_task_test.go +++ b/go/pkg/sysdb/coordinator/create_task_test.go @@ -33,6 +33,12 @@ type MockMemberlistStore struct { mock.Mock } +func matchStringPtr(want string) interface{} { + return mock.MatchedBy(func(got *string) bool { + return got != nil && *got == want + }) +} + func (m *MockMemberlistStore) GetMemberlist(ctx context.Context) (memberlist memberlist_manager.Memberlist, resourceVersion string, err error) { args := m.Called(ctx) if args.Get(0) == nil { @@ -70,7 +76,7 @@ func (suite *AttachFunctionTestSuite) setupAttachFunctionMocks(ctx context.Conte // Phase 1: Create attached function in transaction // Check if any attached function exists for this collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollectionID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{}, nil).Once() suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() @@ -81,7 +87,6 @@ func (suite *AttachFunctionTestSuite) setupAttachFunctionMocks(ctx context.Conte suite.mockFunctionDb.On("GetByName", functionName). Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() @@ -166,7 +171,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { // Setup mocks that will be called within the transaction (using mock.Anything for context) // Check if any attached function exists for this collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollectionID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{}, nil).Once() // Look up database @@ -180,20 +185,18 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() // Check input collection exists - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() // Check if input collection is being used as an output collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). - Return([]*dbmodel.AttachedFunction{}, nil).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr(inputCollectionID), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Twice() // Check if output collection already exists - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", - []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + []string(nil), matchStringPtr(outputCollectionName), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() // Insert attached function with lowest_live_nonce = NULL @@ -208,6 +211,9 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { attachedFunction.DatabaseID == databaseID && attachedFunction.MinRecordsForInvocation == int64(MinRecordsForInvocation) })).Return(nil).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Maybe() + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Maybe() + suite.mockCollectionDb.On("LockCollection", mock.AnythingOfType("string")).Return((*bool)(nil), nil).Maybe() // Mock the Transaction call itself - it will execute the function suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). @@ -295,9 +301,9 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea txCtx := context.Background() // Inside transaction: check for existing attached functions - suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Maybe() inputCollID := inputCollectionID - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{existingAttachedFunction}, nil).Once() // Note: validateAttachedFunctionMatchesRequest uses dbmodel.GetFunctionNameByID (static lookup), @@ -355,7 +361,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_AllowsOutputCollectionI } suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollectionID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{}, nil).Once() suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() @@ -366,7 +372,6 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_AllowsOutputCollectionI suite.mockFunctionDb.On("GetByName", functionName). Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() @@ -383,16 +388,17 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_AllowsOutputCollectionI } suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). - Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr(inputCollectionID), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Twice() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr("root-input-collection"), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Twice() suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() suite.mockFunctionDb.On("GetByIDs", []uuid.UUID{upstreamFunctionID}). Return([]*dbmodel.Function{{ID: upstreamFunctionID, Name: "async-upstream", IsAsync: true}}, nil).Once() - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", - []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + []string(nil), matchStringPtr(outputCollectionName), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() @@ -405,6 +411,9 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_AllowsOutputCollectionI attachedFunction.DatabaseID == databaseID && attachedFunction.MinRecordsForInvocation == int64(minRecordsForInvocation) })).Return(nil).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Maybe() + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Maybe() + suite.mockCollectionDb.On("LockCollection", mock.AnythingOfType("string")).Return((*bool)(nil), nil).Maybe() suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { @@ -480,13 +489,22 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RejectsOutputCollection } suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). - Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr(inputCollectionID), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{existingUpstream}, nil).Twice() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr("root-input-collection"), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Twice() suite.mockMetaDomain.On("FunctionDb", mock.Anything).Return(suite.mockFunctionDb).Once() suite.mockFunctionDb.On("GetByIDs", []uuid.UUID{upstreamFunctionID}). Return([]*dbmodel.Function{{ID: upstreamFunctionID, Name: "sync-upstream", IsAsync: false}}, nil).Once() + suite.mockCollectionDb.On("GetCollections", + []string(nil), matchStringPtr(outputCollectionName), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Maybe() + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Maybe() + suite.mockCollectionDb.On("LockCollection", mock.AnythingOfType("string")).Return((*bool)(nil), nil).Maybe() + suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { txFunc := args.Get(1).(func(context.Context) error) @@ -543,7 +561,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { // Phase 1: Create attached function in transaction suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollectionID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollectionID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{}, nil).Once() suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() @@ -554,24 +572,25 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { suite.mockFunctionDb.On("GetByName", functionName). Return(&dbmodel.Function{ID: functionID, Name: functionName, IsAsync: false}, nil).Once() - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", []string{inputCollectionID}, (*string)(nil), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{{Collection: &dbmodel.Collection{ID: inputCollectionID}}}, nil).Once() // Check if input collection is being used as an output collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), &inputCollectionID, []uuid.UUID(nil), false). - Return([]*dbmodel.AttachedFunction{}, nil).Once() + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), (*string)(nil), matchStringPtr(inputCollectionID), []uuid.UUID(nil), false). + Return([]*dbmodel.AttachedFunction{}, nil).Twice() // Check if output collection already exists - suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Once() suite.mockCollectionDb.On("GetCollections", - []string(nil), &outputCollectionName, tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). + []string(nil), matchStringPtr(outputCollectionName), tenantID, databaseName, (*int32)(nil), (*int32)(nil), false). Return([]*dbmodel.CollectionAndMetadata{}, nil).Once() suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() suite.mockAttachedFunctionDb.On("Insert", mock.Anything).Return(nil).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Maybe() + suite.mockMetaDomain.On("CollectionDb", mock.Anything).Return(suite.mockCollectionDb).Maybe() + suite.mockCollectionDb.On("LockCollection", mock.AnythingOfType("string")).Return((*bool)(nil), nil).Maybe() suite.mockTxImpl.On("Transaction", ctx, mock.AnythingOfType("func(context.Context) error")). Run(func(args mock.Arguments) { @@ -610,9 +629,9 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { txCtx := context.Background() // Inside transaction: check for existing attached functions - suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() + suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Maybe() inputCollID := inputCollectionID - suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), &inputCollID, (*string)(nil), []uuid.UUID(nil), false). + suite.mockAttachedFunctionDb.On("GetAttachedFunctions", (*uuid.UUID)(nil), (*string)(nil), matchStringPtr(inputCollID), (*string)(nil), []uuid.UUID(nil), false). Return([]*dbmodel.AttachedFunction{incompleteAttachedFunction}, nil).Once() // Note: validateAttachedFunctionMatchesRequest uses dbmodel.GetFunctionNameByID (static lookup), diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index 16689b4609d..79c50f1b965 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -23,6 +24,8 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) +const maxAttachedFunctionDepth = 5 + // validateAttachedFunctionMatchesRequest validates that an existing attached function's parameters match the request parameters. // Returns (true, nil) if all parameters match (idempotent request). // Returns (false, nil) if parameters don't match. @@ -84,6 +87,311 @@ func (s *Coordinator) validateAttachedFunctionMatchesRequest(ctx context.Context return true, nil } +func (s *Coordinator) resolveAttachedFunctionOutputCollectionID(ctx context.Context, attachedFunction *dbmodel.AttachedFunction, databaseName string) (*string, error) { + if attachedFunction.OutputCollectionID != nil { + return attachedFunction.OutputCollectionID, nil + } + + existingCollections, err := s.catalog.metaDomain.CollectionDb(ctx).GetCollections(nil, &attachedFunction.OutputCollectionName, attachedFunction.TenantID, databaseName, nil, nil, false) + if err != nil { + return nil, err + } + if len(existingCollections) == 0 { + return nil, nil + } + + outputCollectionID := existingCollections[0].Collection.ID + return &outputCollectionID, nil +} + +type attachedFunctionGraphState struct { + coordinator *Coordinator + ctx context.Context + databaseName string + // upstreamFunctions caches functions that flow into the to-be input collection. + upstreamFunctions map[string][]*dbmodel.AttachedFunction + // downstreamFunctions caches functions that fan out from the to-be output collection. + downstreamFunctions map[string][]*dbmodel.AttachedFunction +} + +func newAttachedFunctionGraphState(ctx context.Context, coordinator *Coordinator, databaseName string) *attachedFunctionGraphState { + return &attachedFunctionGraphState{ + coordinator: coordinator, + ctx: ctx, + databaseName: databaseName, + upstreamFunctions: make(map[string][]*dbmodel.AttachedFunction), + downstreamFunctions: make(map[string][]*dbmodel.AttachedFunction), + } +} + +func (g *attachedFunctionGraphState) incoming(collectionID string) ([]*dbmodel.AttachedFunction, error) { + if incoming, ok := g.upstreamFunctions[collectionID]; ok { + return incoming, nil + } + + incoming, err := g.coordinator.catalog.metaDomain.AttachedFunctionDb(g.ctx).GetAttachedFunctions(nil, nil, nil, &collectionID, nil, false) + if err != nil { + return nil, err + } + g.upstreamFunctions[collectionID] = incoming + return incoming, nil +} + +func (g *attachedFunctionGraphState) outgoing(collectionID string) ([]*dbmodel.AttachedFunction, error) { + if outgoing, ok := g.downstreamFunctions[collectionID]; ok { + return outgoing, nil + } + + outgoing, err := g.coordinator.catalog.metaDomain.AttachedFunctionDb(g.ctx).GetAttachedFunctions(nil, nil, &collectionID, nil, nil, false) + if err != nil { + return nil, err + } + g.downstreamFunctions[collectionID] = outgoing + return outgoing, nil +} + +func (g *attachedFunctionGraphState) outputCollectionID(attachedFunction *dbmodel.AttachedFunction) (*string, error) { + return g.coordinator.resolveAttachedFunctionOutputCollectionID(g.ctx, attachedFunction, g.databaseName) +} + +func (g *attachedFunctionGraphState) materializeIncoming(collectionID string, remainingDepth int, visited map[string]struct{}) error { + if remainingDepth < 0 { + return nil + } + if _, ok := visited[collectionID]; ok { + return nil + } + visited[collectionID] = struct{}{} + + incoming, err := g.incoming(collectionID) + if err != nil { + return err + } + for _, attachedFunction := range incoming { + if err := g.materializeIncoming(attachedFunction.InputCollectionID, remainingDepth-1, visited); err != nil { + return err + } + } + return nil +} + +func (g *attachedFunctionGraphState) materializeOutgoing(collectionID string, remainingDepth int, visited map[string]struct{}) error { + if remainingDepth < 0 { + return nil + } + if _, ok := visited[collectionID]; ok { + return nil + } + visited[collectionID] = struct{}{} + + outgoing, err := g.outgoing(collectionID) + if err != nil { + return err + } + for _, attachedFunction := range outgoing { + outputCollectionID, err := g.outputCollectionID(attachedFunction) + if err != nil { + return err + } + if outputCollectionID == nil { + continue + } + if err := g.materializeOutgoing(*outputCollectionID, remainingDepth-1, visited); err != nil { + return err + } + } + return nil +} + +func (g *attachedFunctionGraphState) maxPathLength( + collectionID string, + memo map[string]int, + visiting map[string]struct{}, + cycleMessage string, + neighbors func(string) ([]string, error), +) (int, error) { + if depth, ok := memo[collectionID]; ok { + return depth, nil + } + if _, ok := visiting[collectionID]; ok { + return 0, status.Errorf(codes.FailedPrecondition, cycleMessage) + } + + visiting[collectionID] = struct{}{} + defer delete(visiting, collectionID) + + nextCollections, err := neighbors(collectionID) + if err != nil { + return 0, err + } + if len(nextCollections) == 0 { + memo[collectionID] = 0 + return 0, nil + } + + maxDepth := 0 + for _, nextCollectionID := range nextCollections { + childDepth, err := g.maxPathLength(nextCollectionID, memo, visiting, cycleMessage, neighbors) + if err != nil { + return 0, err + } + if childDepth+1 > maxDepth { + maxDepth = childDepth + 1 + } + } + + memo[collectionID] = maxDepth + return maxDepth, nil +} + +func (g *attachedFunctionGraphState) incomingCollectionIDs(collectionID string) ([]string, error) { + incoming, err := g.incoming(collectionID) + if err != nil { + return nil, err + } + + nextCollections := make([]string, 0, len(incoming)) + for _, attachedFunction := range incoming { + nextCollections = append(nextCollections, attachedFunction.InputCollectionID) + } + return nextCollections, nil +} + +func (g *attachedFunctionGraphState) outgoingCollectionIDs(collectionID string) ([]string, error) { + outgoing, err := g.outgoing(collectionID) + if err != nil { + return nil, err + } + + nextCollections := make([]string, 0, len(outgoing)) + for _, attachedFunction := range outgoing { + outputCollectionID, err := g.outputCollectionID(attachedFunction) + if err != nil { + return nil, err + } + if outputCollectionID != nil { + nextCollections = append(nextCollections, *outputCollectionID) + } + } + return nextCollections, nil +} + +func (g *attachedFunctionGraphState) collectionDepth(collectionID string, memo map[string]int, visiting map[string]struct{}) (int, error) { + return g.maxPathLength( + collectionID, + memo, + visiting, + "attached function cycle detected while computing depth", + g.incomingCollectionIDs, + ) +} + +func (g *attachedFunctionGraphState) collectionTailDepth(collectionID string, memo map[string]int, visiting map[string]struct{}) (int, error) { + return g.maxPathLength( + collectionID, + memo, + visiting, + "attached function cycle detected while computing downstream depth", + g.outgoingCollectionIDs, + ) +} + +func (g *attachedFunctionGraphState) reaches(startCollectionID string, targetCollectionID string) (bool, error) { + if startCollectionID == targetCollectionID { + return true, nil + } + + queue := []string{startCollectionID} + visited := map[string]struct{}{startCollectionID: {}} + + for len(queue) > 0 { + currentCollectionID := queue[0] + queue = queue[1:] + + outgoing, err := g.outgoing(currentCollectionID) + if err != nil { + return false, err + } + for _, attachedFunction := range outgoing { + outputCollectionID, err := g.outputCollectionID(attachedFunction) + if err != nil { + return false, err + } + if outputCollectionID == nil { + continue + } + if *outputCollectionID == targetCollectionID { + return true, nil + } + if _, ok := visited[*outputCollectionID]; !ok { + visited[*outputCollectionID] = struct{}{} + queue = append(queue, *outputCollectionID) + } + } + } + + return false, nil +} + +func (g *attachedFunctionGraphState) allCollectionIDs() []string { + collectionIDs := make(map[string]struct{}) + + for outputCollectionID, incoming := range g.upstreamFunctions { + collectionIDs[outputCollectionID] = struct{}{} + for _, attachedFunction := range incoming { + collectionIDs[attachedFunction.InputCollectionID] = struct{}{} + } + } + + for inputCollectionID, outgoing := range g.downstreamFunctions { + collectionIDs[inputCollectionID] = struct{}{} + for _, attachedFunction := range outgoing { + if attachedFunction.OutputCollectionID != nil { + collectionIDs[*attachedFunction.OutputCollectionID] = struct{}{} + } + } + } + + result := make([]string, 0, len(collectionIDs)) + for collectionID := range collectionIDs { + result = append(result, collectionID) + } + slices.Sort(result) + return result +} + +func (s *Coordinator) buildAttachFunctionGraph(ctx context.Context, inputCollectionID string, outputCollectionID string, databaseName string) (*attachedFunctionGraphState, error) { + graphState := newAttachedFunctionGraphState(ctx, s, databaseName) + if err := graphState.materializeIncoming(inputCollectionID, maxAttachedFunctionDepth, map[string]struct{}{}); err != nil { + return nil, err + } + if outputCollectionID != "" { + if err := graphState.materializeOutgoing(outputCollectionID, maxAttachedFunctionDepth, map[string]struct{}{}); err != nil { + return nil, err + } + } + return graphState, nil +} + +func (s *Coordinator) lockAttachFunctionGraph(ctx context.Context, graphState *attachedFunctionGraphState, inputCollectionID string, outputCollectionID string) error { + collectionIDsToLock := graphState.allCollectionIDs() + if len(collectionIDsToLock) == 0 { + collectionIDsToLock = []string{inputCollectionID} + if outputCollectionID != "" && outputCollectionID != inputCollectionID { + collectionIDsToLock = append(collectionIDsToLock, outputCollectionID) + slices.Sort(collectionIDsToLock) + } + } + + for _, collectionID := range collectionIDsToLock { + _, err := s.catalog.metaDomain.CollectionDb(ctx).LockCollection(collectionID) + if err != nil { + return err + } + } + return nil +} + // AttachFunction creates an output collection and attached function in a single transaction func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.AttachFunctionRequest) (*coordinatorpb.AttachFunctionResponse, error) { log := log.With(zap.String("method", "AttachFunction")) @@ -166,14 +474,40 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att return common.ErrCollectionNotFound } - // Check if input collection is being used as an output collection by any attached function - inputCollectionIDStr := req.InputCollectionId - attachedFunctionsUsingAsOutput, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetAttachedFunctions(nil, nil, nil, &inputCollectionIDStr, nil, false) + // Check if output collection already exists so we can materialize and then lock the full graph in a stable order. + existingOutputCollection, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &req.OutputCollectionName, req.TenantId, req.Database, nil, nil, false) + if err != nil { + log.Error("AttachFunction: failed to check for existing output collection", zap.Error(err)) + return err + } + + var existingOutputCollectionID string + if len(existingOutputCollection) > 0 { + existingOutputCollectionID = existingOutputCollection[0].Collection.ID + } + + graphState, err := s.buildAttachFunctionGraph(txCtx, req.InputCollectionId, existingOutputCollectionID, req.Database) if err != nil { - log.Error("AttachFunction: failed to check if input collection is used as output", zap.Error(err)) + log.Error("AttachFunction: failed to materialize attached function graph", zap.Error(err)) return err } + if err := s.lockAttachFunctionGraph(txCtx, graphState, req.InputCollectionId, existingOutputCollectionID); err != nil { + log.Error("AttachFunction: failed to lock attached function graph", zap.Error(err)) + return err + } + + // Rebuild the graph under locks before validating/inserting. + graphState, err = s.buildAttachFunctionGraph(txCtx, req.InputCollectionId, existingOutputCollectionID, req.Database) + if err != nil { + log.Error("AttachFunction: failed to rebuild attached function graph under locks", zap.Error(err)) + return err + } + + // Validate that the input collection can accept another upstream edge. + inputCollectionIDStr := req.InputCollectionId + attachedFunctionsUsingAsOutput := graphState.upstreamFunctions[inputCollectionIDStr] if len(attachedFunctionsUsingAsOutput) > 0 { + // Load each referenced function once so we can check its execution mode. functionIDs := make([]uuid.UUID, 0, len(attachedFunctionsUsingAsOutput)) seenFunctionIDs := make(map[uuid.UUID]struct{}, len(attachedFunctionsUsingAsOutput)) for _, attachedFunction := range attachedFunctionsUsingAsOutput { @@ -213,14 +547,48 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att } } - // Check if output collection already exists - existingOutputCollections, err := s.catalog.metaDomain.CollectionDb(txCtx).GetCollections(nil, &req.OutputCollectionName, req.TenantId, req.Database, nil, nil, false) + // Validate the output side of the new edge against the existing graph. + outputTailDepth := 0 + if len(existingOutputCollection) > 0 { + wouldCreateCycle, err := graphState.reaches(existingOutputCollectionID, req.InputCollectionId) + if err != nil { + log.Error("AttachFunction: failed while checking for attached function cycles", zap.Error(err)) + return err + } + if wouldCreateCycle { + log.Error("AttachFunction: cannot attach function because it would create a cycle", + zap.String("input_collection_id", req.InputCollectionId), + zap.String("output_collection_name", req.OutputCollectionName), + zap.String("output_collection_id", existingOutputCollectionID)) + return common.ErrCannotAttachToOutputCollection + } + + outputTailDepth, err = graphState.collectionTailDepth(existingOutputCollectionID, map[string]int{}, map[string]struct{}{}) + if err != nil { + log.Error("AttachFunction: failed to compute output collection downstream depth", zap.Error(err)) + return err + } + } + + // Enforce the maximum chain length after splicing in the new function. + inputCollectionDepth, err := graphState.collectionDepth(req.InputCollectionId, map[string]int{}, map[string]struct{}{}) if err != nil { - log.Error("AttachFunction: failed to check for existing output collection", zap.Error(err)) + log.Error("AttachFunction: failed to compute input collection depth", zap.Error(err)) return err } - if len(existingOutputCollections) > 0 { + totalAttachedFunctionDepth := inputCollectionDepth + 1 + outputTailDepth + if totalAttachedFunctionDepth > maxAttachedFunctionDepth { + log.Error("AttachFunction: attached function depth exceeds maximum", + zap.String("input_collection_id", req.InputCollectionId), + zap.Int("input_collection_depth", inputCollectionDepth), + zap.Int("output_tail_depth", outputTailDepth), + zap.Int("total_attached_function_depth", totalAttachedFunctionDepth), + zap.Int("max_attached_function_depth", maxAttachedFunctionDepth)) + return status.Errorf(codes.InvalidArgument, "attached function depth exceeds maximum of %d", maxAttachedFunctionDepth) + } + + if len(existingOutputCollection) > 0 { // Output collection exists - we now allow reusing any existing collection log.Info("AttachFunction: output collection already exists, will reuse it", zap.String("output_collection_name", req.OutputCollectionName)) diff --git a/go/pkg/sysdb/coordinator/task_graph_test.go b/go/pkg/sysdb/coordinator/task_graph_test.go new file mode 100644 index 00000000000..de82856baee --- /dev/null +++ b/go/pkg/sysdb/coordinator/task_graph_test.go @@ -0,0 +1,136 @@ +package coordinator + +import ( + "testing" + + "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func testAttachedFunctionEdge(inputCollectionID string, outputCollectionID *string) *dbmodel.AttachedFunction { + return &dbmodel.AttachedFunction{ + InputCollectionID: inputCollectionID, + OutputCollectionID: outputCollectionID, + } +} + +func TestAttachedFunctionGraphStateCollectionIDs(t *testing.T) { + t.Parallel() + + outputA := "output-a" + outputB := "output-b" + + graphState := &attachedFunctionGraphState{ + coordinator: &Coordinator{}, + upstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + "sink": { + testAttachedFunctionEdge("input-a", &outputA), + testAttachedFunctionEdge("input-b", &outputB), + }, + }, + downstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + "source": { + testAttachedFunctionEdge("source", &outputA), + testAttachedFunctionEdge("source", &outputB), + }, + }, + } + + incoming, err := graphState.incomingCollectionIDs("sink") + require.NoError(t, err) + require.Equal(t, []string{"input-a", "input-b"}, incoming) + + outgoing, err := graphState.outgoingCollectionIDs("source") + require.NoError(t, err) + require.Equal(t, []string{"output-a", "output-b"}, outgoing) +} + +func TestAttachedFunctionGraphStateDepthAndReachability(t *testing.T) { + t.Parallel() + + collectionB := "collection-b" + collectionC := "collection-c" + collectionD := "collection-d" + + graphState := &attachedFunctionGraphState{ + coordinator: &Coordinator{}, + upstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + collectionB: { + testAttachedFunctionEdge("collection-a", &collectionB), + }, + collectionC: { + testAttachedFunctionEdge(collectionB, &collectionC), + }, + collectionD: { + testAttachedFunctionEdge("collection-a", &collectionD), + }, + "collection-a": {}, + }, + downstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + "collection-a": { + testAttachedFunctionEdge("collection-a", &collectionB), + testAttachedFunctionEdge("collection-a", &collectionD), + }, + collectionB: { + testAttachedFunctionEdge(collectionB, &collectionC), + }, + collectionC: {}, + collectionD: {}, + }, + } + + depth, err := graphState.collectionDepth(collectionC, map[string]int{}, map[string]struct{}{}) + require.NoError(t, err) + require.Equal(t, 2, depth) + + tailDepth, err := graphState.collectionTailDepth("collection-a", map[string]int{}, map[string]struct{}{}) + require.NoError(t, err) + require.Equal(t, 2, tailDepth) + + reaches, err := graphState.reaches("collection-a", collectionC) + require.NoError(t, err) + require.True(t, reaches) + + reaches, err = graphState.reaches(collectionB, collectionD) + require.NoError(t, err) + require.False(t, reaches) +} + +func TestAttachedFunctionGraphStateDetectsCycles(t *testing.T) { + t.Parallel() + + collectionA := "collection-a" + collectionB := "collection-b" + + graphState := &attachedFunctionGraphState{ + coordinator: &Coordinator{}, + upstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + collectionA: { + testAttachedFunctionEdge(collectionB, &collectionA), + }, + collectionB: { + testAttachedFunctionEdge(collectionA, &collectionB), + }, + }, + downstreamFunctions: map[string][]*dbmodel.AttachedFunction{ + collectionA: { + testAttachedFunctionEdge(collectionA, &collectionB), + }, + collectionB: { + testAttachedFunctionEdge(collectionB, &collectionA), + }, + }, + } + + _, err := graphState.collectionDepth(collectionA, map[string]int{}, map[string]struct{}{}) + require.Error(t, err) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.Contains(t, err.Error(), "attached function cycle detected while computing depth") + + _, err = graphState.collectionTailDepth(collectionA, map[string]int{}, map[string]struct{}{}) + require.Error(t, err) + require.Equal(t, codes.FailedPrecondition, status.Code(err)) + require.Contains(t, err.Error(), "attached function cycle detected while computing downstream depth") +}