Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ behavior.
One caveat is that while the number of concurrently running workers is limited,
task results are not and they accumulate until they are collected. Therefore,
if a large number of tasks can be expected, the workerpool should be
periodically drained (e.g. every 10k tasks).
periodically drained (e.g. every 10k tasks). Alternatively,
`WithResultCallback` can be used to process results as they complete, avoiding
accumulation entirely.

This package is mostly useful when tasks are CPU bound and spawning too many
routines would be detrimental to performance. It features a straightforward API
and no external dependencies. See the section below for a usage example.
and no external dependencies. See the sections below for usage examples.

## Example
## Example with Drain

```go
package main
Expand Down Expand Up @@ -52,7 +54,7 @@ func IsPrime(n int64) bool {

func main() {
wp := workerpool.New(runtime.NumCPU())
for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 {
for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 {
id := fmt.Sprintf("task #%d", i)
// Use Submit to submit tasks for processing. Submit blocks when no
// worker is available to pick up the task.
Expand All @@ -63,8 +65,9 @@ func main() {
}
return nil
})
// Submit fails when the pool is closed (ErrClosed) or being drained
// (ErrDrained). Check for the error when appropriate.
// Submit fails when the pool is closed (ErrClosed), being drained
// (ErrDraining), or the parent context is done (context.Canceled).
// Check for the error when appropriate.
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
Expand Down Expand Up @@ -93,3 +96,54 @@ func main() {
}
}
```

## Example with result callback

Use `WithResultCallback` to process each result as it completes rather than
accumulating them for a later `Drain` call. The callback receives a `Result`,
which extends `Task` with a `Duration()` method reporting how long the task
took to execute. This is useful for logging, metrics, or long-running pools
where unbounded result accumulation is undesirable.

