Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 35 additions & 43 deletions token/services/selector/sherdlock/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,17 @@ const (
defaultCacheMaxQueries = maxImmediateRetries
)

type tokenFetcher interface {
UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (iterator[*token2.UnspentTokenInWallet], error)
}

//go:generate counterfeiter -o mock/tokendb.go -fake-name TokenDB . TokenDB
type TokenDB interface {
SpendableTokensIteratorBy(ctx context.Context, walletID string, typ token2.Type) (driver.SpendableTokensIterator, error)
}

type enhancedIterator[T any] interface {
HasNext() bool
}

type permutatableIterator[T any] interface {
iterators.Iterator[T]
NewPermutation() iterators.Iterator[T]
}

type FetcherStrategy string

const (
Lazy = "lazy"
Eager = "eager"
Mixed = "mixed"
Listener = "listener"
Cached = "cached"
Lazy FetcherStrategy = "lazy"
Eager FetcherStrategy = "eager"
Mixed FetcherStrategy = "mixed"
Listener FetcherStrategy = "listener"
Cached FetcherStrategy = "cached"
)

type FetcherProvider interface {
GetFetcher(tmsID token.TMSID) (tokenFetcher, error)
}

type fetchFunc func(db *tokendb.StoreService, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) tokenFetcher
type fetchFunc func(db *tokendb.StoreService, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) TokenFetcher

