Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
# Needed for the example-test to run.
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
go test -cover -v ./...
go test -race -cover -v ./...

lint:
name: Lint
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ lint:
internal/lint/golangci-lint run ./... --fix

check: lint
go test -cover ./...
go test -race -cover ./...
go mod tidy

.PHONY: example
85 changes: 51 additions & 34 deletions graphql/subscription.go
Original file line number Diff line number Diff line change
@@ -1,63 +1,80 @@
package graphql

import (
"fmt"
"reflect"
"sync"
)

// map of subscription ID to subscription
// subscriptionMap is a map of subscription ID to subscription.
// It is NOT thread-safe and must be protected by the caller's lock.
type subscriptionMap struct {
map_ map[string]subscription
sync.RWMutex
}

type subscription struct {
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
id string
hasBeenUnsubscribed bool
interfaceChan interface{}
forwardDataFunc ForwardDataFunction
id string
closed bool // true if the channel has been closed
}

func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
s.Lock()
defer s.Unlock()
// create adds a new subscription to the map.
// The caller must hold the webSocketClient lock.
func (s *subscriptionMap) create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) {
s.map_[subscriptionID] = subscription{
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
hasBeenUnsubscribed: false,
id: subscriptionID,
interfaceChan: interfaceChan,
forwardDataFunc: forwardDataFunc,
closed: false,
}
}

