Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cache/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# etcd cache

Experimental etcd client cache library.

**Note:** gRPC proxy is not supported. The cache relies on `RequestProgress` RPCs, which the gRPC proxy does not forward.
108 changes: 86 additions & 22 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

pb "go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
)

Expand All @@ -31,24 +32,33 @@ var (
ErrUnsupportedRequest = errors.New("cache: unsupported request parameters")
// Returned when the requested key or key‑range is invalid (empty or reversed) or lies outside c.prefix.
ErrKeyRangeInvalid = errors.New("cache: invalid or out‑of‑range key range")
// Returned when the cache timed out waiting for the requested revision
ErrCacheTimeout = errors.New("cache: timed out waiting for revision")
)

// Cache buffers a single etcd Watch for a given key‐prefix and fan‑outs local watchers.
//
// Note: gRPC proxy is not supported. Cache relies on RequestProgress RPCs,
// which the gRPC proxy does not forward.
type Cache struct {
prefix string // prefix is the key-prefix this shard is responsible for ("" = root).
cfg Config // immutable runtime configuration
watcher clientv3.Watcher
kv clientv3.KV
demux *demux // demux fans incoming events out to active watchers and manages resync.
store *store // last‑observed snapshot
ready *ready
stop context.CancelFunc
waitGroup sync.WaitGroup
internalCtx context.Context
prefix string // prefix is the key-prefix this shard is responsible for ("" = root).
cfg Config // immutable runtime configuration
watcher clientv3.Watcher
kv clientv3.KV
demux *demux // demux fans incoming events out to active watchers and manages resync.
store *store // last‑observed snapshot
ready *ready
stop context.CancelFunc
waitGroup sync.WaitGroup
internalCtx context.Context
progressRequestor progressRequestor
}

