diff --git a/client.go b/client.go index be20ed489..db82b4182 100644 --- a/client.go +++ b/client.go @@ -386,7 +386,7 @@ func (c *Client) receiveLoop() error { // createStream creates a new stream and registers it with the client // Introduce stream types for multiple or single response -func (c *Client) createStream(flags uint8, b []byte) (*stream, error) { +func (c *Client) createStream(flags uint8, b []byte, recvBuf int) (*stream, error) { // sendLock must be held across both allocation of the stream ID and sending it across the wire. // This ensures that new stream IDs sent on the wire are always increasing, which is a // requirement of the TTRPC protocol. @@ -417,7 +417,7 @@ func (c *Client) createStream(flags uint8, b []byte) (*stream, error) { default: } - s = newStream(c.nextStreamID, c) + s = newStream(c.nextStreamID, c, recvBuf) c.streams[s.id] = s c.nextStreamID = c.nextStreamID + 2 @@ -517,7 +517,7 @@ func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, metho } else { flags = flagRemoteClosed } - s, err := c.createStream(flags, p) + s, err := c.createStream(flags, p, streamRecvBufferSize) if err != nil { return nil, err } @@ -536,7 +536,7 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err return err } - s, err := c.createStream(0, p) + s, err := c.createStream(0, p, 1) if err != nil { return err } diff --git a/errors.go b/errors.go index 632dbe8bd..1e6f6b9c9 100644 --- a/errors.go +++ b/errors.go @@ -36,6 +36,12 @@ var ( // ErrStreamClosed is when the streaming connection is closed. ErrStreamClosed = errors.New("ttrpc: stream closed") + + // ErrStreamFull is returned when a stream's receive buffer is full + // and the message cannot be delivered without blocking the + // connection's receive loop. This prevents a single unconsumed + // stream from deadlocking all other streams on the same connection. + ErrStreamFull = errors.New("ttrpc: stream buffer full") ) // OversizedMessageErr is used to indicate refusal to send an oversized message. diff --git a/go.mod b/go.mod index bd6b82027..413e01a7b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/containerd/ttrpc -go 1.22 +go 1.23 require ( github.com/containerd/log v0.1.0 diff --git a/services.go b/services.go index 6d092bf95..ac7c752f7 100644 --- a/services.go +++ b/services.go @@ -23,6 +23,7 @@ import ( "io" "os" "path" + "time" "unsafe" "google.golang.org/grpc/codes" @@ -128,10 +129,14 @@ func (s *serviceSet) handle(ctx context.Context, req *Request, respond func(*sta StreamingClient: stream.StreamingClient, StreamingServer: stream.StreamingServer, } + recvBuf := streamRecvBufferSize + if !stream.StreamingClient { + recvBuf = 1 + } sh := &streamHandler{ ctx: ctx, respond: respond, - recv: make(chan Unmarshaler, 5), + recv: make(chan Unmarshaler, recvBuf), info: info, } go func() { @@ -158,6 +163,12 @@ func (s *serviceSet) handle(ctx context.Context, req *Request, respond func(*sta return nil, status.Errorf(codes.Unimplemented, "method %v", req.Method) } +// streamRecvBufferSize is the buffer size for stream recv channels. It +// should be large enough to absorb normal bursts without hitting the +// 1-second timeout fallback in receive/data, but small enough that +// per-stream memory overhead stays trivial. +const streamRecvBufferSize = 64 + type streamHandler struct { ctx context.Context respond func(*status.Status, []byte, bool, bool) error @@ -184,6 +195,17 @@ func (s *streamHandler) data(unmarshal Unmarshaler) error { return nil case <-s.ctx.Done(): return s.ctx.Err() + default: + // If recv channel is full, wait up to a second for an item + // to drain and unblock, otherwise return an error. + select { + case s.recv <- unmarshal: + return nil + case <-s.ctx.Done(): + return s.ctx.Err() + case <-time.After(time.Second): + return ErrStreamFull + } } } diff --git a/stream.go b/stream.go index 739a4c967..a6a71def6 100644 --- a/stream.go +++ b/stream.go @@ -19,6 +19,7 @@ package ttrpc import ( "context" "sync" + "time" ) type streamID uint32 @@ -38,11 +39,11 @@ type stream struct { recvClose chan struct{} } -func newStream(id streamID, send sender) *stream { +func newStream(id streamID, send sender, recvBuf int) *stream { return &stream{ id: id, sender: send, - recv: make(chan *streamMessage, 1), + recv: make(chan *streamMessage, recvBuf), recvClose: make(chan struct{}), } } @@ -63,6 +64,11 @@ func (s *stream) send(mt messageType, flags uint8, b []byte) error { return s.sender.send(uint32(s.id), mt, flags, b) } +// receive delivers a message to this stream from the connection receive loop. +// If the stream's recv buffer is full, it waits up to 1 second for the +// consumer to make progress. This keeps the receive loop moving for other +// streams while still providing backpressure under normal operation. If the +// timeout expires the stream is closed with ErrStreamFull. func (s *stream) receive(ctx context.Context, msg *streamMessage) error { select { case <-s.recvClose: @@ -76,6 +82,20 @@ func (s *stream) receive(ctx context.Context, msg *streamMessage) error { return nil case <-ctx.Done(): return ctx.Err() + default: + // If recv channel is full, wait up to a second for an item + // to drain and unblock, otherwise close the stream. + select { + case <-s.recvClose: + return s.recvErr + case s.recv <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Second): + s.closeWithError(ErrStreamFull) + return ErrStreamFull + } } } diff --git a/stream_full_test.go b/stream_full_test.go new file mode 100644 index 000000000..f454d3135 --- /dev/null +++ b/stream_full_test.go @@ -0,0 +1,238 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package ttrpc + +import ( + "context" + "io" + "testing" + "time" + + "github.com/containerd/ttrpc/internal" +) + +// TestStreamNotConsumedDoesNotBlockConnection verifies that a stream whose +// receive buffer fills up (because the client stopped consuming) does not +// block other streams or unary calls on the same connection. +// +// This guards against a deadlock where the client's receiveLoop blocks +// trying to deliver a message to a full stream, which prevents all other +// streams on the same connection from receiving anything. +func TestStreamNotConsumedDoesNotBlockConnection(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + serviceName = "streamService" + ) + + defer listener.Close() + defer cleanup() + + desc := &ServiceDesc{ + Methods: map[string]Method{ + "Echo": func(_ context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req internal.EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + req.Seq++ + return &req, nil + }, + }, + Streams: map[string]Stream{ + "EchoStream": { + Handler: func(_ context.Context, ss StreamServer) (interface{}, error) { + for { + var req internal.EchoPayload + if err := ss.RecvMsg(&req); err != nil { + if err == io.EOF { + err = nil + } + return nil, err + } + req.Seq++ + if err := ss.SendMsg(&req); err != nil { + return nil, err + } + } + }, + StreamingClient: true, + StreamingServer: true, + }, + }, + } + server.RegisterService(serviceName, desc) + + go server.Serve(ctx, listener) + defer server.Close() + + // Create a bidirectional streaming RPC and send messages into it, + // but never call RecvMsg. This will fill up the stream's receive + // buffer (capacity 1) once the server echoes back. + abandonedStream, err := client.NewStream(ctx, &StreamDesc{true, true}, serviceName, "EchoStream", nil) + if err != nil { + t.Fatal(err) + } + + // Send enough messages to guarantee the server has echoed back more + // than the client-side buffer (capacity 1) can hold. + for i := 0; i < 10; i++ { + if err := abandonedStream.SendMsg(&internal.EchoPayload{ + Seq: int64(i), + Msg: "abandoned", + }); err != nil { + // Send may fail if the stream is closed due to buffer full, + // which is acceptable. + break + } + } + + // Wait for the receive loop to detect the abandoned stream. The buffer + // fills immediately, then the 1-second timeout fires, closing the + // stream and unblocking the receive loop for other streams. + time.Sleep(2 * time.Second) + + // A unary call on the same connection must succeed. Without the + // timeout in stream.receive, the receiveLoop would still be blocked + // trying to deliver to the abandoned stream. + callCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var req, resp internal.EchoPayload + req.Seq = 42 + req.Msg = "must not deadlock" + if err := client.Call(callCtx, serviceName, "Echo", &req, &resp); err != nil { + t.Fatalf("unary Call blocked by unconsumed stream: %v", err) + } + if resp.Seq != 43 { + t.Fatalf("unexpected sequence: got %d, want 43", resp.Seq) + } + + // Also verify a second stream works. + stream2, err := client.NewStream(callCtx, &StreamDesc{true, true}, serviceName, "EchoStream", nil) + if err != nil { + t.Fatalf("NewStream blocked by unconsumed stream: %v", err) + } + if err := stream2.SendMsg(&internal.EchoPayload{Seq: 1, Msg: "hello"}); err != nil { + t.Fatalf("SendMsg on second stream failed: %v", err) + } + var resp2 internal.EchoPayload + if err := stream2.RecvMsg(&resp2); err != nil { + t.Fatalf("RecvMsg on second stream failed: %v", err) + } + if resp2.Seq != 2 { + t.Fatalf("unexpected sequence on stream2: got %d, want 2", resp2.Seq) + } +} + +// TestStreamFullOnServer verifies that when a server-side stream handler +// stops consuming messages, the server's receive goroutine is not blocked +// and can still process other streams. This guards against the same +// deadlock as TestStreamNotConsumedDoesNotBlockConnection but on the +// server side, where streamHandler.data() blocks the receive goroutine. +func TestStreamFullOnServer(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + client, cleanup = newTestClient(t, addr) + serviceName = "streamService" + handlerReady = make(chan struct{}) + ) + + defer listener.Close() + defer cleanup() + + desc := &ServiceDesc{ + Methods: map[string]Method{ + "Echo": func(_ context.Context, unmarshal func(interface{}) error) (interface{}, error) { + var req internal.EchoPayload + if err := unmarshal(&req); err != nil { + return nil, err + } + req.Seq++ + return &req, nil + }, + }, + Streams: map[string]Stream{ + "SlowConsumer": { + Handler: func(ctx context.Context, _ StreamServer) (interface{}, error) { + // Signal that the handler is running, then stop consuming. + close(handlerReady) + // Block until the context is cancelled (server shutdown). + <-ctx.Done() + return nil, ctx.Err() + }, + StreamingClient: true, + StreamingServer: false, + }, + }, + } + server.RegisterService(serviceName, desc) + + go server.Serve(ctx, listener) + defer server.Close() + + // Open a stream whose server handler stops consuming after setup. + slowStream, err := client.NewStream(ctx, &StreamDesc{StreamingClient: true}, serviceName, "SlowConsumer", nil) + if err != nil { + t.Fatal(err) + } + + // Wait for the handler to be ready (and stopped consuming). + select { + case <-handlerReady: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for handler to start") + } + + // Send many messages to fill up the server's recv buffer (capacity 5). + // The server handler is not consuming, so these will pile up. + // We send in a goroutine because sends may eventually block. + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + for i := 0; i < 20; i++ { + if err := slowStream.SendMsg(&internal.EchoPayload{ + Seq: int64(i), + Msg: "filling buffer", + }); err != nil { + break + } + } + }() + + // Wait for the server receive goroutine to detect the full buffer. + // The 1-second timeout in data() fires, after which the receive + // goroutine can process other streams again. + time.Sleep(2 * time.Second) + + // Verify we can still make a unary call on the same connection. + callCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var req, resp internal.EchoPayload + req.Seq = 99 + if err := client.Call(callCtx, serviceName, "Echo", &req, &resp); err != nil { + t.Fatalf("unary Call blocked by full server stream: %v", err) + } + if resp.Seq != 100 { + t.Fatalf("unexpected sequence: got %d, want 100", resp.Seq) + } +}