Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3047,6 +3047,150 @@ func (suite *BulkIngestTests) TestBulkIngestWithStream() {
suite.Equal(int64(5), totalRows)
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamBeforeOptions() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec1 := bldr.NewRecordBatch()
bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{2, 3}, nil)
rec2 := bldr.NewRecordBatch()
defer rec1.Release()
defer rec2.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec1, rec2})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_first"))
suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate))

nRows, err := stmt.ExecuteUpdate(context.Background())
suite.Require().NoError(err)
suite.Equal(int64(3), nRows)

requests := suite.server.GetIngestRequests()
suite.Require().Len(requests, 1)
suite.Equal("bind_first", requests[0].GetTable())
}

func (suite *BulkIngestTests) TestBulkIngestBindBeforeOptions() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{10, 20}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

suite.Require().NoError(stmt.Bind(context.Background(), rec))

suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bind_batch_first"))
suite.Require().NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate))

nRows, err := stmt.ExecuteUpdate(context.Background())
suite.Require().NoError(err)
suite.Equal(int64(2), nRows)

requests := suite.server.GetIngestRequests()
suite.Require().Len(requests, 1)
suite.Equal("bind_batch_first", requests[0].GetTable())
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTarget() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

_, err = stmt.ExecuteUpdate(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestBindMissingTarget() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

suite.Require().NoError(stmt.Bind(context.Background(), rec))

_, err = stmt.ExecuteUpdate(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestBindStreamMissingTargetExecuteQuery() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
defer validation.CheckedClose(suite.T(), stmt)

schema := arrow.NewSchema([]arrow.Field{
{Name: "batch_id", Type: arrow.PrimitiveTypes.Int32, Nullable: false},
}, nil)

bldr := array.NewRecordBuilder(memory.DefaultAllocator, schema)
defer bldr.Release()

bldr.Field(0).(*array.Int32Builder).AppendValues([]int32{1}, nil)
rec := bldr.NewRecordBatch()
defer rec.Release()

rdr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec})
suite.Require().NoError(err)
defer rdr.Release()

suite.Require().NoError(stmt.BindStream(context.Background(), rdr))

_, _, err = stmt.ExecuteQuery(context.Background())
suite.Require().Error(err)
suite.Contains(err.Error(), "must set IngestTargetTable before bulk ingestion")
}

func (suite *BulkIngestTests) TestBulkIngestWithoutBind() {
stmt, err := suite.cnxn.NewStatement()
suite.Require().NoError(err)
Expand Down
46 changes: 40 additions & 6 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,14 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n
return nil, -1, err
}

// Reject staged binds if no ingest target was provided
if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) {
return nil, -1, adbc.Error{
Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion",
Code: adbc.StatusInvalidState,
}
}

// Handle bulk ingest
if s.targetTable != "" {
nrec, err = s.executeIngest(ctx)
Expand Down Expand Up @@ -535,6 +543,14 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
return -1, err
}

// Reject staged binds if no ingest target was provided
if s.targetTable == "" && s.prepared == nil && (s.bound != nil || s.streamBind != nil) {
return -1, adbc.Error{
Msg: "[Flight SQL Statement] must set IngestTargetTable before bulk ingestion",
Code: adbc.StatusInvalidState,
}
}

// Handle bulk ingest
if s.targetTable != "" {
return s.executeIngest(ctx)
Expand Down Expand Up @@ -617,9 +633,18 @@ func (s *statement) Bind(_ context.Context, values arrow.RecordBatch) error {
}

if s.prepared == nil {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

return adbc.Error{
Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind",
Code: adbc.StatusInvalidState}
if s.streamBind != nil {
s.streamBind.Release()
s.streamBind = nil
}
if s.bound != nil {
s.bound.Release()
}
s.bound = values
if s.bound != nil {
s.bound.Retain()
}
return nil
}

// calls retain
Expand Down Expand Up @@ -650,9 +675,18 @@ func (s *statement) BindStream(_ context.Context, stream array.RecordReader) err
}

if s.prepared == nil {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we combine this clause with the one above?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure let me know if 3c38988 did the job.

return adbc.Error{
Msg: "[Flight SQL Statement] must call Prepare or set IngestTargetTable before calling Bind",
Code: adbc.StatusInvalidState}
if s.bound != nil {
s.bound.Release()
s.bound = nil
}
if s.streamBind != nil {
s.streamBind.Release()
}
s.streamBind = stream
if s.streamBind != nil {
s.streamBind.Retain()
}
return nil
}

// calls retain
Expand Down
24 changes: 12 additions & 12 deletions go/adbc/ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,19 @@ func IngestStream(ctx context.Context, cnxn Connection, reader array.RecordReade
err = errors.Join(err, stmt.Close())
}()

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set required options
// Set required options before binding
if err = stmt.SetOption(OptionKeyIngestTargetTable, targetTable); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err)
}
if err = stmt.SetOption(OptionKeyIngestMode, ingestMode); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err)
}

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(OptionValueIngestTargetCatalog, opt.Catalog); err != nil {
Expand Down Expand Up @@ -167,19 +167,19 @@ func IngestStreamContext(ctx context.Context, cnxn ConnectionWithContext, reader
err = errors.Join(err, stmt.Close(ctx))
}()

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set required options
// Set required options before binding (some drivers require target first)
if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(target_table=%s): %w", targetTable, err)
}
if err = stmt.SetOption(ctx, OptionKeyIngestMode, ingestMode); err != nil {
return -1, fmt.Errorf("error during ingestion: SetOption(mode=%s): %w", ingestMode, err)
}

// Bind the record batch stream
if err = stmt.BindStream(ctx, reader); err != nil {
return -1, fmt.Errorf("error during ingestion: BindStream: %w", err)
}

// Set other options if provided
if opt.Catalog != "" {
if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog, opt.Catalog); err != nil {
Expand Down