Skip to content

Commit d242dc2

Browse files
committed
App.Done/App.Wait: Share internals
This is a proposed change to #912 by @jasonmills that DRYs up internal state management by unifying `chan os.Signal` and `chan ShutdownSignal` into a single interface as suggested in this comment: #912 (comment) This change isn't quite right because mapping os.Signal to a ShutdownSignal currently relies on a goroutine which isn't reliably shut down -- so we have leaking tests. Note that this also fixes a behavioral bug in #912: `Wait()` channels would not resolve if a plain signal was received.
1 parent b6d4a49 commit d242dc2

File tree

3 files changed

+112
-84
lines changed

3 files changed

+112
-84
lines changed

app.go

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ type App struct {
281281
err error
282282
clock fxclock.Clock
283283
lifecycle *lifecycleWrapper
284+
stopch chan struct{} // closed when Stop is called
284285

285286
container *dig.Container
286287
root *module
@@ -295,13 +296,12 @@ type App struct {
295296
// Decides how we react to errors when building the graph.
296297
errorHooks []ErrorHandler
297298
validate bool
299+
298300
// Used to signal shutdowns.
299-
donesMu sync.Mutex // guards dones and shutdownSig
300-
dones []chan os.Signal
301-
shutdownSig os.Signal
302-
waitsMu sync.Mutex // guards waits and shutdownCode
303-
waits []chan ShutdownSignal
304-
shutdownSignal *ShutdownSignal
301+
shutdownMu sync.Mutex // guards sigReceivers and shutdownSig
302+
sigReceivers []signalReceiver
303+
shutdownSig *ShutdownSignal
304+
signalOnce sync.Once
305305

306306
// Used to make sure Start/Stop is called only once.
307307
runStart sync.Once
@@ -447,6 +447,7 @@ func New(opts ...Option) *App {
447447
clock: fxclock.System,
448448
startTimeout: DefaultTimeout,
449449
stopTimeout: DefaultTimeout,
450+
stopch: make(chan struct{}),
450451
}
451452
app.root = &module{app: app}
452453
app.modules = append(app.modules, app.root)
@@ -715,6 +716,7 @@ func (app *App) Stop(ctx context.Context) (err error) {
715716
// Protect the Stop hooks from being called multiple times.
716717
defer func() {
717718
app.log.LogEvent(&fxevent.Stopped{Err: err})
719+
close(app.stopch)
718720
}()
719721

720722
err = withTimeout(ctx, &withTimeoutParams{
@@ -735,36 +737,49 @@ func (app *App) Stop(ctx context.Context) (err error) {
735737
// Alternatively, a signal can be broadcast to all done channels manually by
736738
// using the Shutdown functionality (see the Shutdowner documentation for details).
737739
func (app *App) Done() <-chan os.Signal {
738-
c := make(chan os.Signal, 1)
739-
740-
app.donesMu.Lock()
741-
defer app.donesMu.Unlock()
742-
// If shutdown signal has been received already
743-
// send it and return. If not, wait for user to send a termination
744-
// signal.
745-
if app.shutdownSig != nil {
746-
c <- app.shutdownSig
747-
return c
748-
}
749-
750-
signal.Notify(c, os.Interrupt, _sigINT, _sigTERM)
751-
app.dones = append(app.dones, c)
752-
return c
740+
rcv, ch := newOSSignalReceiver()
741+
app.appendSignalReceiver(rcv)
742+
return ch
753743
}
754744

755745
func (app *App) Wait() <-chan ShutdownSignal {
756-
c := make(chan ShutdownSignal, 1)
757-
758-
app.waitsMu.Lock()
759-
defer app.waitsMu.Unlock()
746+
rcv, ch := newShutdownSignalReceiver()
747+
app.appendSignalReceiver(rcv)
748+
return ch
749+
}
760750

761-
if app.shutdownSignal != nil {
762-
c <- *app.shutdownSignal
763-
return c
764-
}
751+
func (app *App) appendSignalReceiver(r signalReceiver) {
752+
app.shutdownMu.Lock()
753+
defer app.shutdownMu.Unlock()
765754

766-
app.waits = append(app.waits, c)
767-
return c
755+
// If shutdown signal has been received already
756+
// send it and return.
757+
// If not, wait for user to send a termination signal.
758+
if sig := app.shutdownSig; sig != nil {
759+
// Ignore the error from ReceiveSignal.
760+
// This is a newly created channel and can't possibly be
761+
// blocked.
762+
_ = r.ReceiveShutdownSignal(*sig)
763+
return
764+
}
765+
766+
app.sigReceivers = append(app.sigReceivers, r)
767+
768+
// The first time either Wait or Done is called,
769+
// register an OS signal handler
770+
// and make that broadcast the signal to all sigReceivers
771+
// regardless of whether they're Wait or Done based.
772+
app.signalOnce.Do(func() {
773+
sigch := make(chan os.Signal, 1)
774+
signal.Notify(sigch, os.Interrupt, _sigINT, _sigTERM)
775+
go func() {
776+
select {
777+
case sig := <-sigch:
778+
app.broadcastSignal(sig, 1)
779+
case <-app.stopch:
780+
}
781+
}()
782+
})
768783
}
769784

770785
// StartTimeout returns the configured startup timeout. Apps default to using

shutdown.go

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,55 @@
2121
package fx
2222

2323
import (
24+
"errors"
2425
"fmt"
2526
"os"
2627

2728
"go.uber.org/multierr"
2829
)
2930

31+
var errReceiverBlocked = errors.New("receiver is blocked")
32+
33+
type signalReceiver interface {
34+
ReceiveShutdownSignal(ShutdownSignal) error
35+
}
36+
37+
type osSignalReceiver struct{ ch chan<- os.Signal }
38+
39+
var _ signalReceiver = (*osSignalReceiver)(nil)
40+
41+
func newOSSignalReceiver() (*osSignalReceiver, <-chan os.Signal) {
42+
ch := make(chan os.Signal, 1)
43+
return &osSignalReceiver{ch: ch}, ch
44+
}
45+
46+
func (r *osSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error {
47+
select {
48+
case r.ch <- sig.Signal:
49+
return nil
50+
default:
51+
return errReceiverBlocked
52+
}
53+
}
54+
55+
type shutdownSignalReceiver struct{ ch chan<- ShutdownSignal }
56+
57+
var _ signalReceiver = (*shutdownSignalReceiver)(nil)
58+
59+
func newShutdownSignalReceiver() (*shutdownSignalReceiver, <-chan ShutdownSignal) {
60+
ch := make(chan ShutdownSignal, 1)
61+
return &shutdownSignalReceiver{ch: ch}, ch
62+
}
63+
64+
func (r *shutdownSignalReceiver) ReceiveShutdownSignal(sig ShutdownSignal) error {
65+
select {
66+
case r.ch <- sig:
67+
return nil
68+
default:
69+
return errReceiverBlocked
70+
}
71+
}
72+
3073
// Shutdowner provides a method that can manually trigger the shutdown of the
3174
// application by sending a signal to all open Done channels. Shutdowner works
3275
// on applications using Run as well as Start, Done, and Stop. The Shutdowner is
@@ -81,70 +124,40 @@ func (app *App) shutdowner() Shutdowner {
81124
}
82125

83126
func (app *App) broadcastSignal(signal os.Signal, code int) error {
84-
return multierr.Combine(
85-
app.broadcastDoneSignal(signal),
86-
app.broadcastWaitSignal(signal, code),
87-
)
88-
}
89-
90-
func (app *App) broadcastDoneSignal(signal os.Signal) error {
91-
app.donesMu.Lock()
92-
defer app.donesMu.Unlock()
93-
94-
app.shutdownSig = signal
127+
app.shutdownMu.Lock()
128+
defer app.shutdownMu.Unlock()
95129

96-
var unsent int
97-
for _, done := range app.dones {
98-
select {
99-
case done <- signal:
100-
default:
101-
// shutdown called when done channel has already received a
102-
// termination signal that has not been cleared
103-
unsent++
104-
}
105-
}
106-
107-
if unsent != 0 {
108-
return ErrOnUnsentSignal{
109-
Signal: signal,
110-
Unsent: unsent,
111-
Channels: len(app.dones),
112-
}
113-
}
114-
115-
return nil
116-
}
117-
118-
func (app *App) broadcastWaitSignal(signal os.Signal, code int) error {
119-
app.waitsMu.Lock()
120-
defer app.waitsMu.Unlock()
121-
122-
app.shutdownSignal = &ShutdownSignal{
130+
sig := ShutdownSignal{
123131
Signal: signal,
124132
ExitCode: code,
125133
}
134+
app.shutdownSig = &sig
126135

127-
var unsent int
128-
for _, wait := range app.waits {
129-
select {
130-
case wait <- *app.shutdownSignal:
131-
default:
132-
// shutdown called when wait channel has already received a
133-
// termination signal that has not been cleared
134-
unsent++
136+
var (
137+
unsent int
138+
resultErr error
139+
)
140+
for _, rcv := range app.sigReceivers {
141+
// shutdown called when done channel has already received a
142+
// termination signal that has not been cleared
143+
if err := rcv.ReceiveShutdownSignal(sig); err != nil {
144+
if errors.Is(err, errReceiverBlocked) {
145+
unsent++
146+
} else {
147+
resultErr = multierr.Append(resultErr, err)
148+
}
135149
}
136150
}
137151

138152
if unsent != 0 {
139-
return ErrOnUnsentSignal{
153+
resultErr = multierr.Append(resultErr, &ErrOnUnsentSignal{
140154
Signal: signal,
141155
Unsent: unsent,
142-
Code: code,
143-
Channels: len(app.waits),
144-
}
156+
Channels: len(app.sigReceivers),
157+
})
145158
}
146159

147-
return nil
160+
return resultErr
148161
}
149162

150163
// ErrOnUnsentSignal ... TBD
@@ -155,7 +168,7 @@ type ErrOnUnsentSignal struct {
155168
Channels int
156169
}
157170

158-
func (err ErrOnUnsentSignal) Error() string {
171+
func (err *ErrOnUnsentSignal) Error() string {
159172
return fmt.Sprintf(
160173
"failed to send %v signal to %v out of %v channels",
161174
err.Signal,

shutdown_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func TestShutdown(t *testing.T) {
131131

132132
err := s.Shutdown()
133133
assert.Error(t, err)
134-
var o fx.ErrOnUnsentSignal
134+
var o *fx.ErrOnUnsentSignal
135135
assert.True(t, errors.As(err, &o))
136136

137137
assert.Equal(t, 1, o.Unsent)

0 commit comments

Comments
 (0)