Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions commands/dial_stdio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package commands

import (
"context"
"io"
"net"
"os"
Expand All @@ -15,7 +16,6 @@ import (
ocispecs "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)

type stdioOptions struct {
Expand Down Expand Up @@ -79,29 +79,45 @@ func runDialStdio(dockerCli command.Cli, opts stdioOptions) error {
return err
}

defer conn.Close()

go func() {
<-ctx.Done()
closeWrite(conn)
}()

var eg errgroup.Group

eg.Go(func() error {
_, err := io.Copy(conn, os.Stdin)
closeWrite(conn)
return err
})
eg.Go(func() error {
_, err := io.Copy(os.Stdout, conn)
closeRead(conn)
return err
})
return eg.Wait()
return proxyConn(ctx, conn, os.Stdin, os.Stdout)
})
}

func proxyConn(ctx context.Context, conn net.Conn, stdin io.Reader, stdout io.Writer) error {
defer conn.Close()

stdinDone := make(chan error, 1)
stdoutDone := make(chan error, 1)

go func() {
_, err := io.Copy(conn, stdin)
closeWrite(conn)
stdinDone <- err
}()
go func() {
_, err := io.Copy(stdout, conn)
closeRead(conn)
stdoutDone <- err
}()

for {
select {
case <-ctx.Done():
return context.Cause(ctx)
case err := <-stdinDone:
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) {
return err
}
stdinDone = nil
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem necessary.

case err := <-stdoutDone:
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) {
return err
}
return nil
}
}
}

func closeRead(conn net.Conn) error {
if c, ok := conn.(interface{ CloseRead() error }); ok {
return c.CloseRead()
Expand Down
56 changes: 56 additions & 0 deletions commands/dial_stdio_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package commands

import (
"bytes"
"context"
"io"
"net"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestProxyConnRemoteClose(t *testing.T) {
clientConn, serverConn := net.Pipe()
defer serverConn.Close()

stdin := &blockingReader{waitCh: make(chan struct{})}
defer stdin.Close()

var stdout bytes.Buffer
errCh := make(chan error, 1)
go func() {
errCh <- proxyConn(context.Background(), clientConn, stdin, &stdout)
}()

go func() {
_, _ = serverConn.Write([]byte("hello"))
_ = serverConn.Close()
}()

select {
case err := <-errCh:
require.NoError(t, err)
require.Equal(t, "hello", stdout.String())
case <-time.After(2 * time.Second):
t.Fatal("proxyConn did not return after the remote side closed")
}
}

type blockingReader struct {
waitCh chan struct{}
closeOnce sync.Once
}

func (r *blockingReader) Read([]byte) (int, error) {
<-r.waitCh
return 0, io.EOF
}

func (r *blockingReader) Close() {
r.closeOnce.Do(func() {
close(r.waitCh)
})
}
Loading