diff --git a/token/request.go b/token/request.go index cb9628e704..237e521b5d 100644 --- a/token/request.go +++ b/token/request.go @@ -20,6 +20,7 @@ import ( "github.com/hyperledger-labs/fabric-token-sdk/token/driver" "github.com/hyperledger-labs/fabric-token-sdk/token/driver/protos-go/request" "github.com/hyperledger-labs/fabric-token-sdk/token/services/utils" + "github.com/hyperledger-labs/fabric-token-sdk/token/services/validation" "github.com/hyperledger-labs/fabric-token-sdk/token/token" "go.uber.org/zap/zapcore" ) @@ -283,24 +284,30 @@ func (r *Request) ID() RequestAnchor { // Additional options can be passed to customize the action. func (r *Request) Issue(ctx context.Context, wallet *IssuerWallet, receiver Identity, typ token.Type, q uint64, opts ...IssueOption) (*IssueAction, error) { logger.DebugfContext(ctx, "Start issue") - logger.DebugfContext(ctx, "Done issue") + if wallet == nil { return nil, errors.Errorf("wallet is nil") } - if typ == "" { - return nil, errors.Errorf("type is empty") + if err := validation.ValidateTokenType(string(typ)); err != nil { + return nil, errors.Wrap(err, "invalid token type") + } + if err := validation.ValidateAmount(q, 0); err != nil { + return nil, errors.Wrap(err, "invalid amount") + } + if receiver.IsNone() { + return nil, errors.Errorf("all recipients should be defined") } - if q == 0 { - return nil, errors.Errorf("q is zero") + if r.TokenService == nil || r.TokenService.PublicParametersManager() == nil || r.TokenService.PublicParametersManager().PublicParameters() == nil { + return nil, errors.Errorf("token service is not properly initialized") } + + // Validate amount doesn't exceed max token value maxTokenValue := r.TokenService.PublicParametersManager().PublicParameters().MaxTokenValue() if q > maxTokenValue { - return nil, errors.Errorf("q is larger than max token value [%d]", maxTokenValue) + return nil, errors.Errorf("amount exceeds max token value [%d]", maxTokenValue) } - if receiver.IsNone() { - return nil, errors.Errorf("all recipients should be defined") - } + logger.DebugfContext(ctx, "Done issue") id, err := wallet.GetIssuerIdentity(typ) if err != nil { @@ -312,6 +319,11 @@ func (r *Request) Issue(ctx context.Context, wallet *IssuerWallet, receiver Iden return nil, errors.WithMessagef(err, "failed compiling options [%v]", opts) } + // Validate metadata using the validation package + if err := validation.ValidateMetadata(opt.Attributes); err != nil { + return nil, errors.Wrap(err, "invalid metadata") + } + // Compute Issue action, metaRaw, err := r.TokenService.tms.IssueService().Issue( ctx, @@ -343,15 +355,51 @@ func (r *Request) Issue(ctx context.Context, wallet *IssuerWallet, receiver Iden // In other words, owners[0] will receives values[0], and so on. // Additional options can be passed to customize the action. func (r *Request) Transfer(ctx context.Context, wallet *OwnerWallet, typ token.Type, values []uint64, owners []Identity, opts ...TransferOption) (*TransferAction, error) { - for _, v := range values { + if wallet == nil { + return nil, errors.Errorf("wallet is nil") + } + if r.TokenService == nil || r.TokenService.PublicParametersManager() == nil || r.TokenService.PublicParametersManager().PublicParameters() == nil { + return nil, errors.Errorf("token service is not properly initialized") + } + + // Validate token type + if err := validation.ValidateTokenType(string(typ)); err != nil { + return nil, errors.Wrap(err, "invalid token type") + } + + // Validate values using the validation package + maxTokenValue := r.TokenService.PublicParametersManager().PublicParameters().MaxTokenValue() + for i, v := range values { if v == 0 { - return nil, errors.Errorf("value is zero") + return nil, errors.Errorf("value at index %d is zero", i) + } + if v > maxTokenValue { + return nil, errors.Errorf("value at index %d exceeds max token value [%d]", i, maxTokenValue) } } + + // Validate owners match values length + if len(owners) != len(values) { + return nil, errors.Errorf("number of owners [%d] does not match number of values [%d]", len(owners), len(values)) + } + + // Validate all owners are defined + for i, owner := range owners { + if owner.IsNone() { + return nil, errors.Errorf("owner at index %d is not defined", i) + } + } + opt, err := CompileTransferOptions(opts...) if err != nil { return nil, errors.WithMessagef(err, "failed compiling options [%v]", opts) } + + // Validate metadata using the validation package + if err := validation.ValidateMetadata(opt.Attributes); err != nil { + return nil, errors.Wrap(err, "invalid metadata") + } + tokenIDs, outputTokens, err := r.prepareTransfer(ctx, false, wallet, typ, values, owners, opt) if err != nil { return nil, errors.Wrap(err, "failed preparing transfer") @@ -397,10 +445,37 @@ func (r *Request) Transfer(ctx context.Context, wallet *OwnerWallet, typ token.T // The action redeems tokens of the passed type for a total amount matching the passed value. // Additional options can be passed to customize the action. func (r *Request) Redeem(ctx context.Context, wallet *OwnerWallet, typ token.Type, value uint64, opts ...TransferOption) (*TransferAction, error) { + if wallet == nil { + return nil, errors.Errorf("wallet is nil") + } + if r.TokenService == nil || r.TokenService.PublicParametersManager() == nil || r.TokenService.PublicParametersManager().PublicParameters() == nil { + return nil, errors.Errorf("token service is not properly initialized") + } + + // Validate token type + if err := validation.ValidateTokenType(string(typ)); err != nil { + return nil, errors.Wrap(err, "invalid token type") + } + + // Validate value doesn't exceed max + maxTokenValue := r.TokenService.PublicParametersManager().PublicParameters().MaxTokenValue() + if value == 0 { + return nil, errors.Errorf("redeem value is zero") + } + if value > maxTokenValue { + return nil, errors.Errorf("redeem value exceeds max token value [%d]", maxTokenValue) + } + opt, err := CompileTransferOptions(opts...) if err != nil { return nil, errors.WithMessagef(err, "failed compiling options [%v]", opts) } + + // Validate metadata using the validation package + if err := validation.ValidateMetadata(opt.Attributes); err != nil { + return nil, errors.Wrap(err, "invalid metadata") + } + tokenIDs, outputTokens, err := r.prepareTransfer(ctx, true, wallet, typ, []uint64{value}, []Identity{nil}, opt) if err != nil { return nil, errors.Wrap(err, "failed preparing transfer") diff --git a/token/request_test.go b/token/request_test.go index 85b71e5993..937b1f6743 100644 --- a/token/request_test.go +++ b/token/request_test.go @@ -525,7 +525,7 @@ func TestRequest_Issue(t *testing.T) { wallet := &IssuerWallet{} _, err := req.Issue(ctx, wallet, Identity("receiver"), "", 100) require.Error(t, err) - assert.Contains(t, err.Error(), "type is empty") + assert.Contains(t, err.Error(), "invalid token type") }) t.Run("zero quantity", func(t *testing.T) { @@ -533,7 +533,7 @@ func TestRequest_Issue(t *testing.T) { wallet := &IssuerWallet{} _, err := req.Issue(ctx, wallet, Identity("receiver"), "USD", 0) require.Error(t, err) - assert.Contains(t, err.Error(), "q is zero") + assert.Contains(t, err.Error(), "invalid amount") }) t.Run("none receiver", func(t *testing.T) { @@ -577,7 +577,7 @@ func TestRequest_Issue(t *testing.T) { mockWallet := &IssuerWallet{} _, err := req.Issue(ctx, mockWallet, Identity("receiver"), "USD", 200) require.Error(t, err) - assert.Contains(t, err.Error(), "q is larger than max token value") + assert.Contains(t, err.Error(), "amount exceeds max token value") }) } @@ -585,20 +585,51 @@ func TestRequest_Issue(t *testing.T) { func TestRequest_Transfer(t *testing.T) { ctx := t.Context() - t.Run("zero value", func(t *testing.T) { + t.Run("nil wallet", func(t *testing.T) { req := NewRequest(nil, "test-anchor") + _, err := req.Transfer(ctx, nil, "USD", []uint64{100}, []Identity{Identity("receiver")}) + require.Error(t, err) + assert.Contains(t, err.Error(), "wallet is nil") + }) + + t.Run("zero value", func(t *testing.T) { + mockPP := &driver2.PublicParameters{} + mockPP.MaxTokenValueReturns(1000000) + + mockPPM := &driver2.PublicParamsManager{} + mockPPM.PublicParametersReturns(mockPP) + + tms := &ManagementService{ + publicParametersManager: &PublicParametersManager{ + ppm: mockPPM, + pp: &PublicParameters{PublicParameters: mockPP}, + }, + } + req := NewRequest(tms, "test-anchor") wallet := &OwnerWallet{} _, err := req.Transfer(ctx, wallet, "USD", []uint64{0, 100}, []Identity{Identity("receiver1"), Identity("receiver2")}) require.Error(t, err) - assert.Contains(t, err.Error(), "value is zero") + assert.Contains(t, err.Error(), "value at index 0 is zero") }) t.Run("multiple zero values", func(t *testing.T) { - req := NewRequest(nil, "test-anchor") + mockPP := &driver2.PublicParameters{} + mockPP.MaxTokenValueReturns(1000000) + + mockPPM := &driver2.PublicParamsManager{} + mockPPM.PublicParametersReturns(mockPP) + + tms := &ManagementService{ + publicParametersManager: &PublicParametersManager{ + ppm: mockPPM, + pp: &PublicParameters{PublicParameters: mockPP}, + }, + } + req := NewRequest(tms, "test-anchor") wallet := &OwnerWallet{} _, err := req.Transfer(ctx, wallet, "USD", []uint64{100, 0}, []Identity{Identity("receiver1"), Identity("receiver2")}) require.Error(t, err) - assert.Contains(t, err.Error(), "value is zero") + assert.Contains(t, err.Error(), "value at index 1 is zero") }) } diff --git a/token/services/selector/sherdlock/fetcher.go b/token/services/selector/sherdlock/fetcher.go index f1480f3411..708f6c8607 100644 --- a/token/services/selector/sherdlock/fetcher.go +++ b/token/services/selector/sherdlock/fetcher.go @@ -8,10 +8,12 @@ package sherdlock import ( "context" + "io" "sync" "sync/atomic" "time" + "github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors" "github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections" "github.com/hyperledger-labs/fabric-smart-client/platform/common/utils/collections/iterators" "github.com/hyperledger-labs/fabric-token-sdk/token" @@ -170,7 +172,11 @@ type cachedFetcher struct { queriesResponded uint32 // prevKeys tracks cache keys from the previous update cycle to identify stale entries that need removal. prevKeys map[string]struct{} - mu sync.RWMutex + // isUpdating indicates if a cache refresh is currently in progress. + isUpdating bool + // updateCond allows goroutines to wait for an in-progress update to complete. + updateCond *sync.Cond + mu sync.RWMutex } // NewCachedFetcher creates a fetcher that maintains a periodically refreshed cache of all tokens. @@ -198,48 +204,110 @@ func NewCachedFetcher(tokenDB TokenDB, cacheSize int64, freshnessInterval time.D panic("failed to create ristretto cache: " + err.Error()) } - return &cachedFetcher{ + f := &cachedFetcher{ tokenDB: tokenDB, cache: ristrettoCache, freshnessInterval: freshnessInterval, maxQueriesBeforeRefresh: uint32(maxQueriesBeforeRefresh), prevKeys: make(map[string]struct{}), } + f.updateCond = sync.NewCond(&f.mu) + + return f +} + +// finishUpdate signals completion and releases the lock. +// Must be called while holding f.mu. +func (f *cachedFetcher) finishUpdate() { + f.isUpdating = false + f.updateCond.Broadcast() + f.mu.Unlock() } +// completeUpdate signals completion without unlocking. +// Use this when the lock is not held (e.g., on error paths before re-acquiring lock). +func (f *cachedFetcher) completeUpdate() { + f.isUpdating = false + f.updateCond.Broadcast() +} + +// update refreshes the token cache from the database. It releases the lock during the +// potentially slow DB operation to avoid blocking other goroutines, then re-acquires +// the lock to atomically update the cache. A re-check of staleness is performed +// after the DB call completes to avoid overwriting a cache that was refreshed by +// another goroutine while waiting for the database. func (f *cachedFetcher) update(ctx context.Context) { f.mu.Lock() - defer f.mu.Unlock() + if f.isUpdating { + // Wait for the in-progress update to finish + for f.isUpdating { + f.updateCond.Wait() + } + f.mu.Unlock() + + return + } + if !f.isCacheStale() && !f.isCacheOverused() { logger.DebugfContext(ctx, "Cache renewed in the meantime by another process") + f.mu.Unlock() return } logger.DebugfContext(ctx, "Renew token cache") + f.isUpdating = true + + // Release lock during slow DB operation to not block other token operations + f.mu.Unlock() + it, err := f.tokenDB.SpendableTokensIteratorBy(ctx, "", "") if err != nil { logger.Warnf("Failed to get token iterator: %v", err) + f.completeUpdate() return } defer it.Close() - m := f.groupTokensByKey(ctx, it) + m, err := f.groupTokensByKey(ctx, it) + if err != nil { + logger.Warnf("Failed to group tokens from iterator: %v", err) + f.completeUpdate() + + return + } + + f.mu.Lock() + // Re-check: another goroutine may have refreshed while we waited for DB + if !f.isCacheStale() && !f.isCacheOverused() { + logger.DebugfContext(ctx, "Cache renewed in the meantime by another process, skipping") + f.finishUpdate() + + return + } + f.updateCache(ctx, m) atomic.StoreInt64(&f.lastFetched, time.Now().UnixNano()) atomic.StoreUint32(&f.queriesResponded, 0) + f.finishUpdate() } // groupTokensByKey reads tokens from the iterator and groups them by wallet/currency key. -func (f *cachedFetcher) groupTokensByKey(ctx context.Context, it driver.SpendableTokensIterator) map[string][]*token2.UnspentTokenInWallet { +// It returns an error if the iterator fails mid-way to prevent partial updates. +func (f *cachedFetcher) groupTokensByKey(ctx context.Context, it driver.SpendableTokensIterator) (map[string][]*token2.UnspentTokenInWallet, error) { m := map[string][]*token2.UnspentTokenInWallet{} for t, err := it.Next(); err == nil && t != nil; t, err = it.Next() { key := tokenKey(t.WalletID, t.Type) logger.DebugfContext(ctx, "Adding token with key [%s]", key) m[key] = append(m[key], t) } + // Re-check for error after loop termination + _, err := it.Next() + if err != nil && !errors.Is(err, io.EOF) { + return nil, err + } - return m + return m, nil } // updateCache updates the cache by adding new entries before removing stale ones. diff --git a/token/services/selector/sherdlock/fetcher_test.go b/token/services/selector/sherdlock/fetcher_test.go index c931e08935..5dcdb6b505 100644 --- a/token/services/selector/sherdlock/fetcher_test.go +++ b/token/services/selector/sherdlock/fetcher_test.go @@ -9,6 +9,7 @@ package sherdlock import ( "context" "errors" + "sync" "sync/atomic" "testing" "time" @@ -81,8 +82,8 @@ func TestCachedFetcher_IsCacheStale(t *testing.T) { atomic.StoreInt64(&fetcher.lastFetched, time.Now().UnixNano()) assert.False(t, fetcher.isCacheStale()) - // Wait for cache to become stale - time.Sleep(150 * time.Millisecond) + // Manually set lastFetched to the past instead of sleeping + atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-fetcher.freshnessInterval*2).UnixNano()) assert.True(t, fetcher.isCacheStale()) } @@ -235,8 +236,9 @@ func TestCachedFetcher_UnspentTokensIteratorBy_StaleCache(t *testing.T) { ctx := t.Context() fetcher.update(ctx) - // Wait for cache to become stale - time.Sleep(100 * time.Millisecond) + // Trigger hard refresh by setting lastFetched to the past + atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-fetcher.freshnessInterval*2).UnixNano()) + assert.True(t, fetcher.isCacheStale()) // Setup second call expectation tokens2 := []*token2.UnspentTokenInWallet{ @@ -303,10 +305,6 @@ func TestCachedFetcher_CacheClear(t *testing.T) { fetcher.update(ctx) - // Note: Ristretto cache uses probabilistic eviction and may not immediately reflect changes - // We wait a bit for the cache to process the clear and new additions - time.Sleep(10 * time.Millisecond) - // New key should exist _, ok3 := fetcher.cache.Get(tokenKey("wallet3", "GBP")) assert.True(t, ok3) @@ -608,7 +606,8 @@ func TestCachedFetcher_GroupTokensByKey(t *testing.T) { mockIterator := iterators.Slice(tokens) ctx := t.Context() - grouped := fetcher.groupTokensByKey(ctx, mockIterator) + grouped, err := fetcher.groupTokensByKey(ctx, mockIterator) + require.NoError(t, err) // Should have 3 keys: wallet1-USD, wallet1-EUR, wallet2-USD assert.Len(t, grouped, 3) @@ -628,7 +627,8 @@ func TestCachedFetcher_GroupTokensByKey(t *testing.T) { mockIterator := iterators.Slice(tokens) ctx := t.Context() - grouped := fetcher.groupTokensByKey(ctx, mockIterator) + grouped, err := fetcher.groupTokensByKey(ctx, mockIterator) + require.NoError(t, err) assert.Empty(t, grouped) }) @@ -667,9 +667,6 @@ func TestCachedFetcher_UpdateCache(t *testing.T) { } fetcher.updateCache(ctx, tokensByKey2) - // Wait for cache to process deletions - time.Sleep(10 * time.Millisecond) - // First key should still exist, second should be removed _, ok1 = fetcher.cache.Get(tokenKey("wallet1", "USD")) assert.True(t, ok1) @@ -731,12 +728,12 @@ func TestCachedFetcher_UpdateCache(t *testing.T) { }, } fetcher.updateCache(ctx, newTokens) - time.Sleep(5 * time.Millisecond) + time.Sleep(20 * time.Millisecond) } // Stop readers close(stopReading) - time.Sleep(20 * time.Millisecond) + time.Sleep(100 * time.Millisecond) // Check for errors select { @@ -787,6 +784,61 @@ func TestCachedFetcher_SoftRefresh(t *testing.T) { }) } +func TestCachedFetcher_Update_ThunderingHerd(t *testing.T) { + mockDB := new(mockTokenDB) + // Short freshness interval + fetcher := NewCachedFetcher(mockDB, 0, 50*time.Millisecond, 100) + + // Initial population + mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")). + Return(iterators.Slice([]*token2.UnspentTokenInWallet{}), nil).Once() + + ctx := t.Context() + fetcher.update(ctx) + + // Trigger staleness manually + fetcher.mu.Lock() + atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-10*time.Second).UnixNano()) + fetcher.mu.Unlock() + + // Block the next DB call with a mock that waits + dbCallStarted := make(chan struct{}) + dbCallRelease := make(chan struct{}) + + mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")). + Run(func(args mock.Arguments) { + close(dbCallStarted) + <-dbCallRelease + }). + Return(iterators.Slice([]*token2.UnspentTokenInWallet{}), nil).Once() + + // Start multiple concurrent reads + var wg sync.WaitGroup + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = fetcher.UnspentTokensIteratorBy(ctx, "wallet1", "USD") + }() + } + + // Wait for at least one to start the DB call + select { + case <-dbCallStarted: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for DB call to start") + } + + // Release the DB call + close(dbCallRelease) + + // Wait for all to finish + wg.Wait() + + // Mock should only have been called TWICE total (once for initial, once for the 10 concurrent ones) + mockDB.AssertNumberOfCalls(t, "SpendableTokensIteratorBy", 2) +} + // TestNewFetcherProvider verifies provider creation with valid/invalid strategies and zero values. func TestNewFetcherProvider(t *testing.T) { t.Run("creates provider with valid strategy", func(t *testing.T) { @@ -980,3 +1032,107 @@ func (m *mockStoreServiceManager) StoreServiceByTMSId(tmsID token.TMSID) (*token return nil, errors.New("not implemented") } + +// TestCachedFetcher_UpdateDoesNotBlockReaders tests that the update() function +// releases the lock during the potentially slow DB operation, allowing concurrent +// readers to access the cache. This is the fix for issue #16. +func TestCachedFetcher_UpdateDoesNotBlockReaders(t *testing.T) { + mockDB := new(mockTokenDB) + // Use long freshness interval so cache won't be stale + fetcher := NewCachedFetcher(mockDB, 0, 10*time.Second, 100) + + // Pre-populate the cache so readers can hit it + initialTokens := []*token2.UnspentTokenInWallet{ + {WalletID: "wallet1", Type: "USD", Quantity: "100"}, + } + mockIterator := iterators.Slice(initialTokens) + mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")).Return(mockIterator, nil).Once() + + ctx := t.Context() + fetcher.update(ctx) + + // Make cache stale so update() will be called + atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-20*time.Second).UnixNano()) + + // Use channels to synchronize instead of Sleep + dbStarted := make(chan struct{}) + slowDB := make(chan struct{}) + tokensAfterSlowDB := []*token2.UnspentTokenInWallet{ + {WalletID: "wallet1", Type: "USD", Quantity: "200"}, + } + mockIterator2 := iterators.Slice(tokensAfterSlowDB) + mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")).Return(mockIterator2, nil).Run(func(args mock.Arguments) { + close(dbStarted) // Signal that the DB operation has started + <-slowDB // Wait before returning to simulate slow DB + }).Once() + + // Track whether reader succeeded while update() was blocked on DB + var readerSuccess atomic.Bool + var readerWg sync.WaitGroup + + // Start update in background (it will block on DB call) + readerWg.Add(1) + go func() { + defer readerWg.Done() + fetcher.update(ctx) + }() + + // Wait for the background update to actually reach the DB call + select { + case <-dbStarted: + // Background update is now at line 240, having released the lock at line 238 + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for background update to reach DB operation") + } + + // Reader should be able to acquire RLock while update() waits on DB + // This would deadlock before the fix + fetcher.mu.RLock() + _, ok := fetcher.cache.Get(tokenKey("wallet1", "USD")) + fetcher.mu.RUnlock() + + if ok { + readerSuccess.Store(true) + } + + // Signal DB to complete + close(slowDB) + + // Wait for update to complete + readerWg.Wait() + + // Verify reader succeeded - the cache should still be accessible during update + assert.True(t, readerSuccess.Load(), "reader should be able to access cache while update() is blocked on DB") + mockDB.AssertExpectations(t) +} + +// TestCachedFetcher_UpdateReacquiresLockAfterDB tests that after the DB operation +// completes, update() correctly re-acquires the lock and performs the cache update. +func TestCachedFetcher_UpdateReacquiresLockAfterDB(t *testing.T) { + mockDB := new(mockTokenDB) + fetcher := NewCachedFetcher(mockDB, 0, 1*time.Second, 100) + + // Pre-populate to make cache appear stale + atomic.StoreInt64(&fetcher.lastFetched, time.Now().Add(-20*time.Second).UnixNano()) + + tokens := []*token2.UnspentTokenInWallet{ + {WalletID: "wallet1", Type: "USD", Quantity: "300"}, + } + mockIterator := iterators.Slice(tokens) + mockDB.On("SpendableTokensIteratorBy", mock.Anything, "", token2.Type("")).Return(mockIterator, nil).Once() + + ctx := t.Context() + fetcher.update(ctx) + + // After update completes, cache should be refreshed (not stale) + assert.False(t, fetcher.isCacheStale()) + assert.Equal(t, uint32(0), atomic.LoadUint32(&fetcher.queriesResponded)) + + // Token should be in cache + fetcher.mu.RLock() + _, ok := fetcher.cache.Get(tokenKey("wallet1", "USD")) + fetcher.mu.RUnlock() + assert.True(t, ok, "token should be in cache after update") + + mockDB.AssertExpectations(t) +} diff --git a/token/services/validation/validator.go b/token/services/validation/validator.go new file mode 100644 index 0000000000..1610a35f40 --- /dev/null +++ b/token/services/validation/validator.go @@ -0,0 +1,184 @@ +/* +Copyright IBM Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package validation + +import ( + "github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors" +) + +const ( + // MaxMetadataSize is the maximum size of metadata in bytes (10KB) + MaxMetadataSize = 10 * 1024 + // MaxAddressLength is the maximum length of an address + MaxAddressLength = 256 +) + +// InvalidAmountError indicates a token amount validation failure +type InvalidAmountError struct { + Message string + Value uint64 +} + +func (e *InvalidAmountError) Error() string { + return e.Message +} + +// InvalidAddressError indicates an address validation failure +type InvalidAddressError struct { + Message string + Address []byte +} + +func (e *InvalidAddressError) Error() string { + return e.Message +} + +// InvalidMetadataError indicates a metadata validation failure +type InvalidMetadataError struct { + Message string + Key string +} + +func (e *InvalidMetadataError) Error() string { + return e.Message +} + +// InvalidTokenTypeError indicates a token type validation failure +type InvalidTokenTypeError struct { + Message string + Type string +} + +func (e *InvalidTokenTypeError) Error() string { + return e.Message +} + +// ValidationError is a generic validation error +type ValidationError struct { + Message string +} + +func (e *ValidationError) Error() string { + return e.Message +} + +// NewInvalidAmountError creates a new InvalidAmountError +func NewInvalidAmountError(message string, value uint64) *InvalidAmountError { + return &InvalidAmountError{Message: message, Value: value} +} + +// NewInvalidAddressError creates a new InvalidAddressError +func NewInvalidAddressError(message string, address []byte) *InvalidAddressError { + return &InvalidAddressError{Message: message, Address: address} +} + +// NewInvalidMetadataError creates a new InvalidMetadataError +func NewInvalidMetadataError(message, key string) *InvalidMetadataError { + return &InvalidMetadataError{Message: message, Key: key} +} + +// NewInvalidTokenTypeError creates a new InvalidTokenTypeError +func NewInvalidTokenTypeError(message, tokenType string) *InvalidTokenTypeError { + return &InvalidTokenTypeError{Message: message, Type: tokenType} +} + +// NewValidationError creates a new ValidationError +func NewValidationError(message string) *ValidationError { + return &ValidationError{Message: message} +} + +// ValidateAmount validates a token amount value +func ValidateAmount(value uint64, maxValue uint64) error { + if value == 0 { + return NewInvalidAmountError("token amount must be greater than zero", value) + } + + if maxValue > 0 && value > maxValue { + return NewInvalidAmountError("token amount exceeds maximum allowed value", value) + } + + return nil +} + +// ValidateAddress validates a recipient address +func ValidateAddress(address []byte) error { + if len(address) == 0 { + return NewInvalidAddressError("address cannot be empty", nil) + } + + if len(address) > MaxAddressLength { + return NewInvalidAddressError("address exceeds maximum length", address) + } + + return nil +} + +// ValidateTokenType validates a token type +func ValidateTokenType(tokenType string) error { + if tokenType == "" { + return NewInvalidTokenTypeError("token type cannot be empty", tokenType) + } + + return nil +} + +// ValidateMetadata validates metadata fields +func ValidateMetadata(metadata map[interface{}]interface{}) error { + if metadata == nil { + return nil + } + + for key, value := range metadata { + keyStr, isString := key.(string) + if key == nil || (isString && keyStr == "") { + return NewInvalidMetadataError("metadata key cannot be empty", "") + } + + // Check size for byte slice values + if bytes, ok := value.([]byte); ok { + if len(bytes) > MaxMetadataSize { + return NewInvalidMetadataError("metadata value exceeds maximum size", keyStr) + } + } + } + + return nil +} + +// ValidateTransferValues validates transfer values and owners +func ValidateTransferValues(values []uint64, owners [][]byte, maxValue uint64) error { + if len(values) == 0 { + return NewValidationError("values cannot be empty") + } + + if len(owners) == 0 { + return NewValidationError("owners cannot be empty") + } + + if len(values) != len(owners) { + return NewValidationError("values and owners must have the same length") + } + + for i, v := range values { + if err := ValidateAmount(v, maxValue); err != nil { + return errors.Wrapf(err, "value at index %d", i) + } + } + + for i, o := range owners { + if err := ValidateAddress(o); err != nil { + return errors.Wrapf(err, "owner at index %d", i) + } + } + + return nil +} + +// ValidateRedeemValue validates a redeem value +func ValidateRedeemValue(value uint64, maxValue uint64) error { + return ValidateAmount(value, maxValue) +} diff --git a/token/services/validation/validator_test.go b/token/services/validation/validator_test.go new file mode 100644 index 0000000000..e65d00eb66 --- /dev/null +++ b/token/services/validation/validator_test.go @@ -0,0 +1,225 @@ +/* +Copyright IBM Corp. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package validation + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateAmount(t *testing.T) { + tests := []struct { + name string + value uint64 + maxValue uint64 + wantErr bool + }{ + {"zero value", 0, 1000, true}, + {"positive value within limit", 500, 1000, false}, + {"max value", 1000, 1000, false}, + {"exceeds max", 1001, 1000, true}, + {"no max limit", 1001, 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateAmount(tt.value, tt.maxValue) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateAddress(t *testing.T) { + tests := []struct { + name string + address []byte + wantErr bool + }{ + {"empty address", []byte{}, true}, + {"nil address", nil, true}, + {"valid address", []byte("valid-address"), false}, + {"address too long", make([]byte, MaxAddressLength+1), true}, + {"max length address", make([]byte, MaxAddressLength), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateAddress(tt.address) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateTokenType(t *testing.T) { + tests := []struct { + name string + tokenType string + wantErr bool + }{ + {"empty type", "", true}, + {"valid type", "USD", false}, + {"valid type EUR", "EUR", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTokenType(tt.tokenType) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateMetadata(t *testing.T) { + t.Run("nil metadata", func(t *testing.T) { + err := ValidateMetadata(nil) + assert.NoError(t, err) + }) + + t.Run("empty metadata", func(t *testing.T) { + err := ValidateMetadata(map[interface{}]interface{}{}) + assert.NoError(t, err) + }) + + t.Run("valid metadata with byte values", func(t *testing.T) { + metadata := map[interface{}]interface{}{ + "key1": []byte("value1"), + "key2": []byte("value2"), + } + err := ValidateMetadata(metadata) + assert.NoError(t, err) + }) + + t.Run("valid metadata with non-byte values", func(t *testing.T) { + metadata := map[interface{}]interface{}{ + "key1": "string value", + "key2": 12345, + } + err := ValidateMetadata(metadata) + assert.NoError(t, err) + }) + + t.Run("empty key", func(t *testing.T) { + metadata := map[interface{}]interface{}{ + "": []byte("value"), + } + err := ValidateMetadata(metadata) + assert.Error(t, err) + }) + + t.Run("value too large", func(t *testing.T) { + metadata := map[interface{}]interface{}{ + "key1": make([]byte, MaxMetadataSize+1), + } + err := ValidateMetadata(metadata) + assert.Error(t, err) + }) +} + +func TestValidateTransferValues(t *testing.T) { + t.Run("empty values", func(t *testing.T) { + err := ValidateTransferValues([]uint64{}, [][]byte{[]byte("owner")}, 1000) + assert.Error(t, err) + }) + + t.Run("empty owners", func(t *testing.T) { + err := ValidateTransferValues([]uint64{100}, [][]byte{}, 1000) + assert.Error(t, err) + }) + + t.Run("mismatched lengths", func(t *testing.T) { + err := ValidateTransferValues([]uint64{100, 200}, [][]byte{[]byte("owner1")}, 1000) + assert.Error(t, err) + }) + + t.Run("zero value", func(t *testing.T) { + err := ValidateTransferValues([]uint64{0}, [][]byte{[]byte("owner")}, 1000) + assert.Error(t, err) + }) + + t.Run("exceeds max value", func(t *testing.T) { + err := ValidateTransferValues([]uint64{1001}, [][]byte{[]byte("owner")}, 1000) + assert.Error(t, err) + }) + + t.Run("empty owner", func(t *testing.T) { + err := ValidateTransferValues([]uint64{100}, [][]byte{{}}, 1000) + assert.Error(t, err) + }) + + t.Run("valid transfer", func(t *testing.T) { + err := ValidateTransferValues([]uint64{100, 200}, [][]byte{[]byte("owner1"), []byte("owner2")}, 1000) + assert.NoError(t, err) + }) +} + +func TestValidateRedeemValue(t *testing.T) { + tests := []struct { + name string + value uint64 + maxValue uint64 + wantErr bool + }{ + {"zero value", 0, 1000, true}, + {"positive value", 100, 1000, false}, + {"exceeds max", 1001, 1000, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateRedeemValue(tt.value, tt.maxValue) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestErrorTypes(t *testing.T) { + t.Run("InvalidAmountError", func(t *testing.T) { + err := NewInvalidAmountError("test message", 0) + assert.Contains(t, err.Error(), "test message") + assert.Equal(t, uint64(0), err.Value) + }) + + t.Run("InvalidAddressError", func(t *testing.T) { + err := NewInvalidAddressError("test message", []byte("addr")) + assert.Contains(t, err.Error(), "test message") + assert.Equal(t, []byte("addr"), err.Address) + }) + + t.Run("InvalidMetadataError", func(t *testing.T) { + err := NewInvalidMetadataError("test message", "key") + assert.Contains(t, err.Error(), "test message") + assert.Equal(t, "key", err.Key) + }) + + t.Run("InvalidTokenTypeError", func(t *testing.T) { + err := NewInvalidTokenTypeError("test message", "USD") + assert.Contains(t, err.Error(), "test message") + assert.Equal(t, "USD", err.Type) + }) + + t.Run("ValidationError", func(t *testing.T) { + err := NewValidationError("test message") + assert.Contains(t, err.Error(), "test message") + }) +}