diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2467191d..eb52df90 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -29,7 +29,7 @@ jobs: # Needed for the example-test to run. GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - go test -cover -v ./... + go test -race -cover -v ./... lint: name: Lint diff --git a/Makefile b/Makefile index c1117278..45114a04 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ lint: internal/lint/golangci-lint run ./... --fix check: lint - go test -cover ./... + go test -race -cover ./... go mod tidy .PHONY: example diff --git a/graphql/subscription.go b/graphql/subscription.go index 9d39d791..895184a2 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -1,63 +1,80 @@ package graphql import ( - "fmt" "reflect" - "sync" ) -// map of subscription ID to subscription +// subscriptionMap is a map of subscription ID to subscription. +// It is NOT thread-safe and must be protected by the caller's lock. type subscriptionMap struct { map_ map[string]subscription - sync.RWMutex } type subscription struct { - interfaceChan interface{} - forwardDataFunc ForwardDataFunction - id string - hasBeenUnsubscribed bool + interfaceChan interface{} + forwardDataFunc ForwardDataFunction + id string + closed bool // true if the channel has been closed } -func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { - s.Lock() - defer s.Unlock() +// create adds a new subscription to the map. +// The caller must hold the webSocketClient lock. +func (s *subscriptionMap) create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { s.map_[subscriptionID] = subscription{ - id: subscriptionID, - interfaceChan: interfaceChan, - forwardDataFunc: forwardDataFunc, - hasBeenUnsubscribed: false, + id: subscriptionID, + interfaceChan: interfaceChan, + forwardDataFunc: forwardDataFunc, + closed: false, } } -func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { - s.Lock() - defer s.Unlock() - unsub, success := s.map_[subscriptionID] - if !success { - return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID) +// get retrieves a subscription by ID. +// The caller must hold the webSocketClient lock. +// Returns nil if not found. +func (s *subscriptionMap) get(subscriptionID string) *subscription { + sub, ok := s.map_[subscriptionID] + if !ok { + return nil } - hasBeenUnsubscribed := unsub.hasBeenUnsubscribed - unsub.hasBeenUnsubscribed = true - s.map_[subscriptionID] = unsub + return &sub +} - if !hasBeenUnsubscribed { - reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close() - } - return nil +// update updates a subscription in the map. +// The caller must hold the webSocketClient lock. +func (s *subscriptionMap) update(subscriptionID string, sub subscription) { + s.map_[subscriptionID] = sub } -func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) { - s.RLock() - defer s.RUnlock() +// getAllIDs returns all subscription IDs. +// The caller must hold the webSocketClient lock. +func (s *subscriptionMap) getAllIDs() []string { + subscriptionIDs := make([]string, 0, len(s.map_)) for subID := range s.map_ { subscriptionIDs = append(subscriptionIDs, subID) } return subscriptionIDs } -func (s *subscriptionMap) Delete(subscriptionID string) { - s.Lock() - defer s.Unlock() +// delete removes a subscription from the map. +// The caller must hold the webSocketClient lock. +func (s *subscriptionMap) delete(subscriptionID string) { delete(s.map_, subscriptionID) } + +// closeChannel closes a subscription's channel if it hasn't been closed yet. +// The caller must hold the webSocketClient lock. +// Returns true if the channel was closed, false if it was already closed. +func (s *subscriptionMap) closeChannel(subscriptionID string) bool { + sub := s.get(subscriptionID) + if sub == nil || sub.closed { + return false + } + + // Mark as closed before actually closing to prevent double-close + sub.closed = true + s.update(subscriptionID, *sub) + + // Close the channel + reflect.ValueOf(sub.interfaceChan).Close() + return true +} diff --git a/graphql/subscription_test.go b/graphql/subscription_test.go index f0270015..eb92d076 100644 --- a/graphql/subscription_test.go +++ b/graphql/subscription_test.go @@ -4,61 +4,57 @@ import ( "testing" ) -func Test_subscriptionMap_Unsubscribe(t *testing.T) { - type args struct { - subscriptionID string - } +func Test_subscriptionMap_closeChannel(t *testing.T) { tests := []struct { - name string - args args - sm subscriptionMap - wantErr bool + name string + sm subscriptionMap + subscriptionID string + wantClosed bool }{ { - name: "unsubscribe existing subscription", + name: "close existing open channel", sm: subscriptionMap{ map_: map[string]subscription{ "sub1": { - id: "sub1", - interfaceChan: make(chan struct{}), - forwardDataFunc: nil, - hasBeenUnsubscribed: false, + id: "sub1", + interfaceChan: make(chan struct{}), + closed: false, }, }, }, - args: args{subscriptionID: "sub1"}, - wantErr: false, + subscriptionID: "sub1", + wantClosed: true, }, { - name: "unsubscribe non-existent subscription", - sm: subscriptionMap{ - map_: map[string]subscription{}, - }, - args: args{subscriptionID: "doesnotexist"}, - wantErr: true, - }, - { - name: "unsubscribe already unsubscribed subscription", + name: "close already closed channel", sm: subscriptionMap{ map_: map[string]subscription{ "sub2": { - id: "sub2", - interfaceChan: nil, - forwardDataFunc: nil, - hasBeenUnsubscribed: true, + id: "sub2", + interfaceChan: make(chan struct{}), + closed: true, }, }, }, - args: args{subscriptionID: "sub2"}, - wantErr: false, + subscriptionID: "sub2", + wantClosed: false, + }, + { + name: "close non-existent subscription", + sm: subscriptionMap{ + map_: map[string]subscription{}, + }, + subscriptionID: "doesnotexist", + wantClosed: false, }, } for i := range tests { tt := &tests[i] t.Run(tt.name, func(t *testing.T) { s := &tt.sm - if err := s.Unsubscribe(tt.args.subscriptionID); (err != nil) != tt.wantErr { - t.Errorf("subscriptionMap.Unsubscribe() error = %v, wantErr %v", err, tt.wantErr) + gotClosed := s.closeChannel(tt.subscriptionID) + if gotClosed != tt.wantClosed { + t.Errorf("subscriptionMap.closeChannel() = %v, want %v", gotClosed, tt.wantClosed) } }) } diff --git a/graphql/websocket.go b/graphql/websocket.go index 07e209d4..31d2a5da 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "net/http" - "reflect" "strings" "sync" "time" @@ -46,11 +45,11 @@ const ( type webSocketClient struct { Dialer Dialer header http.Header - endpoint string conn WSConn connParams map[string]interface{} errChan chan error subscriptions subscriptionMap + endpoint string isClosing bool sync.Mutex } @@ -108,13 +107,18 @@ func (w *webSocketClient) handleErr(err error) { w.Lock() defer w.Unlock() if !w.isClosing { + // Send while holding lock to prevent Close() from closing + // the channel between our check and our send w.errChan <- err } } func (w *webSocketClient) listenWebSocket() { for { - if w.isClosing { + w.Lock() + isClosing := w.isClosing + w.Unlock() + if isClosing { return } _, message, err := w.conn.ReadMessage() @@ -139,22 +143,31 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if wsMsg.ID == "" { // e.g. keep-alive messages return nil } - w.subscriptions.Lock() - defer w.subscriptions.Unlock() - sub, success := w.subscriptions.map_[wsMsg.ID] - if !success { + + w.Lock() + sub := w.subscriptions.get(wsMsg.ID) + if sub == nil { + w.Unlock() return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) } - if sub.hasBeenUnsubscribed { + if sub.closed { + // Already closed, ignore message + w.Unlock() return nil } + if wsMsg.Type == webSocketTypeComplete { - sub.hasBeenUnsubscribed = true - w.subscriptions.map_[wsMsg.ID] = sub - reflect.ValueOf(sub.interfaceChan).Close() + // Server is telling us the subscription is complete + w.subscriptions.closeChannel(wsMsg.ID) + w.subscriptions.delete(wsMsg.ID) + w.Unlock() return nil } + // Forward the data to the subscription channel. + // We release the lock while calling the forward function to avoid holding + // the lock while doing potentially slow user code. + w.Unlock() return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) } @@ -224,7 +237,11 @@ func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, for } subscriptionID := uuid.NewString() - w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc) + + w.Lock() + w.subscriptions.create(subscriptionID, interfaceChan, forwardDataFunc) + w.Unlock() + subscriptionMsg := webSocketSendMessage{ Type: webSocketTypeSubscribe, Payload: req, @@ -232,7 +249,9 @@ func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, for } err := w.sendStructAsJSON(subscriptionMsg) if err != nil { - w.subscriptions.Delete(subscriptionID) + w.Lock() + w.subscriptions.delete(subscriptionID) + w.Unlock() return "", err } return subscriptionID, nil @@ -247,15 +266,19 @@ func (w *webSocketClient) Unsubscribe(subscriptionID string) error { if err != nil { return err } - err = w.subscriptions.Unsubscribe(subscriptionID) - if err != nil { - return err - } + + w.Lock() + defer w.Unlock() + w.subscriptions.closeChannel(subscriptionID) + w.subscriptions.delete(subscriptionID) return nil } func (w *webSocketClient) UnsubscribeAll() error { - subscriptionIDs := w.subscriptions.GetAllIDs() + w.Lock() + subscriptionIDs := w.subscriptions.getAllIDs() + w.Unlock() + for _, subscriptionID := range subscriptionIDs { err := w.Unsubscribe(subscriptionID) if err != nil { diff --git a/graphql/websocket_test.go b/graphql/websocket_test.go index dd0fa65a..47444e64 100644 --- a/graphql/websocket_test.go +++ b/graphql/websocket_test.go @@ -2,20 +2,18 @@ package graphql import ( "encoding/json" - "sync" "testing" ) const testSubscriptionID = "test-subscription-id" -func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient { +func forgeTestWebSocketClient(closed bool) *webSocketClient { return &webSocketClient{ subscriptions: subscriptionMap{ - RWMutex: sync.RWMutex{}, map_: map[string]subscription{ testSubscriptionID: { - hasBeenUnsubscribed: hasBeenUnsubscribed, - interfaceChan: make(chan any), + closed: closed, + interfaceChan: make(chan any), forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error { return nil }, @@ -60,7 +58,7 @@ func Test_webSocketClient_forwardWebSocketData(t *testing.T) { wantErr: false, }, { - name: "unsubscribed subscription", + name: "closed subscription", args: args{message: []byte(`{"type":"next","id":"test-subscription-id","payload":{}}`)}, wc: forgeTestWebSocketClient(true), wantErr: false,