diff --git a/app.go b/app.go index d7639ac20..1e348ff72 100644 --- a/app.go +++ b/app.go @@ -281,6 +281,7 @@ type App struct { err error clock fxclock.Clock lifecycle *lifecycleWrapper + stopch chan struct{} // closed when Stop is called container *dig.Container root *module @@ -295,13 +296,12 @@ type App struct { // Decides how we react to errors when building the graph. errorHooks []ErrorHandler validate bool + // Used to signal shutdowns. - donesMu sync.Mutex // guards dones and shutdownSig - dones []chan os.Signal - shutdownSig os.Signal - waitsMu sync.Mutex // guards waits and shutdownCode - waits []chan ShutdownSignal - shutdownSignal *ShutdownSignal + shutdownMu sync.Mutex // guards sigReceivers and shutdownSig + sigReceivers []signalReceiver + shutdownSig *ShutdownSignal + signalOnce sync.Once // Used to make sure Start/Stop is called only once. runStart sync.Once @@ -447,6 +447,7 @@ func New(opts ...Option) *App { clock: fxclock.System, startTimeout: DefaultTimeout, stopTimeout: DefaultTimeout, + stopch: make(chan struct{}), } app.root = &module{app: app} app.modules = append(app.modules, app.root) @@ -605,12 +606,12 @@ func (app *App) Run() { // Historically, we do not os.Exit(0) even though most applications // cede control to Fx with they call app.Run. To avoid a breaking // change, never os.Exit for success. - if code := app.run(app.Done()); code != 0 { + if code := app.run(app.Wait()); code != 0 { app.exit(code) } } -func (app *App) run(done <-chan os.Signal) (exitCode int) { +func (app *App) run(done <-chan ShutdownSignal) (exitCode int) { startCtx, cancel := app.clock.WithTimeout(context.Background(), app.StartTimeout()) defer cancel() @@ -619,13 +620,13 @@ func (app *App) run(done <-chan os.Signal) (exitCode int) { } sig := <-done - app.log.LogEvent(&fxevent.Stopping{Signal: sig}) + app.log.LogEvent(&fxevent.Stopping{Signal: sig.Signal}) stopCtx, cancel := app.clock.WithTimeout(context.Background(), app.StopTimeout()) defer cancel() if err := app.Stop(stopCtx); err != nil { - return 1 + return sig.ExitCode } return 0 @@ -715,6 +716,7 @@ func (app *App) Stop(ctx context.Context) (err error) { // Protect the Stop hooks from being called multiple times. defer func() { app.log.LogEvent(&fxevent.Stopped{Err: err}) + close(app.stopch) }() err = withTimeout(ctx, &withTimeoutParams{ @@ -735,36 +737,49 @@ func (app *App) Stop(ctx context.Context) (err error) { // Alternatively, a signal can be broadcast to all done channels manually by // using the Shutdown functionality (see the Shutdowner documentation for details). func (app *App) Done() <-chan os.Signal { - c := make(chan os.Signal, 1) - - app.donesMu.Lock() - defer app.donesMu.Unlock() - // If shutdown signal has been received already - // send it and return. If not, wait for user to send a termination - // signal. - if app.shutdownSig != nil { - c <- app.shutdownSig - return c - } - - signal.Notify(c, os.Interrupt, _sigINT, _sigTERM) - app.dones = append(app.dones, c) - return c + rcv, ch := newOSSignalReceiver() + app.appendSignalReceiver(rcv) + return ch } func (app *App) Wait() <-chan ShutdownSignal { - c := make(chan ShutdownSignal, 1) - - app.waitsMu.Lock() - defer app.waitsMu.Unlock() + rcv, ch := newShutdownSignalReceiver() + app.appendSignalReceiver(rcv) + return ch +} - if app.shutdownSignal != nil { - c <- *app.shutdownSignal - return c - } +func (app *App) appendSignalReceiver(r signalReceiver) { + app.shutdownMu.Lock() + defer app.shutdownMu.Unlock() - app.waits = append(app.waits, c) - return c + // If shutdown signal has been received already + // send it and return. + // If not, wait for user to send a termination signal. + if sig := app.shutdownSig; sig != nil { + // Ignore the error from ReceiveSignal. + // This is a newly created channel and can't possibly be + // blocked. + _ = r.ReceiveShutdownSignal(*sig) + return + } + + app.sigReceivers = append(app.sigReceivers, r) + + // The first time either Wait or Done is called, + // register an OS signal handler + // and make that broadcast the signal to all sigReceivers + // regardless of whether they're Wait or Done based. + app.signalOnce.Do(func() { + sigch := make(chan os.Signal, 1) + signal.Notify(sigch, os.Interrupt, _sigINT, _sigTERM) + go func() { + select { + case sig := <-sigch: + app.broadcastSignal(sig, 1) + case <-app.stopch: + } + }() + }) } // StartTimeout returns the configured startup timeout. Apps default to using diff --git a/shutdown.go b/shutdown.go index dcec1b23a..3ac7cea5c 100644 --- a/shutdown.go +++ b/shutdown.go @@ -21,12 +21,55 @@ package fx import ( + "errors" "fmt" "os" "go.uber.org/multierr" ) +var errReceiverBlocked = errors.New("receiver is blocked") + +type signalReceiver interface { + ReceiveShutdownSignal(ShutdownSignal) error +} + +type osSignalReceiver struct{ ch chan<- os.Signal } + +var _ signalReceiver = (*osSignalReceiver)(nil) + +func newOSSignalReceiver() (*osSignalReceiver, <-chan os.Signal) { + ch := make(chan os.Signal, 1) + return &osSignalReceiver{ch: ch}, ch +} + +func (r *osSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error { + select { + case r.ch <- sig.Signal: + return nil + default: + return errReceiverBlocked + } +} + +type shutdownSignalReceiver struct{ ch chan<- ShutdownSignal } + +var _ signalReceiver = (*shutdownSignalReceiver)(nil) + +func newShutdownSignalReceiver() (*shutdownSignalReceiver, <-chan ShutdownSignal) { + ch := make(chan ShutdownSignal, 1) + return &shutdownSignalReceiver{ch: ch}, ch +} + +func (r *shutdownSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error { + select { + case r.ch <- sig: + return nil + default: + return errReceiverBlocked + } +} + // Shutdowner provides a method that can manually trigger the shutdown of the // application by sending a signal to all open Done channels. Shutdowner works // on applications using Run as well as Start, Done, and Stop. The Shutdowner is @@ -81,70 +124,40 @@ func (app *App) shutdowner() Shutdowner { } func (app *App) broadcastSignal(signal os.Signal, code int) error { - return multierr.Combine( - app.broadcastDoneSignal(signal), - app.broadcastWaitSignal(signal, code), - ) -} - -func (app *App) broadcastDoneSignal(signal os.Signal) error { - app.donesMu.Lock() - defer app.donesMu.Unlock() - - app.shutdownSig = signal + app.shutdownMu.Lock() + defer app.shutdownMu.Unlock() - var unsent int - for _, done := range app.dones { - select { - case done <- signal: - default: - // shutdown called when done channel has already received a - // termination signal that has not been cleared - unsent++ - } - } - - if unsent != 0 { - return ErrOnUnsentSignal{ - Signal: signal, - Unsent: unsent, - Channels: len(app.dones), - } - } - - return nil -} - -func (app *App) broadcastWaitSignal(signal os.Signal, code int) error { - app.waitsMu.Lock() - defer app.waitsMu.Unlock() - - app.shutdownSignal = &ShutdownSignal{ + sig := ShutdownSignal{ Signal: signal, ExitCode: code, } + app.shutdownSig = &sig - var unsent int - for _, wait := range app.waits { - select { - case wait <- *app.shutdownSignal: - default: - // shutdown called when wait channel has already received a - // termination signal that has not been cleared - unsent++ + var ( + unsent int + resultErr error + ) + for _, rcv := range app.sigReceivers { + // shutdown called when done channel has already received a + // termination signal that has not been cleared + if err := rcv.ReceiveShutdownSignal(sig); err != nil { + if errors.Is(err, errReceiverBlocked) { + unsent++ + } else { + resultErr = multierr.Append(resultErr, err) + } } } if unsent != 0 { - return ErrOnUnsentSignal{ + resultErr = multierr.Append(resultErr, &ErrOnUnsentSignal{ Signal: signal, Unsent: unsent, - Code: code, - Channels: len(app.waits), - } + Channels: len(app.sigReceivers), + }) } - return nil + return resultErr } // ErrOnUnsentSignal ... TBD @@ -155,7 +168,7 @@ type ErrOnUnsentSignal struct { Channels int } -func (err ErrOnUnsentSignal) Error() string { +func (err *ErrOnUnsentSignal) Error() string { return fmt.Sprintf( "failed to send %v signal to %v out of %v channels", err.Signal, diff --git a/shutdown_test.go b/shutdown_test.go index 517bae739..20b48de51 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -131,7 +131,7 @@ func TestShutdown(t *testing.T) { err := s.Shutdown() assert.Error(t, err) - var o fx.ErrOnUnsentSignal + var o *fx.ErrOnUnsentSignal assert.True(t, errors.As(err, &o)) assert.Equal(t, 1, o.Unsent)