diff --git a/app.go b/app.go index 0f51531ea..958a20e96 100644 --- a/app.go +++ b/app.go @@ -276,6 +276,9 @@ type App struct { clock fxclock.Clock lifecycle *lifecycleWrapper + stopch chan struct{} // closed when Stop is called + stopChLock sync.RWMutex // mutex for init and closing of stopch + container *dig.Container root *module modules []*module @@ -286,10 +289,16 @@ type App struct { // Decides how we react to errors when building the graph. errorHooks []ErrorHandler validate bool + // Used to signal shutdowns. - donesMu sync.Mutex // guards dones and shutdownSig - dones []chan os.Signal - shutdownSig os.Signal + shutdownMu sync.Mutex + shutdownSig *ShutdownSignal + sigReceivers []signalReceiver + signalOnce sync.Once + + // Used to make sure Start/Stop is called only once. + runStart sync.Once + runStop sync.Once osExit func(code int) // os.Exit override; used for testing only } @@ -394,6 +403,7 @@ func New(opts ...Option) *App { startTimeout: DefaultTimeout, stopTimeout: DefaultTimeout, } + app.root = &module{ app: app, // We start with a logger that writes to stderr. One of the @@ -544,27 +554,32 @@ func (app *App) Run() { // Historically, we do not os.Exit(0) even though most applications // cede control to Fx with they call app.Run. To avoid a breaking // change, never os.Exit for success. - if code := app.run(app.Done()); code != 0 { + if code := app.run(app.Wait()); code != 0 { app.exit(code) } } -func (app *App) run(done <-chan os.Signal) (exitCode int) { +func (app *App) run(done <-chan ShutdownSignal) (exitCode int) { startCtx, cancel := app.clock.WithTimeout(context.Background(), app.StartTimeout()) defer cancel() if err := app.Start(startCtx); err != nil { + app.closeStopChannel() return 1 } sig := <-done - app.log().LogEvent(&fxevent.Stopping{Signal: sig}) + app.log().LogEvent(&fxevent.Stopping{Signal: sig.Signal}) stopCtx, cancel := app.clock.WithTimeout(context.Background(), app.StopTimeout()) defer cancel() if err := app.Stop(stopCtx); err != nil { - return 1 + // if we encounter a timeout during stop, force exit code 1 + if errors.Is(err, context.DeadlineExceeded) { + return 1 + } + return sig.ExitCode } return 0 @@ -605,7 +620,9 @@ var ( // encountered any errors in application initialization. func (app *App) Start(ctx context.Context) (err error) { defer func() { - app.log().LogEvent(&fxevent.Started{Err: err}) + app.runStart.Do(func() { + app.log().LogEvent(&fxevent.Started{Err: err}) + }) }() if app.err != nil { @@ -613,6 +630,8 @@ func (app *App) Start(ctx context.Context) (err error) { return app.err } + app.initStopChannel() + return withTimeout(ctx, &withTimeoutParams{ hook: _onStartHook, callback: app.start, @@ -638,6 +657,30 @@ func (app *App) start(ctx context.Context) error { return nil } +func (app *App) initStopChannel() { + app.stopChLock.Lock() + defer app.stopChLock.Unlock() + if app.stopch == nil { + app.stopch = make(chan struct{}) + } +} + +func (app *App) stopChannel() chan struct{} { + app.stopChLock.RLock() + defer app.stopChLock.RUnlock() + ch := app.stopch + return ch +} + +func (app *App) closeStopChannel() { + app.stopChLock.Lock() + defer app.stopChLock.Unlock() + if app.stopch != nil { + close(app.stopch) + app.stopch = nil + } +} + // Stop gracefully stops the application. It executes any registered OnStop // hooks in reverse order, so that each constructor's stop hooks are called // before its dependencies' stop hooks. @@ -646,16 +689,23 @@ func (app *App) start(ctx context.Context) error { // called are executed. However, all those hooks are executed, even if some // fail. func (app *App) Stop(ctx context.Context) (err error) { + defer func() { - app.log().LogEvent(&fxevent.Stopped{Err: err}) + // Protect the Stop hooks from being called multiple times. + app.runStop.Do(func() { + app.log().LogEvent(&fxevent.Stopped{Err: err}) + app.closeStopChannel() + }) }() - return withTimeout(ctx, &withTimeoutParams{ + err = withTimeout(ctx, &withTimeoutParams{ hook: _onStopHook, callback: app.lifecycle.Stop, lifecycle: app.lifecycle, log: app.log(), }) + + return } // Done returns a channel of signals to block on after starting the @@ -666,21 +716,53 @@ func (app *App) Stop(ctx context.Context) (err error) { // Alternatively, a signal can be broadcast to all done channels manually by // using the Shutdown functionality (see the Shutdowner documentation for details). func (app *App) Done() <-chan os.Signal { - c := make(chan os.Signal, 1) + rcv, ch := newOSSignalReceiver() + app.appendSignalReceiver(rcv) + return ch +} - app.donesMu.Lock() - defer app.donesMu.Unlock() - // If shutdown signal has been received already - // send it and return. If not, wait for user to send a termination - // signal. - if app.shutdownSig != nil { - c <- app.shutdownSig - return c - } +func (app *App) Wait() <-chan ShutdownSignal { + rcv, ch := newShutdownSignalReceiver() + app.appendSignalReceiver(rcv) + return ch +} + +func (app *App) appendSignalReceiver(r signalReceiver) { + app.shutdownMu.Lock() + defer app.shutdownMu.Unlock() - signal.Notify(c, os.Interrupt, _sigINT, _sigTERM) - app.dones = append(app.dones, c) - return c + // If shutdown signal has been received already + // send it and return. + // If not, wait for user to send a termination signal. + if sig := app.shutdownSig; sig != nil { + // Ignore the error from ReceiveSignal. + // This is a newly created channel and can't possibly be + // blocked. + _ = r.ReceiveShutdownSignal(*sig) + return + } + + app.sigReceivers = append(app.sigReceivers, r) + + // The first time either Wait or Done is called, + // register an OS signal handler + // and make that broadcast the signal to all sigReceivers + // regardless of whether they're Wait or Done based. + app.signalOnce.Do(func() { + sigch := make(chan os.Signal, 1) + signal.Notify(sigch, os.Interrupt, _sigINT, _sigTERM) + go func() { + // if the stop channel is nil; that means that the app was never started + // thus, do not broadcast any signals + if stopch := app.stopChannel(); stopch != nil { + select { + case sig := <-sigch: + app.broadcastSignal(sig, 1) + case <-stopch: + } + } + }() + }) } // StartTimeout returns the configured startup timeout. Apps default to using diff --git a/app_internal_test.go b/app_internal_test.go index 82de85e2c..65263e1fd 100644 --- a/app_internal_test.go +++ b/app_internal_test.go @@ -22,7 +22,6 @@ package fx import ( "fmt" - "os" "sync" "testing" @@ -41,7 +40,7 @@ func TestAppRun(t *testing.T) { app := New( WithLogger(func() fxevent.Logger { return spy }), ) - done := make(chan os.Signal) + done := make(chan ShutdownSignal) var wg sync.WaitGroup wg.Add(1) @@ -50,7 +49,7 @@ func TestAppRun(t *testing.T) { app.run(done) }() - done <- _sigINT + done <- ShutdownSignal{Signal: _sigINT} wg.Wait() assert.Equal(t, []string{ diff --git a/app_test.go b/app_test.go index d99a58afd..a887fbb58 100644 --- a/app_test.go +++ b/app_test.go @@ -917,7 +917,7 @@ func TestAppRunTimeout(t *testing.T) { err, _ := errv.Interface().(error) assert.ErrorIs(t, err, context.DeadlineExceeded, - "should fail because of a timeout") + "should fail because of a timeout: %v", err) }) } } diff --git a/shutdown.go b/shutdown.go index d5b8488c0..9b6b992bf 100644 --- a/shutdown.go +++ b/shutdown.go @@ -21,10 +21,55 @@ package fx import ( + "errors" "fmt" "os" + + "go.uber.org/multierr" ) +var errReceiverBlocked = errors.New("receiver is blocked") + +type signalReceiver interface { + ReceiveShutdownSignal(ShutdownSignal) error +} + +type osSignalReceiver struct{ ch chan<- os.Signal } + +var _ signalReceiver = (*osSignalReceiver)(nil) + +func newOSSignalReceiver() (*osSignalReceiver, <-chan os.Signal) { + ch := make(chan os.Signal, 1) + return &osSignalReceiver{ch: ch}, ch +} + +func (r *osSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error { + select { + case r.ch <- sig.Signal: + return nil + default: + return errReceiverBlocked + } +} + +type shutdownSignalReceiver struct{ ch chan<- ShutdownSignal } + +var _ signalReceiver = (*shutdownSignalReceiver)(nil) + +func newShutdownSignalReceiver() (*shutdownSignalReceiver, <-chan ShutdownSignal) { + ch := make(chan ShutdownSignal, 1) + return &shutdownSignalReceiver{ch: ch}, ch +} + +func (r *shutdownSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error { + select { + case r.ch <- sig: + return nil + default: + return errReceiverBlocked + } +} + // Shutdowner provides a method that can manually trigger the shutdown of the // application by sending a signal to all open Done channels. Shutdowner works // on applications using Run as well as Start, Done, and Stop. The Shutdowner is @@ -39,8 +84,26 @@ type ShutdownOption interface { apply(*shutdowner) } +type shutdownCode int + +func (c shutdownCode) apply(s *shutdowner) { + s.exitCode = int(c) +} + +// ShutdownCode implements a shutdown option that allows a user specify the +// os.Exit code that an application should exit with. +func ShutdownCode(code int) ShutdownOption { + return shutdownCode(code) +} + type shutdowner struct { - app *App + exitCode int + app *App +} + +type ShutdownSignal struct { + Signal os.Signal + ExitCode int } // Shutdown broadcasts a signal to all of the application's Done channels @@ -49,35 +112,66 @@ type shutdowner struct { // In practice this means Shutdowner.Shutdown should not be called from an // fx.Invoke, but from a fx.Lifecycle.OnStart hook. func (s *shutdowner) Shutdown(opts ...ShutdownOption) error { - return s.app.broadcastSignal(_sigTERM) + for _, opt := range opts { + opt.apply(s) + } + + return s.app.broadcastSignal(_sigTERM, s.exitCode) } func (app *App) shutdowner() Shutdowner { return &shutdowner{app: app} } -func (app *App) broadcastSignal(signal os.Signal) error { - app.donesMu.Lock() - defer app.donesMu.Unlock() +func (app *App) broadcastSignal(signal os.Signal, code int) error { + app.shutdownMu.Lock() + defer app.shutdownMu.Unlock() - app.shutdownSig = signal + sig := ShutdownSignal{ + Signal: signal, + ExitCode: code, + } + app.shutdownSig = &sig - var unsent int - for _, done := range app.dones { - select { - case done <- signal: - default: - // shutdown called when done channel has already received a - // termination signal that has not been cleared - unsent++ + var ( + unsent int + resultErr error + ) + for _, rcv := range app.sigReceivers { + // shutdown called when done channel has already received a + // termination signal that has not been cleared + if err := rcv.ReceiveShutdownSignal(sig); err != nil { + if errors.Is(err, errReceiverBlocked) { + unsent++ + } else { + resultErr = multierr.Append(resultErr, err) + } } } if unsent != 0 { - return fmt.Errorf("failed to send %v signal to %v out of %v channels", - signal, unsent, len(app.dones), - ) + resultErr = multierr.Append(resultErr, &errOnUnsentSignal{ + Signal: signal, + Unsent: unsent, + Channels: len(app.sigReceivers), + }) } - return nil + return resultErr +} + +type errOnUnsentSignal struct { + Signal os.Signal + Unsent int + Code int + Channels int +} + +func (err *errOnUnsentSignal) Error() string { + return fmt.Sprintf( + "send %v signal: %v/%v channels are blocked", + err.Signal, + err.Unsent, + err.Channels, + ) } diff --git a/shutdown_code_example_test.go b/shutdown_code_example_test.go new file mode 100644 index 000000000..a96d1982d --- /dev/null +++ b/shutdown_code_example_test.go @@ -0,0 +1,47 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package fx_test + +import ( + "fmt" + "go.uber.org/fx" +) + +func ExampleShutdownCode() { + app := fx.New( + fx.Invoke(func(shutdowner fx.Shutdowner) { + // Call the shutdowner Shutdown method with a shutdown code + // option + shutdowner.Shutdown(fx.ShutdownCode(1)) + }), + ) + + app.Run() + + wait := app.Wait() + + signal := <-wait + + fmt.Printf("os.Exit(%v)\n", signal.ExitCode) + + // Output: + // os.Exit(1) +} diff --git a/shutdown_test.go b/shutdown_test.go index b6af93f13..c74367478 100644 --- a/shutdown_test.go +++ b/shutdown_test.go @@ -22,6 +22,7 @@ package fx_test import ( "context" + "fmt" "sync" "testing" @@ -64,7 +65,7 @@ func TestShutdown(t *testing.T) { defer app.RequireStart().RequireStop() assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") - assert.EqualError(t, s.Shutdown(), "failed to send terminated signal to 1 out of 1 channels", + assert.EqualError(t, s.Shutdown(), "send terminated signal: 1/1 channels are blocked", "unexpected error returned when shutdown is called with a blocked channel") assert.NotNil(t, <-done, "done channel did not receive signal") }) @@ -87,6 +88,78 @@ func TestShutdown(t *testing.T) { assert.NotNil(t, <-done1, "done channel 1 did not receive signal") assert.NotNil(t, <-done2, "done channel 2 did not receive signal") }) + + t.Run("shutdown app with exit code(s)", func(t *testing.T) { + t.Parallel() + + t.Run("default", func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New(t, fx.Populate(&s)) + + defer app.RequireStart().RequireStop() + + waits := append( + []<-chan fx.ShutdownSignal{}, + app.Wait(), + app.Wait(), + ) + + assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") + + for _, ch := range waits { + signal := <-ch + assert.NotEmpty(t, signal, "no shutdown signal") + assert.NotNil(t, signal.Signal) + assert.Zero(t, signal.ExitCode) + } + }) + + t.Run("unsent", func(t *testing.T) { + t.Parallel() + + var s fx.Shutdowner + app := fxtest.New( + t, + fx.Populate(&s), + ) + + wait := app.Wait() + defer app.RequireStart().RequireStop() + assert.NoError(t, s.Shutdown(), "error returned from first shutdown call") + + err := s.Shutdown() + assert.Error(t, err) + assert.NotNil(t, <-wait) + }) + + for expected := 0; expected <= 3; expected++ { + expected := expected + t.Run(fmt.Sprintf("with exit code %v", expected), func(t *testing.T) { + t.Parallel() + var s fx.Shutdowner + app := fxtest.New( + t, + fx.Populate(&s), + ) + + defer app.RequireStart().RequireStop() + + wait := app.Wait() + + assert.NoError( + t, + s.Shutdown(fx.ShutdownCode(expected)), + "error in app shutdown", + ) + + signal := <-wait + assert.NotEmpty(t, signal, "no shutdown signal") + assert.NotNil(t, signal.Signal) + assert.Equal(t, expected, signal.ExitCode) + }) + } + }) } func TestDataRace(t *testing.T) { @@ -98,6 +171,7 @@ func TestDataRace(t *testing.T) { fx.Populate(&s), ) require.NoError(t, app.Start(context.Background()), "error starting app") + defer require.NoError(t, app.Stop(context.Background()), "error stopping app") const N = 50 ready := make(chan struct{}) // used to orchestrate goroutines for Done() and ShutdownOption()