diff --git a/locker.go b/locker.go index 5ad5a1c..50703fb 100644 --- a/locker.go +++ b/locker.go @@ -10,15 +10,12 @@ import ( // callback function is a function that should be executed right after it inserted into hashmap // generally callback is responsible for the removing itself from the hashmap -// id - id of the lock // notifyCh - channel to notify that all locks were removed // stopCh - broadcast channel to stop all the callbacks associated with the resource -type callback func(id string, notifyCh chan<- struct{}, stopCh <-chan struct{}) +type callback func(notifyCh chan<- struct{}, stopCh <-chan struct{}) // item represents callback element type item struct { - // callback to remove the item - callback callback // item's stop channel stopCh chan struct{} // item's update TTL channel @@ -102,21 +99,19 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { rr.ownerID.Store(new(id)) rr.writerCount.Store(1) - rr.readerCount.Store(0) l.resources[res] = rr callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) rr.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, rr.notificationCh, rr.stopCh) + callb(rr.notificationCh, rr.stopCh) }() l.globalMu.unlock() @@ -167,12 +162,14 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { l.log.Debug("got release mutex back", "id", id) // inconsistent, still have readers/writers after notification - if r.writerCount.Load() != 0 && r.readerCount.Load() != 0 { + if r.writerCount.Load() != 0 || r.readerCount.Load() != 0 { l.log.Error("inconsistent state, should be zero writers and zero readers", "resource", res, "id", id, "writers", r.writerCount.Load(), "readers", r.readerCount.Load()) + + r.resourceMu.unlock() return false } @@ -183,13 +180,12 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -264,7 +260,7 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { } // inconsistent, still have readers/writers after notification - if r.writerCount.Load() != 0 && r.readerCount.Load() != 0 { + if r.writerCount.Load() != 0 || r.readerCount.Load() != 0 { l.log.Error("inconsistent state, should be zero writers and zero readers", "resource", res, "id", id, @@ -284,14 +280,13 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -322,14 +317,13 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -363,14 +357,13 @@ func (l *locker) lock(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -416,7 +409,6 @@ func (l *locker) lockRead(ctx context.Context, res, id string, ttl int) bool { } rr.ownerID.Store(new("")) - rr.writerCount.Store(0) rr.readerCount.Store(1) l.resources[res] = rr @@ -424,14 +416,13 @@ func (l *locker) lockRead(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) rr.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, rr.notificationCh, rr.stopCh) + callb(rr.notificationCh, rr.stopCh) }() l.globalMu.unlock() @@ -490,7 +481,7 @@ func (l *locker) lockRead(ctx context.Context, res, id string, ttl int) bool { } // inconsistent, still have readers/writers after notification - if r.writerCount.Load() != 0 && r.readerCount.Load() != 0 { + if r.writerCount.Load() != 0 || r.readerCount.Load() != 0 { l.log.Error("inconsistent state, should be zero writers and zero readers", "resource", res, "id", id, @@ -507,14 +498,13 @@ func (l *locker) lockRead(ctx context.Context, res, id string, ttl int) bool { callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -534,21 +524,19 @@ func (l *locker) lockRead(ctx context.Context, res, id string, ttl int) bool { "resource", res, "id", id) // increase readers - r.writerCount.Store(0) r.readerCount.Add(1) // we have TTL, create callback callb, stopCbCh, updateTTLCh := l.makeLockCallback(res, id, ttl) r.locks.Store(id, &item{ - callback: callb, stopCh: stopCbCh, updateTTLCh: updateTTLCh, }) // run the callback go func() { - callb(id, r.notificationCh, r.stopCh) + callb(r.notificationCh, r.stopCh) }() r.resourceMu.unlock() @@ -698,11 +686,7 @@ func (l *locker) exists(ctx context.Context, res, id string) bool { // Special case, check if we have any locks if id == "*" { - if r.writerCount.Load() > 0 || r.readerCount.Load() > 0 { - return true - } - - return false + return r.writerCount.Load() > 0 || r.readerCount.Load() > 0 } if _, existsID := r.locks.Load(id); !existsID { @@ -751,15 +735,6 @@ func (l *locker) updateTTL(ctx context.Context, res, id string, ttl int) bool { "resource", res, "id", id) - if !ok { - l.log.Warn("no such resource", - "resource", res, - "id", id) - - r.resourceMu.unlockRelease() - return false - } - rl, ok := r.locks.Load(id) if !ok { l.log.Warn("no such resource ID", @@ -811,7 +786,7 @@ func (l *locker) makeLockCallback(res, id string, ttl int) (callback, chan struc updateTTLCh := make(chan int, 1) // at this point, when adding lock, we should not have the callback - return func(lockID string, notifCh chan<- struct{}, sCh <-chan struct{}) { + return func(notifCh chan<- struct{}, sCh <-chan struct{}) { // case for the items without TTL. We should add such items to control their flow cbttl := ttl if cbttl == 0 { @@ -820,45 +795,47 @@ func (l *locker) makeLockCallback(res, id string, ttl int) (callback, chan struc // TTL channel ta := time.NewTicker(time.Microsecond * time.Duration(cbttl)) - loop: - select { - case <-ta.C: - l.log.Debug("r/lock: ttl expired", - "resource", res, - "id", lockID, - "ttl microseconds", cbttl, - ) - ta.Stop() - // broadcast stop channel - case <-sCh: - l.log.Debug("r/lock: ttl removed, stop broadcast call", - "resource", res, - "id", lockID, - "ttl microseconds", cbttl, - ) - ta.Stop() - // item stop channel - case <-stopCbCh: - l.log.Debug("r/lock: ttl removed, stop callback call", - "resource", res, - "id", lockID, - "ttl microseconds", cbttl, - ) - ta.Stop() - case newTTL := <-updateTTLCh: - // if the new TTL is 0, we should treat it as unlimited - if newTTL == 0 { - newTTL = 31555952000000 // year + for { + select { + case <-ta.C: + l.log.Debug("r/lock: ttl expired", + "resource", res, + "id", id, + "ttl microseconds", cbttl, + ) + ta.Stop() + // broadcast stop channel + case <-sCh: + l.log.Debug("r/lock: ttl removed, stop broadcast call", + "resource", res, + "id", id, + "ttl microseconds", cbttl, + ) + ta.Stop() + // item stop channel + case <-stopCbCh: + l.log.Debug("r/lock: ttl removed, stop callback call", + "resource", res, + "id", id, + "ttl microseconds", cbttl, + ) + ta.Stop() + case newTTL := <-updateTTLCh: + // if the new TTL is 0, we should treat it as unlimited + if newTTL == 0 { + newTTL = 31555952000000 // year + } + l.log.Debug("r/lock: ttl was updated", + "resource", res, + "id", id, + "new ttl microseconds", newTTL) + // update the initial ttl + cbttl = newTTL + ta.Reset(time.Microsecond * time.Duration(cbttl)) + // in case of TTL, we don't need to remove the item, only update TTL + continue } - l.log.Debug("r/lock: ttl was updated", - "resource", res, - "id", id, - "new ttl microseconds", newTTL) - // update the initial ttl - cbttl = newTTL - ta.Reset(time.Microsecond * time.Duration(cbttl)) - // in case of TTL, we don't need to remove the item, only update TTL - goto loop + break } // unlimited but should not be long diff --git a/rpc.go b/rpc.go index f87cf7a..eede395 100644 --- a/rpc.go +++ b/rpc.go @@ -2,7 +2,7 @@ package lock import ( "context" - stderr "errors" + "errors" "time" "connectrpc.com/connect" @@ -11,7 +11,7 @@ import ( const defaultImmediateTimeout = time.Millisecond -var errEmptyID = stderr.New("empty ID is not allowed") +var errEmptyID = errors.New("empty ID is not allowed") type rpc struct { pl *Plugin diff --git a/tests/lock_test.go b/tests/lock_test.go index e95e833..d136ece 100644 --- a/tests/lock_test.go +++ b/tests/lock_test.go @@ -677,6 +677,111 @@ func TestForceRelease(t *testing.T) { assert.Equal(t, 2, oLogger.FilterMessageSnippet("r/lock: ttl removed, stop callback call").Len()) } +// startLockContainer brings up rpc + lock + logger on 127.0.0.1:6001 and returns +// a stop function the test must defer. Used by the deterministic coverage tests +// below (the api-test helper lives in lock_api_test.go, which is not part of the +// coverage build). +func startLockContainer(t *testing.T) func() { + t.Helper() + + cont := endure.New(slog.LevelError) + cfg := &config.Plugin{ + Version: "2024.2.0", + Path: "configs/.rr-lock-init.yaml", + } + + require.NoError(t, cont.RegisterAll( + cfg, + &logger.Plugin{}, + &rpcPlugin.Plugin{}, + &lockPlugin.Plugin{}, + )) + require.NoError(t, cont.Init()) + + ch, err := cont.Serve() + require.NoError(t, err) + + wg := &sync.WaitGroup{} + stop := make(chan struct{}) + wg.Go(func() { + select { + case e := <-ch: + assert.NoError(t, e.Error, "container reported error") + case <-stop: + } + }) + + time.Sleep(time.Second) // let rpc bind 6001 + + return func() { + close(stop) + require.NoError(t, cont.Stop()) + wg.Wait() + } +} + +// TestLockWaitThenAcquire covers the write-lock wait-then-acquire arm of lock(): +// a second writer blocks on the notification channel and acquires the lock once +// the holder's TTL expires (rather than timing out, as the other tests do). +func TestLockWaitThenAcquire(t *testing.T) { + defer startLockContainer(t)() + + // A holds the write lock; it expires on its own after ~2s. + ok, err := lock("wta", "A", 2*secMult, 0) + require.NoError(t, err) + require.True(t, ok) + + // B blocks waiting for the lock (up to its 10s wait) and acquires it once + // A's TTL expires, exercising the wait-then-acquire arm instead of timing out. + ok, err = lock("wta", "B", 10*secMult, 10*secMult) + require.NoError(t, err) + require.True(t, ok, "B should acquire the lock after A expires") +} + +// TestLockPromoteReadToWrite covers the read->write promotion arm of lock(): +// a single read lock is promoted to a write lock by the same id. +func TestLockPromoteReadToWrite(t *testing.T) { + defer startLockContainer(t)() + + ok, err := lockRead("promote", "X", 100*secMult, 0) // r=1, w=0 + require.NoError(t, err) + require.True(t, ok) + + // Same id requests a write lock: the promotion arm signals the reader to + // stop, waits for the notification, then takes the write lock. + ok, err = lock("promote", "X", 10*secMult, 10*secMult) + require.NoError(t, err) + require.True(t, ok) +} + +// TestExistsWildcard covers the exists() wildcard branch (id == "*"), which +// reports whether a resource holds any lock at all. +func TestExistsWildcard(t *testing.T) { + defer startLockContainer(t)() + + ok, err := lock("wild", "Y", 100*secMult, 0) + require.NoError(t, err) + require.True(t, ok) + + ok, err = exists("wild", "*") // resource has a writer -> true + require.NoError(t, err) + require.True(t, ok) + + ok, err = exists("absent", "*") // no such resource -> false + require.NoError(t, err) + require.False(t, ok) + + ok, err = release("wild", "Y") + require.NoError(t, err) + require.True(t, ok) + + time.Sleep(time.Second) // let the callback zero the counters + + ok, err = exists("wild", "*") // resource exists but holds no locks -> false + require.NoError(t, err) + require.False(t, ok) +} + const letterBytes = "abc" func randomString(n int) string {