Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions addons/processors/iceberg-processor/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ type Decoder interface {
Decode(ctx context.Context, segmentKey, indexKey string, topic string, partition int32) ([]Record, error)
}

type getObjectAPI interface {
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
}

// New returns an S3-backed decoder.
func New(cfg config.Config) (Decoder, error) {
loadOptions := []func(*awsconfig.LoadOptions) error{}
Expand Down Expand Up @@ -93,19 +97,12 @@ func New(cfg config.Config) (Decoder, error) {
}

type s3Decoder struct {
client *s3.Client
client getObjectAPI
bucket string
}

func (d *s3Decoder) Decode(ctx context.Context, segmentKey, indexKey string, topic string, partition int32) ([]Record, error) {
indexBytes, err := d.getObject(ctx, indexKey)
if err != nil {
return nil, fmt.Errorf("download index: %w", err)
}
if _, err := parseIndex(indexBytes); err != nil {
return nil, err
}

// Decode only needs the segment payload; segment/index pairing is validated during discovery.
func (d *s3Decoder) Decode(ctx context.Context, segmentKey, _ string, topic string, partition int32) ([]Record, error) {
segmentBytes, err := d.getObject(ctx, segmentKey)
if err != nil {
return nil, fmt.Errorf("download segment: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ package decoder

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
)

func TestDecodeSegment(t *testing.T) {
Expand Down Expand Up @@ -62,6 +68,30 @@ func TestParseIndex(t *testing.T) {
}
}

func TestDecodeSkipsIndexDownload(t *testing.T) {
segment := buildSegmentBytes(10, 1, time.Now().UnixMilli(), buildRecordBatch(10, time.Now().UnixMilli(), []byte("k1"), []byte("v1")))
client := &fakeGetObjectClient{
objects: map[string][]byte{
"segment.kfs": segment,
},
}
dec := &s3Decoder{
client: client,
bucket: "test-bucket",
}

records, err := dec.Decode(context.Background(), "segment.kfs", "segment.index", "orders", 0)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if len(records) != 1 {
t.Fatalf("expected 1 record, got %d", len(records))
}
if len(client.requests) != 1 || client.requests[0] != "segment.kfs" {
t.Fatalf("unexpected requests: %+v", client.requests)
}
}

func buildIndexBytes(count int32) []byte {
buf := bytes.NewBuffer(make([]byte, 0, 64))
buf.WriteString(indexMagic)
Expand Down Expand Up @@ -152,3 +182,18 @@ func encodeVarint(value int64) []byte {
}
return out
}

type fakeGetObjectClient struct {
objects map[string][]byte
requests []string
}

func (f *fakeGetObjectClient) GetObject(ctx context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
key := aws.ToString(params.Key)
f.requests = append(f.requests, key)
data, ok := f.objects[key]
if !ok {
return nil, errors.New("missing object")
}
return &s3.GetObjectOutput{Body: io.NopCloser(bytes.NewReader(data))}, nil
}
18 changes: 7 additions & 11 deletions addons/processors/sql-processor/internal/decoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ type Decoder interface {
Decode(ctx context.Context, segmentKey, indexKey string, topic string, partition int32) ([]Record, error)
}

type getObjectAPI interface {
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error)
}

func New(cfg config.Config) (Decoder, error) {
loadOptions := []func(*awsconfig.LoadOptions) error{}
if cfg.S3.Region != "" {
Expand Down Expand Up @@ -93,7 +97,7 @@ func New(cfg config.Config) (Decoder, error) {
}

type s3Decoder struct {
client *s3.Client
client getObjectAPI
bucket string
metrics s3Metrics
}
Expand All @@ -116,16 +120,8 @@ func newS3Metrics() s3Metrics {
}
}

func (d *s3Decoder) Decode(ctx context.Context, segmentKey, indexKey string, topic string, partition int32) ([]Record, error) {
indexBytes, err := d.getObject(ctx, "get", indexKey)
if err != nil {
return nil, fmt.Errorf("download index: %w", err)
}
if _, err := parseIndex(indexBytes); err != nil {
d.metrics.decodeErrors.Inc()
return nil, err
}

// Decode only needs the segment payload; segment/index pairing is validated during discovery.
func (d *s3Decoder) Decode(ctx context.Context, segmentKey, _ string, topic string, partition int32) ([]Record, error) {
segmentBytes, err := d.getObject(ctx, "get", segmentKey)
if err != nil {
return nil, fmt.Errorf("download segment: %w", err)
Expand Down
46 changes: 46 additions & 0 deletions addons/processors/sql-processor/internal/decoder/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ package decoder

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
)

func TestParseIndex(t *testing.T) {
Expand Down Expand Up @@ -82,6 +88,31 @@ func TestDecodeBatchCompressed(t *testing.T) {
}
}

func TestDecodeSkipsIndexDownload(t *testing.T) {
segment := buildSegment(buildBatch(5, 1000, buildRecord(0, 0, []byte("k"), []byte("v"))))
client := &fakeGetObjectClient{
objects: map[string][]byte{
"segment.kfs": segment,
},
}
dec := &s3Decoder{
client: client,
bucket: "test-bucket",
metrics: newS3Metrics(),
}

records, err := dec.Decode(context.Background(), "segment.kfs", "segment.index", "orders", 0)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if len(records) != 1 {
t.Fatalf("expected 1 record, got %d", len(records))
}
if len(client.requests) != 1 || client.requests[0] != "segment.kfs" {
t.Fatalf("unexpected requests: %+v", client.requests)
}
}

func buildSegment(batch []byte) []byte {
segment := make([]byte, 0, segmentHeaderLen+len(batch)+segmentFooterLen)
header := make([]byte, segmentHeaderLen)
Expand Down Expand Up @@ -150,3 +181,18 @@ func makeRecordPayload(tsDelta int32, offsetDelta int32, key []byte, value []byt
writeVarint(&body, 0)
return body.Bytes()
}

type fakeGetObjectClient struct {
objects map[string][]byte
requests []string
}

func (f *fakeGetObjectClient) GetObject(ctx context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
key := aws.ToString(params.Key)
f.requests = append(f.requests, key)
data, ok := f.objects[key]
if !ok {
return nil, errors.New("missing object")
}
return &s3.GetObjectOutput{Body: io.NopCloser(bytes.NewReader(data))}, nil
}
4 changes: 2 additions & 2 deletions cmd/kafscale-cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func executeRestore(ctx context.Context, stdout io.Writer, cfg restoreConfig, s3
return err
}
if len(sourceMeta.Topics) == 0 || sourceMeta.Topics[0].ErrorCode != 0 {
return metadata.ErrUnknownTopic
return fmt.Errorf("source topic metadata: %w", metadata.ErrUnknownTopic)
}

sourcePartitions := make(map[int32]struct{}, len(sourceMeta.Topics[0].Partitions))
Expand Down Expand Up @@ -238,7 +238,7 @@ func executeRestore(ctx context.Context, stdout io.Writer, cfg restoreConfig, s3
return err
}
if len(targetMeta.Topics) == 0 || targetMeta.Topics[0].ErrorCode != 0 {
return metadata.ErrUnknownTopic
return fmt.Errorf("target topic metadata: %w", metadata.ErrUnknownTopic)
}

recoveredByPartition := make(map[int32]storage.RecoveredPartition, len(result.Partitions))
Expand Down
105 changes: 85 additions & 20 deletions cmd/kafscale-cli/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package main
import (
"bytes"
"context"
"encoding/binary"
"errors"
"hash/crc32"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -95,12 +97,7 @@ func TestRunRestoreCommandUsesInjectedClients(t *testing.T) {

s3 := storage.NewMemoryS3Client()
artifact, err := storage.BuildSegment(storage.SegmentWriterConfig{IndexIntervalMessages: 1}, []storage.RecordBatch{
{
BaseOffset: 0,
LastOffsetDelta: 0,
MessageCount: 1,
Bytes: make([]byte, 70),
},
makeStorageRecoveryBatch(0, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC).UnixMilli(), []int64{0}),
}, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("BuildSegment: %v", err)
Expand All @@ -123,8 +120,13 @@ func TestRunRestoreCommandUsesInjectedClients(t *testing.T) {
newS3Client = func(context.Context, storage.S3Config) (storage.S3Client, error) {
return s3, nil
}
newEtcdStore = func(context.Context, metadata.ClusterMetadata, metadata.EtcdStoreConfig) (*metadata.EtcdStore, error) {
return store, nil
newEtcdStore = func(ctx context.Context, _ metadata.ClusterMetadata, cfg metadata.EtcdStoreConfig) (*metadata.EtcdStore, error) {
return metadata.NewEtcdStore(ctx, metadata.ClusterMetadata{
Brokers: []protocol.MetadataBroker{
{NodeID: 1, Host: "broker-0", Port: 9092},
},
ControllerID: 1,
}, cfg)
}
newMemoryS3 = func() storage.S3Client { return s3 }

Expand Down Expand Up @@ -237,12 +239,7 @@ func TestExecuteRestoreCreatesRecoveredTopic(t *testing.T) {

s3 := storage.NewMemoryS3Client()
artifact, err := storage.BuildSegment(storage.SegmentWriterConfig{IndexIntervalMessages: 1}, []storage.RecordBatch{
{
BaseOffset: 0,
LastOffsetDelta: 0,
MessageCount: 1,
Bytes: make([]byte, 70),
},
makeStorageRecoveryBatch(0, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC).UnixMilli(), []int64{0}),
}, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("BuildSegment: %v", err)
Expand Down Expand Up @@ -460,12 +457,7 @@ func TestExecuteRestoreRollsBackCopiedS3ObjectsOnPartialFailure(t *testing.T) {

mem := storage.NewMemoryS3Client()
artifact, err := storage.BuildSegment(storage.SegmentWriterConfig{IndexIntervalMessages: 1}, []storage.RecordBatch{
{
BaseOffset: 0,
LastOffsetDelta: 0,
MessageCount: 1,
Bytes: make([]byte, 70),
},
makeStorageRecoveryBatch(0, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC).UnixMilli(), []int64{0}),
}, time.Date(2026, 5, 13, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("BuildSegment: %v", err)
Expand Down Expand Up @@ -508,3 +500,76 @@ func TestExecuteRestoreRollsBackCopiedS3ObjectsOnPartialFailure(t *testing.T) {
t.Fatalf("expected rolled back target topic to be absent, got %+v", meta.Topics)
}
}

func makeStorageRecoveryBatch(baseOffset, firstTimestamp int64, timestampDeltas []int64) storage.RecordBatch {
records := make([][]byte, 0, len(timestampDeltas))
maxTimestamp := firstTimestamp
for i, delta := range timestampDeltas {
records = append(records, makeStorageRecoveryRecord(delta, int64(i)))
if ts := firstTimestamp + delta; ts > maxTimestamp {
maxTimestamp = ts
}
}

bodyLen := 0
for _, record := range records {
bodyLen += len(record)
}
const recordBatchHeaderLen = 61
const batchFrameHeaderLen = 12
batch := make([]byte, recordBatchHeaderLen+bodyLen)
binary.BigEndian.PutUint64(batch[0:8], uint64(baseOffset))
binary.BigEndian.PutUint32(batch[8:12], uint32(len(batch)-batchFrameHeaderLen))
batch[16] = 2
binary.BigEndian.PutUint64(batch[27:35], uint64(firstTimestamp))
binary.BigEndian.PutUint64(batch[35:43], uint64(maxTimestamp))
binary.BigEndian.PutUint64(batch[43:51], uint64(^uint64(0)))
binary.BigEndian.PutUint16(batch[51:53], uint16(^uint16(0)))
binary.BigEndian.PutUint32(batch[53:57], uint32(^uint32(0)))
binary.BigEndian.PutUint32(batch[57:61], uint32(len(records)))
offset := recordBatchHeaderLen
for _, record := range records {
copy(batch[offset:], record)
offset += len(record)
}
binary.BigEndian.PutUint32(batch[23:27], uint32(len(records)-1))
binary.BigEndian.PutUint32(batch[17:21], crc32.Checksum(batch[21:], crc32.MakeTable(crc32.Castagnoli)))

return storage.RecordBatch{
BaseOffset: baseOffset,
LastOffsetDelta: int32(len(records) - 1),
MessageCount: int32(len(records)),
Bytes: batch,
}
}

func makeStorageRecoveryRecord(timestampDelta, offsetDelta int64) []byte {
payload := bytes.NewBuffer(nil)
payload.WriteByte(0)
payload.Write(encodeStorageRecoveryVarint(timestampDelta))
payload.Write(encodeStorageRecoveryVarint(offsetDelta))
payload.Write(encodeStorageRecoveryVarint(-1))
payload.Write(encodeStorageRecoveryVarint(-1))
payload.Write(encodeStorageRecoveryVarint(0))

record := bytes.NewBuffer(nil)
record.Write(encodeStorageRecoveryVarint(int64(payload.Len())))
record.Write(payload.Bytes())
return record.Bytes()
}

func encodeStorageRecoveryVarint(value int64) []byte {
zigzag := uint64(value<<1) ^ uint64(value>>63)
out := make([]byte, 0, 10)
for {
b := byte(zigzag & 0x7f)
zigzag >>= 7
if zigzag != 0 {
b |= 0x80
}
out = append(out, b)
if zigzag == 0 {
return out
}
}
}
Loading
Loading