func (s *subscriptionMap) Unsubscribe(subscriptionID string) error {
s.Lock()
defer s.Unlock()
unsub, success := s.map_[subscriptionID]
if !success {
return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID)
// get retrieves a subscription by ID.
// The caller must hold the webSocketClient lock.
// Returns nil if not found.
func (s *subscriptionMap) get(subscriptionID string) *subscription {
sub, ok := s.map_[subscriptionID]
if !ok {
return nil
}
hasBeenUnsubscribed := unsub.hasBeenUnsubscribed
unsub.hasBeenUnsubscribed = true
s.map_[subscriptionID] = unsub
return &sub
}

if !hasBeenUnsubscribed {
reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close()
}
return nil
// update updates a subscription in the map.
// The caller must hold the webSocketClient lock.
func (s *subscriptionMap) update(subscriptionID string, sub subscription) {
s.map_[subscriptionID] = sub
}

func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) {
s.RLock()
defer s.RUnlock()
// getAllIDs returns all subscription IDs.
// The caller must hold the webSocketClient lock.
func (s *subscriptionMap) getAllIDs() []string {
subscriptionIDs := make([]string, 0, len(s.map_))
for subID := range s.map_ {
subscriptionIDs = append(subscriptionIDs, subID)
}
return subscriptionIDs
}

func (s *subscriptionMap) Delete(subscriptionID string) {
s.Lock()
defer s.Unlock()
// delete removes a subscription from the map.
// The caller must hold the webSocketClient lock.
func (s *subscriptionMap) delete(subscriptionID string) {
delete(s.map_, subscriptionID)
}

// closeChannel closes a subscription's channel if it hasn't been closed yet.
// The caller must hold the webSocketClient lock.
// Returns true if the channel was closed, false if it was already closed.
func (s *subscriptionMap) closeChannel(subscriptionID string) bool {
sub := s.get(subscriptionID)
if sub == nil || sub.closed {
return false
}

// Mark as closed before actually closing to prevent double-close
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed given you're in a lock for this whole function? In any case, I don't think it hurts.

sub.closed = true
s.update(subscriptionID, *sub)

// Close the channel
reflect.ValueOf(sub.interfaceChan).Close()
return true
}
60 changes: 28 additions & 32 deletions graphql/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,57 @@ import (
"testing"
)

func Test_subscriptionMap_Unsubscribe(t *testing.T) {
type args struct {
subscriptionID string
}
func Test_subscriptionMap_closeChannel(t *testing.T) {
tests := []struct {
name string
args args
sm subscriptionMap
wantErr bool
name string
sm subscriptionMap
subscriptionID string
wantClosed bool
}{
{
name: "unsubscribe existing subscription",
name: "close existing open channel",
sm: subscriptionMap{
map_: map[string]subscription{
"sub1": {
id: "sub1",
interfaceChan: make(chan struct{}),
forwardDataFunc: nil,
hasBeenUnsubscribed: false,
id: "sub1",
interfaceChan: make(chan struct{}),
closed: false,
},
},
},
args: args{subscriptionID: "sub1"},
wantErr: false,
subscriptionID: "sub1",
wantClosed: true,
},
{
name: "unsubscribe non-existent subscription",
sm: subscriptionMap{
map_: map[string]subscription{},
},
args: args{subscriptionID: "doesnotexist"},
wantErr: true,
},
{
name: "unsubscribe already unsubscribed subscription",
name: "close already closed channel",
sm: subscriptionMap{
map_: map[string]subscription{
"sub2": {
id: "sub2",
interfaceChan: nil,
forwardDataFunc: nil,
hasBeenUnsubscribed: true,
id: "sub2",
interfaceChan: make(chan struct{}),
closed: true,
},
},
},
args: args{subscriptionID: "sub2"},
wantErr: false,
subscriptionID: "sub2",
wantClosed: false,
},
{
name: "close non-existent subscription",
sm: subscriptionMap{
map_: map[string]subscription{},
},
subscriptionID: "doesnotexist",
wantClosed: false,
},
}
for i := range tests {
tt := &tests[i]
t.Run(tt.name, func(t *testing.T) {
s := &tt.sm
if err := s.Unsubscribe(tt.args.subscriptionID); (err != nil) != tt.wantErr {
t.Errorf("subscriptionMap.Unsubscribe() error = %v, wantErr %v", err, tt.wantErr)
gotClosed := s.closeChannel(tt.subscriptionID)
if gotClosed != tt.wantClosed {
t.Errorf("subscriptionMap.closeChannel() = %v, want %v", gotClosed, tt.wantClosed)
}
})
}
Expand Down
59 changes: 41 additions & 18 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -46,11 +45,11 @@ const (
type webSocketClient struct {
Dialer Dialer
header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
errChan chan error
subscriptions subscriptionMap
endpoint string
isClosing bool
sync.Mutex
}
Expand Down Expand Up @@ -108,13 +107,18 @@ func (w *webSocketClient) handleErr(err error) {
w.Lock()
defer w.Unlock()
if !w.isClosing {
// Send while holding lock to prevent Close() from closing
// the channel between our check and our send
w.errChan <- err
}
}

func (w *webSocketClient) listenWebSocket() {
for {
if w.isClosing {
w.Lock()
isClosing := w.isClosing
w.Unlock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it better to always follow this pattern?

w.Lock()
defer w.Unlock()

isClosing := w.isClosing

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that will work here since the lock is happening inside a loop. The defer won't run until the entire function exits, which means that the second time through the loop the Lock() would block since the lock is already being held from the first time.

You could make a separate function that just assigns w.isClosing, and use defer in that, but that's possibly overkill for this.

if isClosing {
return
}
_, message, err := w.conn.ReadMessage()
Expand All @@ -139,22 +143,31 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
if wsMsg.ID == "" { // e.g. keep-alive messages
return nil
}
w.subscriptions.Lock()
defer w.subscriptions.Unlock()
sub, success := w.subscriptions.map_[wsMsg.ID]
if !success {

w.Lock()
sub := w.subscriptions.get(wsMsg.ID)
if sub == nil {
w.Unlock()
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)
}
if sub.hasBeenUnsubscribed {
if sub.closed {
// Already closed, ignore message
w.Unlock()
return nil
}

if wsMsg.Type == webSocketTypeComplete {
sub.hasBeenUnsubscribed = true
w.subscriptions.map_[wsMsg.ID] = sub
reflect.ValueOf(sub.interfaceChan).Close()
// Server is telling us the subscription is complete
w.subscriptions.closeChannel(wsMsg.ID)
w.subscriptions.delete(wsMsg.ID)
w.Unlock()
return nil
}

// Forward the data to the subscription channel.
// We release the lock while calling the forward function to avoid holding
// the lock while doing potentially slow user code.
w.Unlock()
Comment on lines +147 to +170
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I do agree that having to remember to unlock in all the exit paths is fragile. It may make sense to pull all this code out into its own function so you can use defer to close.

return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload)
}

Expand Down Expand Up @@ -224,15 +237,21 @@ func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, for
}

subscriptionID := uuid.NewString()
w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc)

w.Lock()
w.subscriptions.create(subscriptionID, interfaceChan, forwardDataFunc)
w.Unlock()

subscriptionMsg := webSocketSendMessage{
Type: webSocketTypeSubscribe,
Payload: req,
ID: subscriptionID,
}
err := w.sendStructAsJSON(subscriptionMsg)
if err != nil {
w.subscriptions.Delete(subscriptionID)
w.Lock()
w.subscriptions.delete(subscriptionID)
w.Unlock()
return "", err
}
return subscriptionID, nil
Expand All @@ -247,15 +266,19 @@ func (w *webSocketClient) Unsubscribe(subscriptionID string) error {
if err != nil {
return err
}
err = w.subscriptions.Unsubscribe(subscriptionID)
if err != nil {
return err
}

w.Lock()
defer w.Unlock()
w.subscriptions.closeChannel(subscriptionID)
w.subscriptions.delete(subscriptionID)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old Unsubscribe didn't delete, if I'm reading the code right. Is this a bugfix?

return nil
}

func (w *webSocketClient) UnsubscribeAll() error {
subscriptionIDs := w.subscriptions.GetAllIDs()
w.Lock()
subscriptionIDs := w.subscriptions.getAllIDs()
w.Unlock()

for _, subscriptionID := range subscriptionIDs {
err := w.Unsubscribe(subscriptionID)
if err != nil {
Expand Down
10 changes: 4 additions & 6 deletions graphql/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ package graphql

import (
"encoding/json"
"sync"
"testing"
)

const testSubscriptionID = "test-subscription-id"

func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient {
func forgeTestWebSocketClient(closed bool) *webSocketClient {
return &webSocketClient{
subscriptions: subscriptionMap{
RWMutex: sync.RWMutex{},
map_: map[string]subscription{
testSubscriptionID: {
hasBeenUnsubscribed: hasBeenUnsubscribed,
interfaceChan: make(chan any),
closed: closed,
interfaceChan: make(chan any),
forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error {
return nil
},
Expand Down Expand Up @@ -60,7 +58,7 @@ func Test_webSocketClient_forwardWebSocketData(t *testing.T) {
wantErr: false,
},
{
name: "unsubscribed subscription",
name: "closed subscription",
args: args{message: []byte(`{"type":"next","id":"test-subscription-id","payload":{}}`)},
wc: forgeTestWebSocketClient(true),
wantErr: false,
Expand Down
Loading