Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ type App struct {
err error
clock fxclock.Clock
lifecycle *lifecycleWrapper
stopch chan struct{} // closed when Stop is called

container *dig.Container
root *module
Expand All @@ -295,13 +296,12 @@ type App struct {
// Decides how we react to errors when building the graph.
errorHooks []ErrorHandler
validate bool

// Used to signal shutdowns.
donesMu sync.Mutex // guards dones and shutdownSig
dones []chan os.Signal
shutdownSig os.Signal
waitsMu sync.Mutex // guards waits and shutdownCode
waits []chan ShutdownSignal
shutdownSignal *ShutdownSignal
shutdownMu sync.Mutex // guards sigReceivers and shutdownSig
sigReceivers []signalReceiver
shutdownSig *ShutdownSignal
signalOnce sync.Once

// Used to make sure Start/Stop is called only once.
runStart sync.Once
Expand Down Expand Up @@ -447,6 +447,7 @@ func New(opts ...Option) *App {
clock: fxclock.System,
startTimeout: DefaultTimeout,
stopTimeout: DefaultTimeout,
stopch: make(chan struct{}),
}
app.root = &module{app: app}
app.modules = append(app.modules, app.root)
Expand Down Expand Up @@ -605,12 +606,12 @@ func (app *App) Run() {
// Historically, we do not os.Exit(0) even though most applications
// cede control to Fx with they call app.Run. To avoid a breaking
// change, never os.Exit for success.
if code := app.run(app.Done()); code != 0 {
if code := app.run(app.Wait()); code != 0 {
app.exit(code)
}
}

func (app *App) run(done <-chan os.Signal) (exitCode int) {
func (app *App) run(done <-chan ShutdownSignal) (exitCode int) {
startCtx, cancel := app.clock.WithTimeout(context.Background(), app.StartTimeout())
defer cancel()

Expand All @@ -619,13 +620,13 @@ func (app *App) run(done <-chan os.Signal) (exitCode int) {
}

sig := <-done
app.log.LogEvent(&fxevent.Stopping{Signal: sig})
app.log.LogEvent(&fxevent.Stopping{Signal: sig.Signal})

stopCtx, cancel := app.clock.WithTimeout(context.Background(), app.StopTimeout())
defer cancel()

if err := app.Stop(stopCtx); err != nil {
return 1
return sig.ExitCode
}

return 0
Expand Down Expand Up @@ -715,6 +716,7 @@ func (app *App) Stop(ctx context.Context) (err error) {
// Protect the Stop hooks from being called multiple times.
defer func() {
app.log.LogEvent(&fxevent.Stopped{Err: err})
close(app.stopch)
}()

err = withTimeout(ctx, &withTimeoutParams{
Expand All @@ -735,36 +737,49 @@ func (app *App) Stop(ctx context.Context) (err error) {
// Alternatively, a signal can be broadcast to all done channels manually by
// using the Shutdown functionality (see the Shutdowner documentation for details).
func (app *App) Done() <-chan os.Signal {
c := make(chan os.Signal, 1)

app.donesMu.Lock()
defer app.donesMu.Unlock()
// If shutdown signal has been received already
// send it and return. If not, wait for user to send a termination
// signal.
if app.shutdownSig != nil {
c <- app.shutdownSig
return c
}

signal.Notify(c, os.Interrupt, _sigINT, _sigTERM)
app.dones = append(app.dones, c)
return c
rcv, ch := newOSSignalReceiver()
app.appendSignalReceiver(rcv)
return ch
}

func (app *App) Wait() <-chan ShutdownSignal {
c := make(chan ShutdownSignal, 1)

app.waitsMu.Lock()
defer app.waitsMu.Unlock()
rcv, ch := newShutdownSignalReceiver()
app.appendSignalReceiver(rcv)
return ch
}

if app.shutdownSignal != nil {
c <- *app.shutdownSignal
return c
}
func (app *App) appendSignalReceiver(r signalReceiver) {
app.shutdownMu.Lock()
defer app.shutdownMu.Unlock()

app.waits = append(app.waits, c)
return c
// If shutdown signal has been received already
// send it and return.
// If not, wait for user to send a termination signal.
if sig := app.shutdownSig; sig != nil {
// Ignore the error from ReceiveSignal.
// This is a newly created channel and can't possibly be
// blocked.
_ = r.ReceiveShutdownSignal(*sig)
return
}

app.sigReceivers = append(app.sigReceivers, r)

// The first time either Wait or Done is called,
// register an OS signal handler
// and make that broadcast the signal to all sigReceivers
// regardless of whether they're Wait or Done based.
app.signalOnce.Do(func() {
sigch := make(chan os.Signal, 1)
signal.Notify(sigch, os.Interrupt, _sigINT, _sigTERM)
go func() {
select {
case sig := <-sigch:
app.broadcastSignal(sig, 1)
case <-app.stopch:
}
}()
})
}

// StartTimeout returns the configured startup timeout. Apps default to using
Expand Down
117 changes: 65 additions & 52 deletions shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,55 @@
package fx

import (
"errors"
"fmt"
"os"

"go.uber.org/multierr"
)

var errReceiverBlocked = errors.New("receiver is blocked")

type signalReceiver interface {
ReceiveShutdownSignal(ShutdownSignal) error
}

type osSignalReceiver struct{ ch chan<- os.Signal }

var _ signalReceiver = (*osSignalReceiver)(nil)

func newOSSignalReceiver() (*osSignalReceiver, <-chan os.Signal) {
ch := make(chan os.Signal, 1)
return &osSignalReceiver{ch: ch}, ch
}

func (r *osSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error {
select {
case r.ch <- sig.Signal:
return nil
default:
return errReceiverBlocked
}
}

type shutdownSignalReceiver struct{ ch chan<- ShutdownSignal }

var _ signalReceiver = (*shutdownSignalReceiver)(nil)

func newShutdownSignalReceiver() (*shutdownSignalReceiver, <-chan ShutdownSignal) {
ch := make(chan ShutdownSignal, 1)
return &shutdownSignalReceiver{ch: ch}, ch
}

func (r *shutdownSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error {
select {
case r.ch <- sig:
return nil
default:
return errReceiverBlocked
}
}

// Shutdowner provides a method that can manually trigger the shutdown of the
// application by sending a signal to all open Done channels. Shutdowner works
// on applications using Run as well as Start, Done, and Stop. The Shutdowner is
Expand Down Expand Up @@ -81,70 +124,40 @@ func (app *App) shutdowner() Shutdowner {
}

func (app *App) broadcastSignal(signal os.Signal, code int) error {
return multierr.Combine(
app.broadcastDoneSignal(signal),
app.broadcastWaitSignal(signal, code),
)
}

func (app *App) broadcastDoneSignal(signal os.Signal) error {
app.donesMu.Lock()
defer app.donesMu.Unlock()

app.shutdownSig = signal
app.shutdownMu.Lock()
defer app.shutdownMu.Unlock()

var unsent int
for _, done := range app.dones {
select {
case done <- signal:
default:
// shutdown called when done channel has already received a
// termination signal that has not been cleared
unsent++
}
}

if unsent != 0 {
return ErrOnUnsentSignal{
Signal: signal,
Unsent: unsent,
Channels: len(app.dones),
}
}

return nil
}

func (app *App) broadcastWaitSignal(signal os.Signal, code int) error {
app.waitsMu.Lock()
defer app.waitsMu.Unlock()

app.shutdownSignal = &ShutdownSignal{
sig := ShutdownSignal{
Signal: signal,
ExitCode: code,
}
app.shutdownSig = &sig

var unsent int
for _, wait := range app.waits {
select {
case wait <- *app.shutdownSignal:
default:
// shutdown called when wait channel has already received a
// termination signal that has not been cleared
unsent++
var (
unsent int
resultErr error
)
for _, rcv := range app.sigReceivers {
// shutdown called when done channel has already received a
// termination signal that has not been cleared
if err := rcv.ReceiveShutdownSignal(sig); err != nil {
if errors.Is(err, errReceiverBlocked) {
unsent++
} else {
resultErr = multierr.Append(resultErr, err)
}
}
}

if unsent != 0 {
return ErrOnUnsentSignal{
resultErr = multierr.Append(resultErr, &ErrOnUnsentSignal{
Signal: signal,
Unsent: unsent,
Code: code,
Channels: len(app.waits),
}
Channels: len(app.sigReceivers),
})
}

return nil
return resultErr
}

// ErrOnUnsentSignal ... TBD
Expand All @@ -155,7 +168,7 @@ type ErrOnUnsentSignal struct {
Channels int
}

func (err ErrOnUnsentSignal) Error() string {
func (err *ErrOnUnsentSignal) Error() string {
return fmt.Sprintf(
"failed to send %v signal to %v out of %v channels",
err.Signal,
Expand Down
2 changes: 1 addition & 1 deletion shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestShutdown(t *testing.T) {

err := s.Shutdown()
assert.Error(t, err)
var o fx.ErrOnUnsentSignal
var o *fx.ErrOnUnsentSignal
assert.True(t, errors.As(err, &o))

assert.Equal(t, 1, o.Unsent)
Expand Down