diff --git a/README.md b/README.md index b073964..58b1694 100644 --- a/README.md +++ b/README.md @@ -17,13 +17,15 @@ behavior. One caveat is that while the number of concurrently running workers is limited, task results are not and they accumulate until they are collected. Therefore, if a large number of tasks can be expected, the workerpool should be -periodically drained (e.g. every 10k tasks). +periodically drained (e.g. every 10k tasks). Alternatively, +`WithResultCallback` can be used to process results as they complete, avoiding +accumulation entirely. This package is mostly useful when tasks are CPU bound and spawning too many routines would be detrimental to performance. It features a straightforward API -and no external dependencies. See the section below for a usage example. +and no external dependencies. See the sections below for usage examples. -## Example +## Example with Drain ```go package main @@ -52,7 +54,7 @@ func IsPrime(n int64) bool { func main() { wp := workerpool.New(runtime.NumCPU()) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) // Use Submit to submit tasks for processing. Submit blocks when no // worker is available to pick up the task. @@ -63,8 +65,9 @@ func main() { } return nil }) - // Submit fails when the pool is closed (ErrClosed) or being drained - // (ErrDrained). Check for the error when appropriate. + // Submit fails when the pool is closed (ErrClosed), being drained + // (ErrDraining), or the parent context is done (context.Canceled). + // Check for the error when appropriate. if err != nil { fmt.Fprintln(os.Stderr, err) return @@ -93,3 +96,54 @@ func main() { } } ``` + +## Example with result callback + +Use `WithResultCallback` to process each result as it completes rather than +accumulating them for a later `Drain` call. The callback receives a `Result`, +which extends `Task` with a `Duration()` method reporting how long the task +took to execute. This is useful for logging, metrics, or long-running pools +where unbounded result accumulation is undesirable. + +```go +package main + +import ( + "context" + "fmt" + "log" + "os" + "runtime" + + "github.com/cilium/workerpool" +) + +func main() { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + if err := r.Err(); err != nil { + fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err) + } else { + fmt.Printf("task %s completed in %s\n", r, r.Duration()) + } + })) + + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { + id := fmt.Sprintf("task #%d", i) + err := wp.Submit(id, func(_ context.Context) error { + if IsPrime(n) { + fmt.Println(n, "is prime!") + } + return nil + }) + if err != nil { + log.Fatal(err) + } + } + + // Close waits for all in-flight tasks to complete before returning, + // ensuring all callback invocations have finished. + if err := wp.Close(); err != nil { + log.Fatal(err) + } +} +``` diff --git a/example_test.go b/example_test.go index 2504f34..d33b26c 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,7 @@ package workerpool_test import ( "context" "fmt" + "log" "os" "runtime" @@ -27,7 +28,7 @@ func IsPrime(n int64) bool { func Example() { wp := workerpool.New(runtime.NumCPU()) - for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 { + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { id := fmt.Sprintf("task #%d", i) // Use Submit to submit tasks for processing. Submit blocks when no // worker is available to pick up the task. @@ -67,3 +68,32 @@ func Example() { fmt.Fprintln(os.Stderr, err) } } + +func ExampleWithResultCallback() { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + if err := r.Err(); err != nil { + fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err) + } else { + fmt.Printf("task %s completed in %s\n", r, r.Duration()) + } + })) + + for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 { + id := fmt.Sprintf("task #%d", i) + err := wp.Submit(id, func(_ context.Context) error { + if IsPrime(n) { + fmt.Println(n, "is prime!") + } + return nil + }) + if err != nil { + log.Fatal(err) + } + } + + // Close waits for all in-flight tasks to complete before returning, + // ensuring all callback invocations have finished. + if err := wp.Close(); err != nil { + log.Fatal(err) + } +} diff --git a/task.go b/task.go index 4aa3eaf..c7a3207 100644 --- a/task.go +++ b/task.go @@ -6,6 +6,7 @@ package workerpool import ( "context" "fmt" + "time" ) // Task is a unit of work. @@ -17,26 +18,39 @@ type Task interface { Err() error } +// Result is a completed Task that also reports its execution duration. +// It is passed to the callback registered with WithResultCallback. +type Result interface { + Task + // Duration returns the time taken to execute the task. + Duration() time.Duration +} + type task struct { run func(context.Context) error id string } type taskResult struct { - err error - id string + err error + id string + duration time.Duration } -// Ensure that taskResult implements the Task interface. -var _ Task = &taskResult{} +// Ensure that taskResult implements the Result interface. +var _ Result = &taskResult{} // String implements fmt.Stringer for taskResult. func (t *taskResult) String() string { return t.id } -// Err returns the error resulting from processing the taskResult. It ensures -// that the taskResult struct implements the Task interface. +// Err returns the error resulting from processing the taskResult. func (t *taskResult) Err() error { return t.err } + +// Duration returns the time taken to execute the task. +func (t *taskResult) Duration() time.Duration { + return t.duration +} diff --git a/workerpool.go b/workerpool.go index 457fa9a..91080b1 100644 --- a/workerpool.go +++ b/workerpool.go @@ -15,6 +15,8 @@ // limited, task results are not and they accumulate until they are collected. // Therefore, if a large number of tasks can be expected, the workerpool should // be periodically drained (e.g. every 10k tasks). +// Alternatively, use WithResultCallback to process results as they complete +// without accumulation. package workerpool import ( @@ -22,6 +24,7 @@ import ( "errors" "fmt" "sync" + "time" ) var ( @@ -30,18 +33,41 @@ var ( ErrDraining = errors.New("drain operation in progress") // ErrClosed is returned when operations are attempted after a call to Close. ErrClosed = errors.New("worker pool is closed") + // ErrCallbackSet is returned by Drain when a result callback has been + // registered via WithResultCallback. + ErrCallbackSet = errors.New("a result callback is set") ) +// Option configures a WorkerPool. +type Option func(*WorkerPool) + +// WithResultCallback registers fn to be called each time a task completes. +// When a callback is set, results are not accumulated internally and Drain +// returns ErrCallbackSet. The callback may be invoked concurrently from +// multiple goroutines; fn must be safe for concurrent use. +// WithResultCallback panics if fn is nil. +func WithResultCallback(fn func(Result)) Option { + // TODO(v2): New/NewWithContext should return an error so that option + // validation can propagate errors instead of panicking. + if fn == nil { + panic("workerpool.WithResultCallback: fn must not be nil") + } + return func(wp *WorkerPool) { + wp.onResult = fn + } +} + // WorkerPool spawns, on demand, a number of worker routines to process // submitted tasks concurrently. The number of concurrent routines never // exceeds the specified limit. type WorkerPool struct { - workers chan struct{} - tasks chan *task - done <-chan struct{} - cancel context.CancelFunc - results []Task - wg sync.WaitGroup + workers chan struct{} + tasks chan *task + done <-chan struct{} + cancel context.CancelFunc + onResult func(Result) + results []Task + wg sync.WaitGroup mu sync.Mutex draining bool @@ -50,13 +76,13 @@ type WorkerPool struct { // New creates a new pool of workers where at most n workers process submitted // tasks concurrently. New panics if n ≤ 0. -func New(n int) *WorkerPool { - return NewWithContext(context.Background(), n) +func New(n int, opts ...Option) *WorkerPool { + return NewWithContext(context.Background(), n, opts...) } // NewWithContext creates a new pool of workers where at most n workers process submitted // tasks concurrently. New panics if n ≤ 0. The context is used as the parent context to the context of the task func passed to Submit. -func NewWithContext(ctx context.Context, n int) *WorkerPool { +func NewWithContext(ctx context.Context, n int, opts ...Option) *WorkerPool { if n <= 0 { panic(fmt.Sprintf("workerpool.New: n must be > 0, got %d", n)) } @@ -67,6 +93,9 @@ func NewWithContext(ctx context.Context, n int) *WorkerPool { ctx, cancel := context.WithCancel(ctx) wp.cancel = cancel wp.done = ctx.Done() + for _, opt := range opts { + opt(wp) + } go wp.run(ctx) return wp } @@ -124,6 +153,7 @@ func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error // tasks that have been processed. // If a drain operation is already in progress, ErrDraining is returned. // If the worker pool is closed, ErrClosed is returned. +// If a result callback is set via WithResultCallback, ErrCallbackSet is returned. func (wp *WorkerPool) Drain() ([]Task, error) { wp.mu.Lock() if wp.closed { @@ -134,6 +164,11 @@ func (wp *WorkerPool) Drain() ([]Task, error) { wp.mu.Unlock() return nil, ErrDraining } + // TODO(v2): remove ErrCallbackSet — a pool configured with WithResultCallback should not expose Drain. + if wp.onResult != nil { + wp.mu.Unlock() + return nil, ErrCallbackSet + } wp.draining = true wp.mu.Unlock() @@ -154,7 +189,9 @@ func (wp *WorkerPool) Drain() ([]Task, error) { // Close closes the worker pool, rendering it unable to process new tasks. // Close sends the cancellation signal to any running task and waits for all -// workers, if any, to return. +// workers, if any, to return. When a result callback is set via +// WithResultCallback, all callback invocations are guaranteed to have completed +// before Close returns. // Close will return ErrClosed if it has already been called. func (wp *WorkerPool) Close() error { wp.mu.Lock() @@ -181,15 +218,21 @@ func (wp *WorkerPool) Close() error { // only be called once during the lifetime of a WorkerPool. func (wp *WorkerPool) run(ctx context.Context) { for t := range wp.tasks { - result := taskResult{id: t.id} - wp.results = append(wp.results, &result) + if wp.onResult == nil { + wp.results = append(wp.results, &result) + } wp.workers <- struct{}{} go func() { defer wp.wg.Done() + start := time.Now() if t.run != nil { result.err = t.run(ctx) } + result.duration = time.Since(start) + if wp.onResult != nil { + wp.onResult(&result) + } <-wp.workers }() } diff --git a/workerpool_test.go b/workerpool_test.go index 66cbe5b..085f32c 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -15,6 +15,8 @@ import ( "github.com/cilium/workerpool" ) +var errTask = errors.New("task error") + func TestWorkerPoolNewPanics(t *testing.T) { // helper expecting New(n) to panic. testWorkerPoolNewPanics := func(n int) { @@ -30,6 +32,15 @@ func TestWorkerPoolNewPanics(t *testing.T) { testWorkerPoolNewPanics(-1) } +func TestWithResultCallbackNilPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("WithResultCallback(nil) should panic()") + } + }() + workerpool.WithResultCallback(nil) +} + func TestWorkerPoolTasksCapacity(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) defer func() { @@ -339,6 +350,21 @@ func TestWorkerPoolDrainAfterClose(t *testing.T) { } } +func TestWorkerPoolDrainAfterCloseWithCallback(t *testing.T) { + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(workerpool.Result) {})) + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + // ErrClosed must take precedence over ErrCallbackSet. + tasks, err := wp.Drain() + if !errors.Is(err, workerpool.ErrClosed) { + t.Errorf("got %v; want %v", err, workerpool.ErrClosed) + } + if tasks != nil { + t.Errorf("got %v as tasks; want %v", tasks, nil) + } +} + func TestWorkerPoolSubmitNil(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) defer func() { @@ -367,6 +393,32 @@ func TestWorkerPoolSubmitNil(t *testing.T) { } +func TestWorkerPoolSubmitNilWithCallback(t *testing.T) { + id := "nothing" + var got workerpool.Result + wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) { + got = r + })) + if err := wp.Submit(id, nil); err != nil { + t.Fatalf("got %v; want no error", err) + } + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + if got == nil { + t.Fatal("callback was not invoked") + } + if s := got.String(); s != id { + t.Errorf("String: got '%s', want '%s'", s, id) + } + if err := got.Err(); err != nil { + t.Errorf("Err: got '%v', want no error", err) + } + if got.Duration() < 0 { + t.Errorf("Duration: got %v, want >= 0", got.Duration()) + } +} + func TestWorkerPoolSubmitAfterClose(t *testing.T) { wp := workerpool.New(runtime.NumCPU()) if err := wp.Close(); err != nil { @@ -512,3 +564,67 @@ func TestWorkerPoolNewWithCancelledContext(t *testing.T) { t.Errorf("drain: got %d results, want 0", len(results)) } } + +func TestWorkerPoolWithResultCallback(t *testing.T) { + n := runtime.NumCPU() + + var mu sync.Mutex + var got []workerpool.Result + + wp := workerpool.New(n, workerpool.WithResultCallback(func(r workerpool.Result) { + mu.Lock() + defer mu.Unlock() + got = append(got, r) + })) + + numTasks := n + 2 + wantErr := errTask + for i := range numTasks { + id := fmt.Sprintf("task #%2d", i) + var f func(context.Context) error + if i == 0 { + f = func(_ context.Context) error { return wantErr } + } else { + f = func(_ context.Context) error { return nil } + } + if err := wp.Submit(id, f); err != nil { + t.Fatalf("failed to submit task '%s': %v", id, err) + } + } + + // Drain must return ErrCallbackSet. + tasks, err := wp.Drain() + if !errors.Is(err, workerpool.ErrCallbackSet) { + t.Errorf("drain: got %v, want %v", err, workerpool.ErrCallbackSet) + } + if tasks != nil { + t.Errorf("drain: got %v, want nil", tasks) + } + + // Close waits for all in-flight tasks, so after it returns all callbacks + // have been invoked. + if err := wp.Close(); err != nil { + t.Fatalf("close: got '%v', want no error", err) + } + + mu.Lock() + defer mu.Unlock() + + if len(got) != numTasks { + t.Fatalf("callback: got %d results, want %d", len(got), numTasks) + } + for _, r := range got { + if r.Duration() < 0 { + t.Errorf("%s: Duration: got %v, want >= 0", r, r.Duration()) + } + if r.String() == "task # 0" { + if !errors.Is(r.Err(), wantErr) { + t.Errorf("%s: Err: got %v, want %v", r, r.Err(), wantErr) + } + } else { + if r.Err() != nil { + t.Errorf("%s: Err: got %v, want nil", r, r.Err()) + } + } + } +}