```go
package main

import (
"context"
"fmt"
"log"
"os"
"runtime"

"github.com/cilium/workerpool"
)

func main() {
wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) {
if err := r.Err(); err != nil {
fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err)
} else {
fmt.Printf("task %s completed in %s\n", r, r.Duration())
}
}))

for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 {
id := fmt.Sprintf("task #%d", i)
err := wp.Submit(id, func(_ context.Context) error {
if IsPrime(n) {
fmt.Println(n, "is prime!")
}
return nil
})
if err != nil {
log.Fatal(err)
}
}

// Close waits for all in-flight tasks to complete before returning,
// ensuring all callback invocations have finished.
if err := wp.Close(); err != nil {
log.Fatal(err)
}
}
```
32 changes: 31 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package workerpool_test
import (
"context"
"fmt"
"log"
"os"
"runtime"

Expand All @@ -27,7 +28,7 @@ func IsPrime(n int64) bool {

func Example() {
wp := workerpool.New(runtime.NumCPU())
for i, n := 0, int64(1_000_000_000_000_000_000); n < 1_000_000_000_000_000_100; i, n = i+1, n+1 {
for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 {
id := fmt.Sprintf("task #%d", i)
// Use Submit to submit tasks for processing. Submit blocks when no
// worker is available to pick up the task.
Expand Down Expand Up @@ -67,3 +68,32 @@ func Example() {
fmt.Fprintln(os.Stderr, err)
}
}

func ExampleWithResultCallback() {
wp := workerpool.New(runtime.NumCPU(), workerpool.WithResultCallback(func(r workerpool.Result) {
if err := r.Err(); err != nil {
fmt.Fprintf(os.Stderr, "task %s failed after %s: %v\n", r, r.Duration(), err)
} else {
fmt.Printf("task %s completed in %s\n", r, r.Duration())
}
}))

for i, n := 0, int64(1_000_000_000_000_000_000); i < 100; i, n = i+1, n+1 {
id := fmt.Sprintf("task #%d", i)
err := wp.Submit(id, func(_ context.Context) error {
if IsPrime(n) {
fmt.Println(n, "is prime!")
}
return nil
})
if err != nil {
log.Fatal(err)
}
}

// Close waits for all in-flight tasks to complete before returning,
// ensuring all callback invocations have finished.
if err := wp.Close(); err != nil {
log.Fatal(err)
}
}
26 changes: 20 additions & 6 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package workerpool
import (
"context"
"fmt"
"time"
)

// Task is a unit of work.
Expand All @@ -17,26 +18,39 @@ type Task interface {
Err() error
}

// Result is a completed Task that also reports its execution duration.
// It is passed to the callback registered with WithResultCallback.
type Result interface {
Task
// Duration returns the time taken to execute the task.
Duration() time.Duration
}

type task struct {
run func(context.Context) error
id string
}

type taskResult struct {
err error
id string
err error
id string
duration time.Duration
}

// Ensure that taskResult implements the Task interface.
var _ Task = &taskResult{}
// Ensure that taskResult implements the Result interface.
var _ Result = &taskResult{}

// String implements fmt.Stringer for taskResult.
func (t *taskResult) String() string {
return t.id
}

// Err returns the error resulting from processing the taskResult. It ensures
// that the taskResult struct implements the Task interface.
// Err returns the error resulting from processing the taskResult.
func (t *taskResult) Err() error {
return t.err
}

// Duration returns the time taken to execute the task.
func (t *taskResult) Duration() time.Duration {
return t.duration
}
67 changes: 55 additions & 12 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
// limited, task results are not and they accumulate until they are collected.
// Therefore, if a large number of tasks can be expected, the workerpool should
// be periodically drained (e.g. every 10k tasks).
// Alternatively, use WithResultCallback to process results as they complete
// without accumulation.
package workerpool

import (
"context"
"errors"
"fmt"
"sync"
"time"
)

var (
Expand All @@ -30,18 +33,41 @@ var (
ErrDraining = errors.New("drain operation in progress")
// ErrClosed is returned when operations are attempted after a call to Close.
ErrClosed = errors.New("worker pool is closed")
// ErrCallbackSet is returned by Drain when a result callback has been
// registered via WithResultCallback.
ErrCallbackSet = errors.New("a result callback is set")
Comment thread
rolinh marked this conversation as resolved.
)

// Option configures a WorkerPool.
type Option func(*WorkerPool)

// WithResultCallback registers fn to be called each time a task completes.
// When a callback is set, results are not accumulated internally and Drain
// returns ErrCallbackSet. The callback may be invoked concurrently from
// multiple goroutines; fn must be safe for concurrent use.
// WithResultCallback panics if fn is nil.
func WithResultCallback(fn func(Result)) Option {
// TODO(v2): New/NewWithContext should return an error so that option
// validation can propagate errors instead of panicking.
if fn == nil {
panic("workerpool.WithResultCallback: fn must not be nil")
Comment thread
rolinh marked this conversation as resolved.
}
return func(wp *WorkerPool) {
wp.onResult = fn
}
}

// WorkerPool spawns, on demand, a number of worker routines to process
// submitted tasks concurrently. The number of concurrent routines never
// exceeds the specified limit.
type WorkerPool struct {
workers chan struct{}
tasks chan *task
done <-chan struct{}
cancel context.CancelFunc
results []Task
wg sync.WaitGroup
workers chan struct{}
tasks chan *task
done <-chan struct{}
cancel context.CancelFunc
onResult func(Result)
results []Task
wg sync.WaitGroup

mu sync.Mutex
draining bool
Expand All @@ -50,13 +76,13 @@ type WorkerPool struct {

// New creates a new pool of workers where at most n workers process submitted
// tasks concurrently. New panics if n ≤ 0.
func New(n int) *WorkerPool {
return NewWithContext(context.Background(), n)
func New(n int, opts ...Option) *WorkerPool {
return NewWithContext(context.Background(), n, opts...)
}

// NewWithContext creates a new pool of workers where at most n workers process submitted
// tasks concurrently. New panics if n ≤ 0. The context is used as the parent context to the context of the task func passed to Submit.
func NewWithContext(ctx context.Context, n int) *WorkerPool {
func NewWithContext(ctx context.Context, n int, opts ...Option) *WorkerPool {
if n <= 0 {
panic(fmt.Sprintf("workerpool.New: n must be > 0, got %d", n))
}
Expand All @@ -67,6 +93,9 @@ func NewWithContext(ctx context.Context, n int) *WorkerPool {
ctx, cancel := context.WithCancel(ctx)
wp.cancel = cancel
wp.done = ctx.Done()
for _, opt := range opts {
opt(wp)
}
go wp.run(ctx)
return wp
}
Expand Down Expand Up @@ -124,6 +153,7 @@ func (wp *WorkerPool) Submit(id string, f func(ctx context.Context) error) error
// tasks that have been processed.
// If a drain operation is already in progress, ErrDraining is returned.
// If the worker pool is closed, ErrClosed is returned.
// If a result callback is set via WithResultCallback, ErrCallbackSet is returned.
func (wp *WorkerPool) Drain() ([]Task, error) {
wp.mu.Lock()
if wp.closed {
Expand All @@ -134,6 +164,11 @@ func (wp *WorkerPool) Drain() ([]Task, error) {
wp.mu.Unlock()
return nil, ErrDraining
}
// TODO(v2): remove ErrCallbackSet — a pool configured with WithResultCallback should not expose Drain.
if wp.onResult != nil {
wp.mu.Unlock()
return nil, ErrCallbackSet
}
wp.draining = true
wp.mu.Unlock()

Expand All @@ -154,7 +189,9 @@ func (wp *WorkerPool) Drain() ([]Task, error) {

// Close closes the worker pool, rendering it unable to process new tasks.
// Close sends the cancellation signal to any running task and waits for all
// workers, if any, to return.
// workers, if any, to return. When a result callback is set via
// WithResultCallback, all callback invocations are guaranteed to have completed
// before Close returns.
// Close will return ErrClosed if it has already been called.
func (wp *WorkerPool) Close() error {
wp.mu.Lock()
Expand All @@ -181,15 +218,21 @@ func (wp *WorkerPool) Close() error {
// only be called once during the lifetime of a WorkerPool.
func (wp *WorkerPool) run(ctx context.Context) {
for t := range wp.tasks {

result := taskResult{id: t.id}
wp.results = append(wp.results, &result)
if wp.onResult == nil {
wp.results = append(wp.results, &result)
}
Comment thread
kaworu marked this conversation as resolved.
wp.workers <- struct{}{}
go func() {
defer wp.wg.Done()
start := time.Now()
if t.run != nil {
result.err = t.run(ctx)
}
result.duration = time.Since(start)
if wp.onResult != nil {
wp.onResult(&result)
}
<-wp.workers
}()
}
Expand Down
Loading
Loading