diff --git a/commands/dial_stdio.go b/commands/dial_stdio.go index e848efba5100..ef641401cc0c 100644 --- a/commands/dial_stdio.go +++ b/commands/dial_stdio.go @@ -1,6 +1,7 @@ package commands import ( + "context" "io" "net" "os" @@ -15,7 +16,6 @@ import ( ocispecs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/pkg/errors" "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" ) type stdioOptions struct { @@ -79,29 +79,45 @@ func runDialStdio(dockerCli command.Cli, opts stdioOptions) error { return err } - defer conn.Close() - - go func() { - <-ctx.Done() - closeWrite(conn) - }() - - var eg errgroup.Group - - eg.Go(func() error { - _, err := io.Copy(conn, os.Stdin) - closeWrite(conn) - return err - }) - eg.Go(func() error { - _, err := io.Copy(os.Stdout, conn) - closeRead(conn) - return err - }) - return eg.Wait() + return proxyConn(ctx, conn, os.Stdin, os.Stdout) }) } +func proxyConn(ctx context.Context, conn net.Conn, stdin io.Reader, stdout io.Writer) error { + defer conn.Close() + + stdinDone := make(chan error, 1) + stdoutDone := make(chan error, 1) + + go func() { + _, err := io.Copy(conn, stdin) + closeWrite(conn) + stdinDone <- err + }() + go func() { + _, err := io.Copy(stdout, conn) + closeRead(conn) + stdoutDone <- err + }() + + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + case err := <-stdinDone: + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) { + return err + } + stdinDone = nil + case err := <-stdoutDone: + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) { + return err + } + return nil + } + } +} + func closeRead(conn net.Conn) error { if c, ok := conn.(interface{ CloseRead() error }); ok { return c.CloseRead() diff --git a/commands/dial_stdio_test.go b/commands/dial_stdio_test.go new file mode 100644 index 000000000000..a304da393d99 --- /dev/null +++ b/commands/dial_stdio_test.go @@ -0,0 +1,56 @@ +package commands + +import ( + "bytes" + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestProxyConnRemoteClose(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer serverConn.Close() + + stdin := &blockingReader{waitCh: make(chan struct{})} + defer stdin.Close() + + var stdout bytes.Buffer + errCh := make(chan error, 1) + go func() { + errCh <- proxyConn(context.Background(), clientConn, stdin, &stdout) + }() + + go func() { + _, _ = serverConn.Write([]byte("hello")) + _ = serverConn.Close() + }() + + select { + case err := <-errCh: + require.NoError(t, err) + require.Equal(t, "hello", stdout.String()) + case <-time.After(2 * time.Second): + t.Fatal("proxyConn did not return after the remote side closed") + } +} + +type blockingReader struct { + waitCh chan struct{} + closeOnce sync.Once +} + +func (r *blockingReader) Read([]byte) (int, error) { + <-r.waitCh + return 0, io.EOF +} + +func (r *blockingReader) Close() { + r.closeOnce.Do(func() { + close(r.waitCh) + }) +}