type fetcherProvider struct {
tokenStoreServiceManager tokendb.StoreServiceManager
Expand All @@ -71,7 +49,7 @@ type fetcherProvider struct {
}

var fetchers = map[FetcherStrategy]fetchFunc{
Mixed: func(db *tokendb.StoreService, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) tokenFetcher {
Mixed: func(db *tokendb.StoreService, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) TokenFetcher {
return newMixedFetcher(db, m, cacheSize, freshnessInterval, maxQueries)
},
}
Expand All @@ -94,7 +72,7 @@ func NewFetcherProvider(storeServiceManager tokendb.StoreServiceManager, metrics
}

// GetFetcher returns a token fetcher instance for the specified TMS ID.
func (p *fetcherProvider) GetFetcher(tmsID token.TMSID) (tokenFetcher, error) {
func (p *fetcherProvider) GetFetcher(tmsID token.TMSID) (TokenFetcher, error) {
tokenDB, err := p.tokenStoreServiceManager.StoreServiceByTMSId(tmsID)
if err != nil {
return nil, err
Expand All @@ -113,21 +91,21 @@ type mixedFetcher struct {
m *Metrics
}

// newMixedFetcher creates a fetcher that combines eager (cached) and lazy (on-demand) strategies.
func newMixedFetcher(tokenDB TokenDB, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) *mixedFetcher {
// NewMixedFetcher creates a fetcher that combines eager (cached) and lazy (on-demand) strategies.
func NewMixedFetcher(tokenDB TokenDB, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) *mixedFetcher {
return &mixedFetcher{
lazyFetcher: NewLazyFetcher(tokenDB),
eagerFetcher: newCachedFetcher(tokenDB, cacheSize, freshnessInterval, maxQueries),
eagerFetcher: NewCachedFetcher(tokenDB, cacheSize, freshnessInterval, maxQueries),
m: m,
}
}

// UnspentTokensIteratorBy returns an iterator for unspent tokens, trying cached results first, falling back to database query.
func (f *mixedFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (iterator[*token2.UnspentTokenInWallet], error) {
func (f *mixedFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (Iterator[*token2.UnspentTokenInWallet], error) {
logger.DebugfContext(ctx, "call unspent tokens iterator")
it, err := f.eagerFetcher.UnspentTokensIteratorBy(ctx, walletID, currency)
logger.DebugfContext(ctx, "fetched eager iterator")
if err == nil && it.(enhancedIterator[*token2.UnspentTokenInWallet]).HasNext() {
if err == nil && it.(interface{ HasNext() bool }).HasNext() {
logger.DebugfContext(ctx, "eager iterator had tokens. Returning iterator")
f.m.UnspentTokensInvocations.With(fetcherTypeLabel, eager).Add(1)

Expand All @@ -140,6 +118,11 @@ func (f *mixedFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID str
return f.lazyFetcher.UnspentTokensIteratorBy(ctx, walletID, currency)
}

// newMixedFetcher is an internal alias for NewMixedFetcher.
func newMixedFetcher(tokenDB TokenDB, m *Metrics, cacheSize int64, freshnessInterval time.Duration, maxQueries int) *mixedFetcher {
return NewMixedFetcher(tokenDB, m, cacheSize, freshnessInterval, maxQueries)
}

// lazyFetcher only looks up the results when requested
type lazyFetcher struct {
tokenDB TokenDB
Expand All @@ -151,7 +134,7 @@ func NewLazyFetcher(tokenDB TokenDB) *lazyFetcher {
}

// UnspentTokensIteratorBy queries the database directly for unspent tokens.
func (f *lazyFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (iterator[*token2.UnspentTokenInWallet], error) {
func (f *lazyFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (Iterator[*token2.UnspentTokenInWallet], error) {
logger.DebugfContext(ctx, "Query the DB for new tokens")
it, err := f.tokenDB.SpendableTokensIteratorBy(ctx, walletID, currency)
if err != nil {
Expand All @@ -161,6 +144,11 @@ func (f *lazyFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID stri
return collections.NewPermutatedIterator[token2.UnspentTokenInWallet](it)
}

type permutatableIterator[T any] interface {
iterators.Iterator[T]
NewPermutation() iterators.Iterator[T]
}

type tokenCache interface {
Get(key string) (permutatableIterator[*token2.UnspentTokenInWallet], bool)
Add(key string, value permutatableIterator[*token2.UnspentTokenInWallet])
Expand All @@ -178,15 +166,15 @@ type cachedFetcher struct {
maxQueriesBeforeRefresh uint32

// TODO: A better strategy is to keep following variables per cache key (type/owner combination) and lock/fetch only the 'expired' entry
lastFetched time.Time
lastFetched int64
queriesResponded uint32
// prevKeys tracks cache keys from the previous update cycle to identify stale entries that need removal.
prevKeys map[string]struct{}
mu sync.RWMutex
}

// newCachedFetcher creates a fetcher that maintains a periodically refreshed cache of all tokens.
func newCachedFetcher(tokenDB TokenDB, cacheSize int64, freshnessInterval time.Duration, maxQueriesBeforeRefresh int) *cachedFetcher {
// NewCachedFetcher creates a fetcher that maintains a periodically refreshed cache of all tokens.
func NewCachedFetcher(tokenDB TokenDB, cacheSize int64, freshnessInterval time.Duration, maxQueriesBeforeRefresh int) *cachedFetcher {
// Use defaults if values are not provided (zero values)
if freshnessInterval <= 0 {
freshnessInterval = defaultCacheFreshnessInterval
Expand Down Expand Up @@ -219,7 +207,6 @@ func newCachedFetcher(tokenDB TokenDB, cacheSize int64, freshnessInterval time.D
}
}

// update refreshes the token cache from the database, adding new entries before removing stale ones to prevent race conditions.
func (f *cachedFetcher) update(ctx context.Context) {
f.mu.Lock()
defer f.mu.Unlock()
Expand All @@ -239,7 +226,7 @@ func (f *cachedFetcher) update(ctx context.Context) {

m := f.groupTokensByKey(ctx, it)
f.updateCache(ctx, m)
f.lastFetched = time.Now()
atomic.StoreInt64(&f.lastFetched, time.Now().UnixNano())
atomic.StoreUint32(&f.queriesResponded, 0)
}

Expand Down Expand Up @@ -282,7 +269,7 @@ func (f *cachedFetcher) updateCache(ctx context.Context, tokensByKey map[string]
}

// UnspentTokensIteratorBy returns cached unspent tokens, triggering a refresh if the cache is stale or overused.
func (f *cachedFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (iterator[*token2.UnspentTokenInWallet], error) {
func (f *cachedFetcher) UnspentTokensIteratorBy(ctx context.Context, walletID string, currency token2.Type) (Iterator[*token2.UnspentTokenInWallet], error) {
defer atomic.AddUint32(&f.queriesResponded, 1)
if f.isCacheOverused() {
logger.DebugfContext(ctx, "Overused data. Soft refresh (in the background)...")
Expand Down Expand Up @@ -313,5 +300,10 @@ func (f *cachedFetcher) isCacheOverused() bool {

// isCacheStale checks if the cache has exceeded its freshness interval.
func (f *cachedFetcher) isCacheStale() bool {
return time.Since(f.lastFetched) > f.freshnessInterval
lastFetched := atomic.LoadInt64(&f.lastFetched)
if lastFetched == 0 {
return true
}

return time.Since(time.Unix(0, lastFetched)) > f.freshnessInterval
}
42 changes: 21 additions & 21 deletions token/services/selector/sherdlock/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestNewCachedFetcher_WithDefaults(t *testing.T) {
mockDB := new(mockTokenDB)

// Test with zero values (should use defaults)
fetcher := newCachedFetcher(mockDB, 0, 0, 0)
fetcher := NewCachedFetcher(mockDB, 0, 0, 0)

assert.NotNil(t, fetcher)
assert.Equal(t, defaultCacheFreshnessInterval, fetcher.freshnessInterval)
Expand All @@ -62,7 +62,7 @@ func TestNewCachedFetcher_WithCustomValues(t *testing.T) {
customFreshness := 60 * time.Second
customMaxQueries := 200

fetcher := newCachedFetcher(mockDB, customSize, customFreshness, customMaxQueries)
fetcher := NewCachedFetcher(mockDB, customSize, customFreshness, customMaxQueries)

assert.NotNil(t, fetcher)
assert.Equal(t, customFreshness, fetcher.freshnessInterval)
Expand All @@ -72,13 +72,13 @@ func TestNewCachedFetcher_WithCustomValues(t *testing.T) {

func TestCachedFetcher_IsCacheStale(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 100*time.Millisecond, 0)
fetcher := NewCachedFetcher(mockDB, 0, 100*time.Millisecond, 0)

// Initially cache should be stale (lastFetched is zero time)
assert.True(t, fetcher.isCacheStale())

// Update lastFetched to now
fetcher.lastFetched = time.Now()
atomic.StoreInt64(&fetcher.lastFetched, time.Now().UnixNano())
assert.False(t, fetcher.isCacheStale())

// Wait for cache to become stale
Expand All @@ -89,7 +89,7 @@ func TestCachedFetcher_IsCacheStale(t *testing.T) {
func TestCachedFetcher_IsCacheOverused(t *testing.T) {
mockDB := new(mockTokenDB)
maxQueries := 5
fetcher := newCachedFetcher(mockDB, 0, 0, maxQueries)
fetcher := NewCachedFetcher(mockDB, 0, 0, maxQueries)

// Initially not overused
assert.False(t, fetcher.isCacheOverused())
Expand All @@ -107,7 +107,7 @@ func TestCachedFetcher_IsCacheOverused(t *testing.T) {

func TestCachedFetcher_Update(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 1*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100)

// Create test tokens
tokens := []*token2.UnspentTokenInWallet{
Expand Down Expand Up @@ -158,7 +158,7 @@ func TestCachedFetcher_Update(t *testing.T) {

func TestCachedFetcher_UnspentTokensIteratorBy_CacheHit(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 10*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 10*time.Second, 100)

// Populate cache
tokens := []*token2.UnspentTokenInWallet{
Expand All @@ -179,7 +179,7 @@ func TestCachedFetcher_UnspentTokensIteratorBy_CacheHit(t *testing.T) {

require.NoError(t, err)
assert.NotNil(t, it)
assert.True(t, it.(enhancedIterator[*token2.UnspentTokenInWallet]).HasNext())
assert.True(t, it.(interface{ HasNext() bool }).HasNext())

// Verify query counter incremented
assert.Equal(t, uint32(1), atomic.LoadUint32(&fetcher.queriesResponded))
Expand All @@ -189,7 +189,7 @@ func TestCachedFetcher_UnspentTokensIteratorBy_CacheHit(t *testing.T) {

func TestCachedFetcher_UnspentTokensIteratorBy_CacheMiss(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 10*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 10*time.Second, 100)

// Populate cache with different key
tokens := []*token2.UnspentTokenInWallet{
Expand All @@ -211,15 +211,15 @@ func TestCachedFetcher_UnspentTokensIteratorBy_CacheMiss(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, it)
// Should return empty iterator
assert.False(t, it.(enhancedIterator[*token2.UnspentTokenInWallet]).HasNext())
assert.False(t, it.(interface{ HasNext() bool }).HasNext())

mockDB.AssertExpectations(t)
}

func TestCachedFetcher_UnspentTokensIteratorBy_StaleCache(t *testing.T) {
mockDB := new(mockTokenDB)
// Very short freshness interval
fetcher := newCachedFetcher(mockDB, 0, 50*time.Millisecond, 100)
fetcher := NewCachedFetcher(mockDB, 0, 50*time.Millisecond, 100)

// Initial population
tokens1 := []*token2.UnspentTokenInWallet{
Expand Down Expand Up @@ -260,7 +260,7 @@ func TestCachedFetcher_UnspentTokensIteratorBy_StaleCache(t *testing.T) {

func TestCachedFetcher_CacheClear(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 10*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 10*time.Second, 100)

// First update with tokens
tokens1 := []*token2.UnspentTokenInWallet{
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestCachedFetcher_CacheClear(t *testing.T) {
mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")).Return(mockIterator2, nil).Once()

// Force cache to be stale so update will actually run
fetcher.lastFetched = time.Now().Add(-20 * time.Second)
atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-20*time.Second).UnixNano())

fetcher.update(ctx)

Expand All @@ -317,7 +317,7 @@ func TestCachedFetcher_CacheClear(t *testing.T) {
func TestNewMixedFetcher(t *testing.T) {
mockDB := new(mockTokenDB)

fetcher := newMixedFetcher(mockDB, nil, 100, 30*time.Second, 100)
fetcher := NewMixedFetcher(mockDB, nil, 100, 30*time.Second, 100)

assert.NotNil(t, fetcher)
assert.NotNil(t, fetcher.lazyFetcher)
Expand Down Expand Up @@ -486,7 +486,7 @@ func TestLazyFetcher_UnspentTokensIteratorBy_ErrorHandling(t *testing.T) {

func TestMixedFetcher_FallbackBehavior(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newMixedFetcher(mockDB, NewMetrics(&disabled.Provider{}), 0, 10*time.Second, 100)
fetcher := NewMixedFetcher(mockDB, NewMetrics(&disabled.Provider{}), 0, 10*time.Second, 100)

t.Run("uses lazy fetcher when eager returns error", func(t *testing.T) {
// Setup: eager fetcher will fail to update
Expand Down Expand Up @@ -553,7 +553,7 @@ func TestMixedFetcher_FallbackBehavior(t *testing.T) {

func TestCachedFetcher_ConcurrentAccess(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 1*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100)

t.Run("handles concurrent reads during update", func(t *testing.T) {
// Populate cache
Expand Down Expand Up @@ -596,7 +596,7 @@ func TestCachedFetcher_ConcurrentAccess(t *testing.T) {

func TestCachedFetcher_GroupTokensByKey(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 1*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100)

t.Run("groups tokens correctly", func(t *testing.T) {
tokens := []*token2.UnspentTokenInWallet{
Expand Down Expand Up @@ -637,7 +637,7 @@ func TestCachedFetcher_GroupTokensByKey(t *testing.T) {
// TestCachedFetcher_UpdateCache verifies cache updates without race conditions (add before remove).
func TestCachedFetcher_UpdateCache(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 1*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100)

t.Run("removes stale keys", func(t *testing.T) {
ctx := t.Context()
Expand Down Expand Up @@ -751,7 +751,7 @@ func TestCachedFetcher_UpdateCache(t *testing.T) {
func TestCachedFetcher_SoftRefresh(t *testing.T) {
mockDB := new(mockTokenDB)
maxQueries := 3
fetcher := newCachedFetcher(mockDB, 0, 10*time.Second, maxQueries)
fetcher := NewCachedFetcher(mockDB, 0, 10*time.Second, maxQueries)

t.Run("triggers soft refresh when overused", func(t *testing.T) {
// Initial population
Expand Down Expand Up @@ -886,7 +886,7 @@ func TestFetcherProvider_GetFetcher(t *testing.T) {
// TestCachedFetcher_UpdateWithDatabaseError verifies cache stays stale when DB update fails.
func TestCachedFetcher_UpdateWithDatabaseError(t *testing.T) {
mockDB := new(mockTokenDB)
fetcher := newCachedFetcher(mockDB, 0, 1*time.Second, 100)
fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100)

t.Run("handles database error gracefully", func(t *testing.T) {
expectedErr := errors.New("database connection failed")
Expand Down Expand Up @@ -929,7 +929,7 @@ func TestTokenKey_EdgeCases(t *testing.T) {
func TestMixedFetcher_MetricsTracking(t *testing.T) {
mockDB := new(mockTokenDB)
metrics := NewMetrics(&disabled.Provider{})
fetcher := newMixedFetcher(mockDB, metrics, 0, 10*time.Second, 100)
fetcher := NewMixedFetcher(mockDB, metrics, 0, 10*time.Second, 100)

t.Run("tracks eager fetcher usage", func(t *testing.T) {
// Populate cache
Expand Down
Loading
Loading