From 358c94f34b73757ead3aaa9296d88bf1462363de Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Sun, 26 Oct 2025 20:39:39 +0100 Subject: [PATCH 1/4] websocket: do unsubscribe before closing connection --- graphql/websocket.go | 8 ++--- internal/integration/integration_test.go | 43 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/graphql/websocket.go b/graphql/websocket.go index ac83665d..0ed430e6 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -194,13 +194,13 @@ func (w *webSocketClient) Close() error { if w.conn == nil { return nil } - err := w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) + err := w.UnsubscribeAll() if err != nil { - return fmt.Errorf("failed to send closure message: %w", err) + return fmt.Errorf("failed to unsubscribe: %w", err) } - err = w.UnsubscribeAll() + err = w.conn.WriteMessage(closeMessage, formatCloseMessage(closeNormalClosure, "")) if err != nil { - return fmt.Errorf("failed to unsubscribe: %w", err) + return fmt.Errorf("failed to send closure message: %w", err) } w.Lock() defer w.Unlock() diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 042c20f6..68341f14 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -242,6 +242,49 @@ func TestSubscriptionConnectionParams(t *testing.T) { } } +func TestSubscriptionClose(t *testing.T) { + _ = `# @genqlient + subscription count { count }` + + ctx := context.Background() + server := server.RunServer() + defer server.Close() + + cases := []struct { + name string + unsub bool + }{ + { + name: "unsubscribed_manually", + unsub: true, + }, + { + name: "unsubscribed_automatically", + unsub: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + wsClient := newRoundtripWebSocketClient(t, server.URL) + + _, err := wsClient.Start(ctx) + require.NoError(t, err) + + _, subscriptionID, err := count(ctx, wsClient) + require.NoError(t, err) + + if tc.unsub { + err = wsClient.Unsubscribe(subscriptionID) + require.NoError(t, err) + } + + err = wsClient.Close() + require.NoError(t, err) + }) + } +} + func TestServerError(t *testing.T) { _ = `# @genqlient query failingQuery { fail me { id } }` From accdda9f3106114ca3dba66d59ef3819ea0793e8 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Sun, 26 Oct 2025 23:28:02 +0100 Subject: [PATCH 2/4] try to unbreak --- graphql/subscription.go | 9 +++++++++ graphql/websocket.go | 1 + 2 files changed, 10 insertions(+) diff --git a/graphql/subscription.go b/graphql/subscription.go index 8e20ba51..daa065ee 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -37,6 +37,15 @@ func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success return sub, success } +func (s *subscriptionMap) markUnsubscribed(subscriptionID string) { + s.Lock() + defer s.Unlock() + if sub, ok := s.map_[subscriptionID]; ok { + sub.hasBeenUnsubscribed = true + s.map_[subscriptionID] = sub + } +} + func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { s.Lock() defer s.Unlock() diff --git a/graphql/websocket.go b/graphql/websocket.go index 0ed430e6..b49fe66d 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -147,6 +147,7 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { return nil } if wsMsg.Type == webSocketTypeComplete { + w.subscriptions.markUnsubscribed(wsMsg.ID) reflect.ValueOf(sub.interfaceChan).Close() return nil } From 6c18c82223f1597d6d49258525e65ff320b6130c Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Sun, 26 Oct 2025 23:38:12 +0100 Subject: [PATCH 3/4] test infra --- internal/integration/generated.go | 56 +++++++++++++++++ internal/integration/integration_test.go | 4 +- internal/integration/schema.graphql | 1 + internal/integration/server/gqlgen_exec.go | 70 ++++++++++++++++++++++ internal/integration/server/server.go | 4 ++ 5 files changed, 133 insertions(+), 2 deletions(-) diff --git a/internal/integration/generated.go b/internal/integration/generated.go index 446f5d2c..fd034833 100644 --- a/internal/integration/generated.go +++ b/internal/integration/generated.go @@ -1319,6 +1319,14 @@ type countAuthorizedResponse struct { // GetCountAuthorized returns countAuthorizedResponse.CountAuthorized, and is useful for accessing the field via an interface. func (v *countAuthorizedResponse) GetCountAuthorized() int { return v.CountAuthorized } +// countCloseResponse is returned by countClose on success. +type countCloseResponse struct { + CountClose int `json:"countClose"` +} + +// GetCountClose returns countCloseResponse.CountClose, and is useful for accessing the field via an interface. +func (v *countCloseResponse) GetCountClose() int { return v.CountClose } + // countResponse is returned by count on success. type countResponse struct { Count int `json:"count"` @@ -3200,6 +3208,54 @@ func countAuthorizedForwardData(interfaceChan interface{}, jsonRawMsg json.RawMe return nil } +// The subscription executed by countClose. +const countClose_Operation = ` +subscription countClose { + countClose +} +` + +// To unsubscribe, use [graphql.WebSocketClient.Unsubscribe] +func countClose( + ctx_ context.Context, + client_ graphql.WebSocketClient, +) (dataChan_ chan countCloseWsResponse, subscriptionID_ string, err_ error) { + req_ := &graphql.Request{ + OpName: "countClose", + Query: countClose_Operation, + } + + dataChan_ = make(chan countCloseWsResponse) + subscriptionID_, err_ = client_.Subscribe(req_, dataChan_, countCloseForwardData) + + return dataChan_, subscriptionID_, err_ +} + +type countCloseWsResponse graphql.BaseResponse[*countCloseResponse] + +func countCloseForwardData(interfaceChan interface{}, jsonRawMsg json.RawMessage) error { + var gqlResp graphql.Response + var wsResp countCloseWsResponse + err := json.Unmarshal(jsonRawMsg, &gqlResp) + if err != nil { + return err + } + if len(gqlResp.Errors) == 0 { + err = json.Unmarshal(jsonRawMsg, &wsResp) + if err != nil { + return err + } + } else { + wsResp.Errors = gqlResp.Errors + } + dataChan_, ok := interfaceChan.(chan countCloseWsResponse) + if !ok { + return errors.New("failed to cast interface into 'chan countCloseWsResponse'") + } + dataChan_ <- wsResp + return nil +} + // The mutation executed by createUser. const createUser_Operation = ` mutation createUser ($user: NewUser!) { diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 68341f14..ca3ffae6 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -244,7 +244,7 @@ func TestSubscriptionConnectionParams(t *testing.T) { func TestSubscriptionClose(t *testing.T) { _ = `# @genqlient - subscription count { count }` + subscription countClose { countClose }` ctx := context.Background() server := server.RunServer() @@ -271,7 +271,7 @@ func TestSubscriptionClose(t *testing.T) { _, err := wsClient.Start(ctx) require.NoError(t, err) - _, subscriptionID, err := count(ctx, wsClient) + _, subscriptionID, err := countClose(ctx, wsClient) require.NoError(t, err) if tc.unsub { diff --git a/internal/integration/schema.graphql b/internal/integration/schema.graphql index 16e13cab..1e64ec29 100644 --- a/internal/integration/schema.graphql +++ b/internal/integration/schema.graphql @@ -20,6 +20,7 @@ type Mutation { type Subscription { count: Int! countAuthorized: Int! + countClose: Int! } type User implements Being & Lucky { diff --git a/internal/integration/server/gqlgen_exec.go b/internal/integration/server/gqlgen_exec.go index 4fd4a2f6..cd5e0b02 100644 --- a/internal/integration/server/gqlgen_exec.go +++ b/internal/integration/server/gqlgen_exec.go @@ -82,6 +82,7 @@ type ComplexityRoot struct { Subscription struct { Count func(childComplexity int) int CountAuthorized func(childComplexity int) int + CountClose func(childComplexity int) int } User struct { @@ -112,6 +113,7 @@ type QueryResolver interface { type SubscriptionResolver interface { Count(ctx context.Context) (<-chan int, error) CountAuthorized(ctx context.Context) (<-chan int, error) + CountClose(ctx context.Context) (<-chan int, error) } type executableSchema struct { @@ -306,6 +308,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Subscription.CountAuthorized(childComplexity), true + case "Subscription.countClose": + if e.complexity.Subscription.CountClose == nil { + break + } + + return e.complexity.Subscription.CountClose(childComplexity), true + case "User.birthdate": if e.complexity.User.Birthdate == nil { break @@ -500,6 +509,7 @@ type Mutation { type Subscription { count: Int! countAuthorized: Int! + countClose: Int! } type User implements Being & Lucky { @@ -2090,6 +2100,64 @@ func (ec *executionContext) fieldContext_Subscription_countAuthorized(_ context. return fc, nil } +func (ec *executionContext) _Subscription_countClose(ctx context.Context, field graphql.CollectedField) (ret func(ctx context.Context) graphql.Marshaler) { + fc, err := ec.fieldContext_Subscription_countClose(ctx, field) + if err != nil { + return nil + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().CountClose(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return nil + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func(ctx context.Context) graphql.Marshaler { + select { + case res, ok := <-resTmp.(<-chan int): + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNInt2int(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + case <-ctx.Done(): + return nil + } + } +} + +func (ec *executionContext) fieldContext_Subscription_countClose(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Subscription", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _User_id(ctx context.Context, field graphql.CollectedField, obj *User) (ret graphql.Marshaler) { fc, err := ec.fieldContext_User_id(ctx, field) if err != nil { @@ -4677,6 +4745,8 @@ func (ec *executionContext) _Subscription(ctx context.Context, sel ast.Selection return ec._Subscription_count(ctx, fields[0]) case "countAuthorized": return ec._Subscription_countAuthorized(ctx, fields[0]) + case "countClose": + return ec._Subscription_countClose(ctx, fields[0]) default: panic("unknown field " + strconv.Quote(fields[0].Name)) } diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index 71e08ad3..d8071ae1 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -182,6 +182,10 @@ func (s *subscriptionResolver) CountAuthorized(ctx context.Context) (<-chan int, return s.Count(ctx) } +func (s *subscriptionResolver) CountClose(ctx context.Context) (<-chan int, error) { + return s.Count(ctx) +} + const AuthKey = "authToken" type ( From b14a34847bf1b1a380099a18980f40f7555ef8e7 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Mon, 27 Oct 2025 09:09:48 +0100 Subject: [PATCH 4/4] handle concurrent unsubs --- graphql/subscription.go | 16 ---------------- graphql/websocket.go | 9 ++++++--- internal/integration/integration_test.go | 10 ++++++++-- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/graphql/subscription.go b/graphql/subscription.go index daa065ee..9d39d791 100644 --- a/graphql/subscription.go +++ b/graphql/subscription.go @@ -30,22 +30,6 @@ func (s *subscriptionMap) Create(subscriptionID string, interfaceChan interface{ } } -func (s *subscriptionMap) Read(subscriptionID string) (sub subscription, success bool) { - s.RLock() - defer s.RUnlock() - sub, success = s.map_[subscriptionID] - return sub, success -} - -func (s *subscriptionMap) markUnsubscribed(subscriptionID string) { - s.Lock() - defer s.Unlock() - if sub, ok := s.map_[subscriptionID]; ok { - sub.hasBeenUnsubscribed = true - s.map_[subscriptionID] = sub - } -} - func (s *subscriptionMap) Unsubscribe(subscriptionID string) error { s.Lock() defer s.Unlock() diff --git a/graphql/websocket.go b/graphql/websocket.go index b49fe66d..07e209d4 100644 --- a/graphql/websocket.go +++ b/graphql/websocket.go @@ -139,15 +139,18 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error { if wsMsg.ID == "" { // e.g. keep-alive messages return nil } - sub, ok := w.subscriptions.Read(wsMsg.ID) - if !ok { + w.subscriptions.Lock() + defer w.subscriptions.Unlock() + sub, success := w.subscriptions.map_[wsMsg.ID] + if !success { return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID) } if sub.hasBeenUnsubscribed { return nil } if wsMsg.Type == webSocketTypeComplete { - w.subscriptions.markUnsubscribed(wsMsg.ID) + sub.hasBeenUnsubscribed = true + w.subscriptions.map_[wsMsg.ID] = sub reflect.ValueOf(sub.interfaceChan).Close() return nil } diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index ca3ffae6..88e92380 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -104,7 +104,10 @@ func TestSubscription(t *testing.T) { dataChan, subscriptionID, err := count(ctx, wsClient) require.NoError(t, err) - defer wsClient.Close() + defer func() { + err := wsClient.Close() + require.NoError(t, err) + }() var ( counter = 0 @@ -198,7 +201,10 @@ func TestSubscriptionConnectionParams(t *testing.T) { dataChan, subscriptionID, err := countAuthorized(ctx, wsClient) require.NoError(t, err) - defer wsClient.Close() + defer func() { + err := wsClient.Close() + require.NoError(t, err) + }() var ( counter = 0