-
Notifications
You must be signed in to change notification settings - Fork 139
Avoid deadlocks by consolidating locks #412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,63 +1,80 @@ | ||
| package graphql | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "reflect" | ||
| "sync" | ||
| ) | ||
|
|
||
| // map of subscription ID to subscription | ||
| // subscriptionMap is a map of subscription ID to subscription. | ||
| // It is NOT thread-safe and must be protected by the caller's lock. | ||
| type subscriptionMap struct { | ||
| map_ map[string]subscription | ||
| sync.RWMutex | ||
| } | ||
|
|
||
| type subscription struct { | ||
| interfaceChan interface{} | ||
| forwardDataFunc ForwardDataFunction | ||
| id string | ||
| hasBeenUnsubscribed bool | ||
| interfaceChan interface{} | ||
| forwardDataFunc ForwardDataFunction | ||
| id string | ||
| closed bool // true if the channel has been closed | ||
| } | ||
|
|
||
| func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { | ||
| s.Lock() | ||
| defer s.Unlock() | ||
| // create adds a new subscription to the map. | ||
| // The caller must hold the webSocketClient lock. | ||
| func (s *subscriptionMap) create(subscriptionID string, interfaceChan interface{}, forwardDataFunc ForwardDataFunction) { | ||
| s.map_[subscriptionID] = subscription{ | ||
| id: subscriptionID, | ||
| interfaceChan: interfaceChan, | ||
| forwardDataFunc: forwardDataFunc, | ||
| hasBeenUnsubscribed: false, | ||
| id: subscriptionID, | ||
| interfaceChan: interfaceChan, | ||
| forwardDataFunc: forwardDataFunc, | ||
| closed: false, | ||
| } | ||
| } | ||
|
|
||
| func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { | ||
| s.Lock() | ||
| defer s.Unlock() | ||
| unsub, success := s.map_[subscriptionID] | ||
| if !success { | ||
| return fmt.Errorf("tried to unsubscribe from unknown subscription with ID '%s'", subscriptionID) | ||
| // get retrieves a subscription by ID. | ||
| // The caller must hold the webSocketClient lock. | ||
| // Returns nil if not found. | ||
| func (s *subscriptionMap) get(subscriptionID string) *subscription { | ||
| sub, ok := s.map_[subscriptionID] | ||
| if !ok { | ||
| return nil | ||
| } | ||
| hasBeenUnsubscribed := unsub.hasBeenUnsubscribed | ||
| unsub.hasBeenUnsubscribed = true | ||
| s.map_[subscriptionID] = unsub | ||
| return &sub | ||
| } | ||
|
|
||
| if !hasBeenUnsubscribed { | ||
| reflect.ValueOf(s.map_[subscriptionID].interfaceChan).Close() | ||
| } | ||
| return nil | ||
| // update updates a subscription in the map. | ||
| // The caller must hold the webSocketClient lock. | ||
| func (s *subscriptionMap) update(subscriptionID string, sub subscription) { | ||
| s.map_[subscriptionID] = sub | ||
| } | ||
|
|
||
| func (s *subscriptionMap) GetAllIDs() (subscriptionIDs []string) { | ||
| s.RLock() | ||
| defer s.RUnlock() | ||
| // getAllIDs returns all subscription IDs. | ||
| // The caller must hold the webSocketClient lock. | ||
| func (s *subscriptionMap) getAllIDs() []string { | ||
| subscriptionIDs := make([]string, 0, len(s.map_)) | ||
| for subID := range s.map_ { | ||
| subscriptionIDs = append(subscriptionIDs, subID) | ||
| } | ||
| return subscriptionIDs | ||
| } | ||
|
|
||
| func (s *subscriptionMap) Delete(subscriptionID string) { | ||
| s.Lock() | ||
| defer s.Unlock() | ||
| // delete removes a subscription from the map. | ||
| // The caller must hold the webSocketClient lock. | ||
| func (s *subscriptionMap) delete(subscriptionID string) { | ||
| delete(s.map_, subscriptionID) | ||
| } | ||
|
|
||
| // closeChannel closes a subscription's channel if it hasn't been closed yet. | ||
| // The caller must hold the webSocketClient lock. | ||
| // Returns true if the channel was closed, false if it was already closed. | ||
| func (s *subscriptionMap) closeChannel(subscriptionID string) bool { | ||
| sub := s.get(subscriptionID) | ||
| if sub == nil || sub.closed { | ||
| return false | ||
| } | ||
|
|
||
| // Mark as closed before actually closing to prevent double-close | ||
| sub.closed = true | ||
| s.update(subscriptionID, *sub) | ||
|
|
||
| // Close the channel | ||
| reflect.ValueOf(sub.interfaceChan).Close() | ||
| return true | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,6 @@ import ( | |
| "encoding/json" | ||
| "fmt" | ||
| "net/http" | ||
| "reflect" | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
@@ -46,11 +45,11 @@ const ( | |
| type webSocketClient struct { | ||
| Dialer Dialer | ||
| header http.Header | ||
| endpoint string | ||
| conn WSConn | ||
| connParams map[string]interface{} | ||
| errChan chan error | ||
| subscriptions subscriptionMap | ||
| endpoint string | ||
| isClosing bool | ||
| sync.Mutex | ||
| } | ||
|
|
@@ -108,13 +107,18 @@ func (w *webSocketClient) handleErr(err error) { | |
| w.Lock() | ||
| defer w.Unlock() | ||
| if !w.isClosing { | ||
| // Send while holding lock to prevent Close() from closing | ||
| // the channel between our check and our send | ||
| w.errChan <- err | ||
| } | ||
| } | ||
|
|
||
| func (w *webSocketClient) listenWebSocket() { | ||
| for { | ||
| if w.isClosing { | ||
| w.Lock() | ||
| isClosing := w.isClosing | ||
| w.Unlock() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it better to always follow this pattern? w.Lock()
defer w.Unlock()
isClosing := w.isClosing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that will work here since the lock is happening inside a loop. The You could make a separate function that just assigns |
||
| if isClosing { | ||
| return | ||
| } | ||
| _, message, err := w.conn.ReadMessage() | ||
|
|
@@ -139,22 +143,31 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { | |
| if wsMsg.ID == "" { // e.g. keep-alive messages | ||
| return nil | ||
| } | ||
| w.subscriptions.Lock() | ||
| defer w.subscriptions.Unlock() | ||
| sub, success := w.subscriptions.map_[wsMsg.ID] | ||
| if !success { | ||
|
|
||
| w.Lock() | ||
| sub := w.subscriptions.get(wsMsg.ID) | ||
| if sub == nil { | ||
| w.Unlock() | ||
| return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) | ||
| } | ||
| if sub.hasBeenUnsubscribed { | ||
| if sub.closed { | ||
| // Already closed, ignore message | ||
| w.Unlock() | ||
| return nil | ||
| } | ||
|
|
||
| if wsMsg.Type == webSocketTypeComplete { | ||
| sub.hasBeenUnsubscribed = true | ||
| w.subscriptions.map_[wsMsg.ID] = sub | ||
| reflect.ValueOf(sub.interfaceChan).Close() | ||
| // Server is telling us the subscription is complete | ||
| w.subscriptions.closeChannel(wsMsg.ID) | ||
| w.subscriptions.delete(wsMsg.ID) | ||
| w.Unlock() | ||
| return nil | ||
| } | ||
|
|
||
| // Forward the data to the subscription channel. | ||
| // We release the lock while calling the forward function to avoid holding | ||
| // the lock while doing potentially slow user code. | ||
| w.Unlock() | ||
|
Comment on lines
+147
to
+170
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, I do agree that having to remember to unlock in all the exit paths is fragile. It may make sense to pull all this code out into its own function so you can use |
||
| return sub.forwardDataFunc(sub.interfaceChan, wsMsg.Payload) | ||
| } | ||
|
|
||
|
|
@@ -224,15 +237,21 @@ func (w *webSocketClient) Subscribe(req *Request, interfaceChan interface{}, for | |
| } | ||
|
|
||
| subscriptionID := uuid.NewString() | ||
| w.subscriptions.Create(subscriptionID, interfaceChan, forwardDataFunc) | ||
|
|
||
| w.Lock() | ||
| w.subscriptions.create(subscriptionID, interfaceChan, forwardDataFunc) | ||
| w.Unlock() | ||
|
|
||
| subscriptionMsg := webSocketSendMessage{ | ||
| Type: webSocketTypeSubscribe, | ||
| Payload: req, | ||
| ID: subscriptionID, | ||
| } | ||
| err := w.sendStructAsJSON(subscriptionMsg) | ||
| if err != nil { | ||
| w.subscriptions.Delete(subscriptionID) | ||
| w.Lock() | ||
| w.subscriptions.delete(subscriptionID) | ||
| w.Unlock() | ||
| return "", err | ||
| } | ||
| return subscriptionID, nil | ||
|
|
@@ -247,15 +266,19 @@ func (w *webSocketClient) Unsubscribe(subscriptionID string) error { | |
| if err != nil { | ||
| return err | ||
| } | ||
| err = w.subscriptions.Unsubscribe(subscriptionID) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| w.Lock() | ||
| defer w.Unlock() | ||
| w.subscriptions.closeChannel(subscriptionID) | ||
| w.subscriptions.delete(subscriptionID) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old |
||
| return nil | ||
| } | ||
|
|
||
| func (w *webSocketClient) UnsubscribeAll() error { | ||
| subscriptionIDs := w.subscriptions.GetAllIDs() | ||
| w.Lock() | ||
| subscriptionIDs := w.subscriptions.getAllIDs() | ||
| w.Unlock() | ||
|
|
||
| for _, subscriptionID := range subscriptionIDs { | ||
| err := w.Unsubscribe(subscriptionID) | ||
| if err != nil { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed given you're in a lock for this whole function? In any case, I don't think it hurts.