// New builds a cache shard that watches only the requested prefix.
// For the root cache pass "".
//
// Note: gRPC proxy is not supported. Cache relies on RequestProgress RPCs,
// which the gRPC proxy does not forward.
func New(client *clientv3.Client, prefix string, opts ...Option) (*Cache, error) {
cfg := defaultConfig()
for _, opt := range opts {
Expand All @@ -65,23 +75,28 @@ func New(client *clientv3.Client, prefix string, opts ...Option) (*Cache, error)
internalCtx, cancel := context.WithCancel(context.Background())

cache := &Cache{
prefix: prefix,
cfg: cfg,
watcher: client.Watcher,
kv: client.KV,
store: newStore(cfg.BTreeDegree, cfg.HistoryWindowSize),
ready: newReady(),
stop: cancel,
internalCtx: internalCtx,
prefix: prefix,
cfg: cfg,
watcher: client.Watcher,
kv: client.KV,
store: newStore(cfg.BTreeDegree, cfg.HistoryWindowSize),
ready: newReady(),
stop: cancel,
internalCtx: internalCtx,
progressRequestor: newConditionalProgressRequestor(client.Watcher, realClock{}, cfg.ProgressRequestInterval),
}

cache.demux = NewDemux(internalCtx, &cache.waitGroup, cfg.HistoryWindowSize, cfg.ResyncInterval)

cache.waitGroup.Add(1)
cache.waitGroup.Add(2)
go func() {
defer cache.waitGroup.Done()
cache.getWatchLoop()
}()
go func() {
defer cache.waitGroup.Done()
cache.progressRequestor.run(internalCtx)
}()

return cache, nil
}
Expand Down Expand Up @@ -161,6 +176,19 @@ func (c *Cache) Get(ctx context.Context, key string, opts ...clientv3.OpOption)
endKey := op.RangeBytes()
requestedRev := op.Rev()

if !op.IsSerializable() {
serverRev, err := c.serverRevision(ctx)
if err != nil {
return nil, err
}
if requestedRev > serverRev {
return nil, rpctypes.ErrFutureRev
}
if err = c.waitTillRevision(ctx, serverRev); err != nil {
return nil, err
}
}

kvs, latestRev, err := c.store.Get(startKey, endKey, requestedRev)
if err != nil {
return nil, err
Expand Down Expand Up @@ -196,6 +224,45 @@ func (c *Cache) WaitForRevision(ctx context.Context, rev int64) error {
}
}

func (c *Cache) serverRevision(ctx context.Context) (int64, error) {
key := c.prefix
if key == "" {
key = "/"
}
resp, err := c.kv.Get(ctx, key, clientv3.WithLimit(1), clientv3.WithCountOnly())
if err != nil {
return 0, err
}
return resp.Header.Revision, nil
}

func (c *Cache) waitTillRevision(ctx context.Context, rev int64) error {
if c.store.LatestRev() >= rev {
return nil
}

c.progressRequestor.add()
defer c.progressRequestor.remove()

ticker := time.NewTicker(revisionPollInterval)
defer ticker.Stop()
timeout := time.After(c.cfg.WaitTimeout)

// TODO: rewrite from periodic polling to passive notification
for {
if c.store.LatestRev() >= rev {
return nil
}
select {
case <-ticker.C:
case <-timeout:
return ErrCacheTimeout
case <-ctx.Done():
return ctx.Err()
}
}
}

// Close cancels the private context and blocks until all goroutines return.
func (c *Cache) Close() {
c.stop()
Expand Down Expand Up @@ -358,9 +425,6 @@ func (c *Cache) validateGet(key string, op clientv3.Op) (KeyPredicate, error) {
return nil, fmt.Errorf("%w: MinCreateRev(%d) not supported", ErrUnsupportedRequest, op.MinCreateRev())
case op.MaxCreateRev() != 0:
return nil, fmt.Errorf("%w: MaxCreateRev(%d) not supported", ErrUnsupportedRequest, op.MaxCreateRev())
// cache now only serves serializable reads of the latest revision (rev == 0).
case !op.IsSerializable():
return nil, fmt.Errorf("%w: non-serializable request", ErrUnsupportedRequest)
}

startKey := []byte(key)
Expand Down
180 changes: 178 additions & 2 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package cache

import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -501,6 +503,7 @@ type mockWatcher struct {
wg sync.WaitGroup
mu sync.Mutex
lastStartRev int64
progressErr error
}

func newMockWatcher(buf int) *mockWatcher {
Expand All @@ -522,7 +525,7 @@ func (m *mockWatcher) Watch(ctx context.Context, _ string, opts ...clientv3.OpOp
return out
}

func (m *mockWatcher) RequestProgress(_ context.Context) error { return nil }
func (m *mockWatcher) RequestProgress(_ context.Context) error { return m.progressErr }

func (m *mockWatcher) Close() error {
m.closeOnce.Do(func() { close(m.responses) })
Expand Down Expand Up @@ -600,6 +603,7 @@ func (m *mockWatcher) streamResponses(ctx context.Context, out chan<- clientv3.W
type kvStub struct {
queued []*clientv3.GetResponse
defaultResp *clientv3.GetResponse
defaultErr error
}

func newKVStub(resps ...*clientv3.GetResponse) *kvStub {
Expand All @@ -610,7 +614,11 @@ func newKVStub(resps ...*clientv3.GetResponse) *kvStub {
}
}

func (s *kvStub) Get(ctx context.Context, key string, _ ...clientv3.OpOption) (*clientv3.GetResponse, error) {
func (s *kvStub) Get(_ context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
if s.defaultErr != nil {
return nil, s.defaultErr
}

if len(s.queued) > 0 {
next := s.queued[0]
s.queued = s.queued[1:]
Expand Down Expand Up @@ -692,3 +700,171 @@ func verifySnapshot(t *testing.T, cache *Cache, want []*mvccpb.KeyValue) {
t.Fatalf("cache snapshot mismatch (-want +got):\n%s", diff)
}
}

type noopProgressNotifier struct{}

func (n *noopProgressNotifier) RequestProgress(_ context.Context) error {
return nil
}

func newTestProgressRequestor() *conditionalProgressRequestor {
return newConditionalProgressRequestor(&noopProgressNotifier{}, realClock{}, 100*time.Millisecond)
}

func newCacheForWaitTest(serverRev int64, localRev int64, pr progressRequestor) (*Cache, *store) {
cfg := defaultConfig()
st := newStore(cfg.BTreeDegree, cfg.HistoryWindowSize)
if localRev > 0 {
st.Restore(nil, localRev)
}
kv := &kvStub{
defaultResp: &clientv3.GetResponse{Header: &pb.ResponseHeader{Revision: serverRev}},
}
return &Cache{
kv: kv,
store: st,
prefix: "/",
progressRequestor: pr,
cfg: cfg,
}, st
}

func TestWaitTillRevision(t *testing.T) {
t.Run("cache_already_caught_up", func(t *testing.T) {
c, _ := newCacheForWaitTest(10, 10, newTestProgressRequestor())

if err := c.waitTillRevision(context.Background(), 10); err != nil {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if waitTillRevision has a bug this test will block for very long. Would be good to have some timeout based on expected time.

t.Fatalf("unexpected error: %v", err)
}
})

t.Run("local_rev_sufficient_skips_server_call", func(t *testing.T) {
cfg := defaultConfig()
st := newStore(cfg.BTreeDegree, cfg.HistoryWindowSize)
st.Restore(nil, 10)
c := &Cache{
kv: &kvStub{defaultErr: fmt.Errorf("should not be called")},
store: st,
prefix: "/",
progressRequestor: newTestProgressRequestor(),
cfg: cfg,
}

if err := c.waitTillRevision(context.Background(), 5); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("cache_catches_up", func(t *testing.T) {
c, st := newCacheForWaitTest(15, 5, newTestProgressRequestor())

go func() {
time.Sleep(200 * time.Millisecond)
st.Restore(nil, 10)
}()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := c.waitTillRevision(ctx, 10); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("rev_zero_cache_caught_up", func(t *testing.T) {
c, _ := newCacheForWaitTest(10, 10, newTestProgressRequestor())

if err := c.waitTillRevision(context.Background(), 0); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("rev_zero_waits_for_server_rev", func(t *testing.T) {
c, st := newCacheForWaitTest(10, 5, newTestProgressRequestor())

go func() {
time.Sleep(200 * time.Millisecond)
st.Restore(nil, 10)
}()

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := c.waitTillRevision(ctx, 0); err != nil {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("context_cancelled", func(t *testing.T) {
c, _ := newCacheForWaitTest(10, 5, newTestProgressRequestor())

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
err := c.waitTillRevision(ctx, 10)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("got %v, want context.DeadlineExceeded", err)
}
})

t.Run("timeout", func(t *testing.T) {
c, _ := newCacheForWaitTest(10, 5, newTestProgressRequestor())

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test just waits 10 seconds on every execution. First that's too long time for happy path, second maybe using https://go.dev/blog/synctest would help with test design here.

defer cancel()
err := c.waitTillRevision(ctx, 10)
if !errors.Is(err, ErrCacheTimeout) {
t.Fatalf("got %v, want ErrCacheTimeout", err)
}
})
}

func TestWaitTillRevisionTriggersProgressRequests(t *testing.T) {
fc := newFakeClock()
pr := newTestConditionalProgressRequestor(fc, 50*time.Millisecond)
c, st := newCacheForWaitTest(15, 5, pr)

// Start progress requestor
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go pr.run(ctx)

// Wait for goroutine to start
time.Sleep(10 * time.Millisecond)

// Initially, no progress requests should be sent (no waiters)
fc.Advance(100 * time.Millisecond)
if err := pollConditionNoChange(func() bool {
return pr.progressRequestsSentCount.Load() == 0
}); err != nil {
t.Fatal("expected no progress requests without active waiters")
}

// Start waiting - this should trigger progress requests
errCh := make(chan error, 1)
go func() {
errCh <- c.waitTillRevision(context.Background(), 10)
}()

// Advance time and wait for progress requests to start
fc.Advance(50 * time.Millisecond)
time.Sleep(10 * time.Millisecond)

// Verify progress requests are being sent while waiting
if pr.progressRequestsSentCount.Load() == 0 {
t.Fatal("expected progress requests during wait")
}

// Complete the wait
st.Restore(nil, 15)

if err := <-errCh; err != nil {
t.Fatalf("unexpected error: %v", err)
}

// After completion, progress requests should stop
finalCount := pr.progressRequestsSentCount.Load()
fc.Advance(100 * time.Millisecond)
if err := pollConditionNoChange(func() bool {
return pr.progressRequestsSentCount.Load() == finalCount
}); err != nil {
t.Fatalf("expected no new progress requests after completion, got %d initially, then changed", finalCount)
}
}
Loading
Loading