Skip to content

Commit 6ca7528

Browse files
handle concurrent unsubs
1 parent 6c18c82 commit 6ca7528

File tree

3 files changed

+12
-21
lines changed

3 files changed

+12
-21
lines changed

graphql/subscription.go

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,6 @@ func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{
3030
}
3131
}
3232

33-
func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) {
34-
s.RLock()
35-
defer s.RUnlock()
36-
sub, success = s.map_[subscriptionID]
37-
return sub, success
38-
}
39-
40-
func (s *subscriptionMap) markUnsubscribed(subscriptionID string) {
41-
s.Lock()
42-
defer s.Unlock()
43-
if sub, ok := s.map_[subscriptionID]; ok {
44-
sub.hasBeenUnsubscribed = true
45-
s.map_[subscriptionID] = sub
46-
}
47-
}
48-
4933
func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
5034
s.Lock()
5135
defer s.Unlock()

graphql/websocket.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,18 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
139139
if wsMsg.ID == "" { // e.g. keep-alive messages
140140
return nil
141141
}
142-
sub, ok := w.subscriptions.Read(wsMsg.ID)
143-
if !ok {
142+
w.subscriptions.Lock()
143+
defer w.subscriptions.Unlock()
144+
sub, success := w.subscriptions.map_[wsMsg.ID]
145+
if !success {
144146
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
145147
}
146148
if sub.hasBeenUnsubscribed {
147149
return nil
148150
}
149151
if wsMsg.Type == webSocketTypeComplete {
150-
w.subscriptions.markUnsubscribed(wsMsg.ID)
152+
sub.hasBeenUnsubscribed = true
153+
w.subscriptions.map_[wsMsg.ID] = sub
151154
reflect.ValueOf(sub.interfaceChan).Close()
152155
return nil
153156
}

internal/integration/integration_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ func TestSubscription(t *testing.T) {
104104

105105
dataChan, subscriptionID, err := count(ctx, wsClient)
106106
require.NoError(t, err)
107-
defer wsClient.Close()
107+
defer func() {
108+
require.NoError(t, wsClient.Close())
109+
}()
108110

109111
var (
110112
counter = 0
@@ -198,7 +200,9 @@ func TestSubscriptionConnectionParams(t *testing.T) {
198200

199201
dataChan, subscriptionID, err := countAuthorized(ctx, wsClient)
200202
require.NoError(t, err)
201-
defer wsClient.Close()
203+
defer func() {
204+
require.NoError(t, wsClient.Close())
205+
}()
202206

203207
var (
204208
counter = 0

0 commit comments

Comments
 (0)