From b50414f5adfe48a85275a4b92b5580b2477722d8 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 12 May 2026 13:01:16 +0000 Subject: [PATCH 01/15] fixes --- client.go | 86 +++++++++++++++++++++++++++++++++++-------- client_test.go | 35 ++++++++++++++++++ internal/sync/pool.go | 12 +++++- 3 files changed, 116 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index 7428cdf6..9b918584 100644 --- a/client.go +++ b/client.go @@ -281,6 +281,10 @@ func (c *clientConn) recv(ctx context.Context, reqid uint32, ch chan result) (*s c.discard(ch) return nil, ctx.Err() + case <-c.closed: + c.discard(ch) + return nil, sshfx.StatusConnectionLost + case res := <-ch: c.resPool.Put(ch) @@ -488,7 +492,6 @@ func statusToError(status *sshfx.StatusPacket, okExpected bool) error { return fmt.Errorf("unexpected SSH_FX_OK") } return nil - case sshfx.StatusEOF: return io.EOF } @@ -705,11 +708,17 @@ func (cl *Client) ReportPoolMetrics(wr io.Writer) { } } +// Wait blocks until the connection has shut down, +// and returns the error causing the shutdown. +// It can be called concurrently from multiple goroutines. +func (cl *Client) Wait() error { + return cl.conn.Wait() +} + // Close closes the SFTP session. func (cl *Client) Close() error { cl.conn.disconnect(nil) - cl.conn.wr.Close() - return nil + return cl.conn.wr.Close() } func wrapPathError(op, path string, err error) error { @@ -753,11 +762,8 @@ func wrapLinkError(op, oldpath, newpath string, err error) error { func (cl *Client) Mkdir(name string, perm fs.FileMode) error { return wrapPathError("mkdir", name, cl.sendPacket(context.Background(), nil, &sshfx.MkdirPacket{ - Path: name, - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, + Path: name, + Attrs: attrsFromPerm(perm), }), ) } @@ -845,6 +851,45 @@ func (cl *Client) Remove(name string) error { return wrapPathError("remove", name, errF) } +// RemoveAll removes path and any children it contains. +// It removes everything it can but returns the first error it encounters. +// f the path does not exist, RemoveAll returns nil (no error). +// If there is an error, it will be of type *PathError. +func (cl *Client) RemoveAll(pathname string) error { + fi, err := cl.Stat(pathname) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + + if !fi.IsDir() { + return cl.Remove(pathname) + } + + files, err := cl.ReadDir(pathname) + if err != nil { + return err + } + + for _, file := range files { + filename := path.Join(pathname, file.Name()) + switch { + case file.IsDir(): + if err := cl.RemoveAll(filename); err != nil { + return err + } + default: + if err := cl.Remove(filename); err != nil { + return err + } + } + } + + return nil +} + func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { return wrapPathError("setstat", name, cl.sendPacket(ctx, nil, &sshfx.SetStatPacket{ @@ -935,7 +980,7 @@ func (cl *Client) ReadLink(name string) (string, error) { // Server-specific restrictions may apply when old path and new path are in different directories. // Even within the same directory, on non-Unix servers Rename is not guaranteed to be an atomic operation. func (cl *Client) Rename(oldpath, newpath string) error { - if cl.hasExtension(openssh.ExtensionPOSIXRename()) { + if cl.HasExtension(openssh.ExtensionPOSIXRename()) { return wrapLinkError("rename", oldpath, newpath, cl.sendPacket(context.Background(), nil, &openssh.POSIXRenameExtendedPacket{ OldPath: oldpath, @@ -963,7 +1008,11 @@ func (cl *Client) Symlink(oldname, newname string) error { ) } -func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { +func (cl *Client) GetExtension(name string) string { + return cl.exts[name] +} + +func (cl *Client) HasExtension(ext *sshfx.ExtensionPair) bool { return cl.exts[ext.Name] == ext.Data } @@ -986,7 +1035,7 @@ func (cl *Client) StatVFS(path string) (*openssh.StatVFSExtendedReplyPacket, err // then no request will be sent, // and Link returns an *fs.LinkError wrapping sshfx.StatusOpUnsupported. func (cl *Client) Link(oldname, newname string) error { - if !cl.hasExtension(openssh.ExtensionHardlink()) { + if !cl.HasExtension(openssh.ExtensionHardlink()) { return wrapLinkError("hardlink", oldname, newname, sshfx.StatusOpUnsupported) } @@ -1422,6 +1471,14 @@ func (cl *Client) Create(name string) (*File, error) { return cl.OpenFile(name, OpenFlagReadWrite|OpenFlagCreate|OpenFlagTruncate, 0666) } +func attrsFromPerm(perm fs.FileMode) sshfx.Attributes { + var attrs sshfx.Attributes + if perm != 0 { + attrs.SetPermissions(sshfx.FileMode(perm.Perm())) + } + return attrs +} + // OpenFile is the generalized open call; // most users can use the simplified Open or Create methods instead. // It opens the named file with the specified flag (OpenFlagReadOnly, etc.). @@ -1434,10 +1491,7 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenPacket{ Filename: name, PFlags: toPortableFlags(flag), - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, + Attrs: attrsFromPerm(perm), }) if err != nil { return nil, wrapPathError("openfile", name, err) @@ -2468,7 +2522,7 @@ func (f *File) Sync() error { return f.wrapErr("fsync", err) } - if !f.cl.hasExtension(openssh.ExtensionFSync()) { + if !f.cl.HasExtension(openssh.ExtensionFSync()) { return f.wrapErr("fsync", sshfx.StatusOpUnsupported) } diff --git a/client_test.go b/client_test.go index e9e68eab..2892b0a0 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,13 @@ package sftp import ( + "bytes" + "errors" "io" "io/fs" "testing" + + sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" ) func TestClient(t *testing.T) { @@ -36,3 +40,34 @@ func TestClient(t *testing.T) { var _ allFS = new(fsys) } + +type sink struct{} + +func (sink) Close() error { return nil } +func (sink) Write(p []byte) (int, error) { return len(p), nil } + +// Issue #418: panic in clientConn.recv when the sid is incomplete. +func TestClientNoSid(t *testing.T) { + initPkt := &sshfx.VersionPacket{ + Version: sftpProtocolVersion, + } + + initData, err := initPkt.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + stream := new(bytes.Buffer) + stream.Write(initData) + stream.Write([]byte{ 0, 0, 0, 10, 0, 0 }) + + cl, err := NewClientPipe(t.Context(), stream, sink{}) + if err != nil { + t.Fatal(err) + } + + _, err = cl.Stat("anything") + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("cl.Stat = %v, expected sshfx.StatusConnectionLost", err) + } +} diff --git a/internal/sync/pool.go b/internal/sync/pool.go index 9c3e6dc2..7ffe6b98 100644 --- a/internal/sync/pool.go +++ b/internal/sync/pool.go @@ -168,7 +168,8 @@ func (p *Pool[T]) Put(v *T) { // relieving pressure on the garbage collector and amortizing allocation overhead. // While also co-ordinating outstanding work, so the caller can wait for all work to be complete. type WorkPool[T any] struct { - wg sync.WaitGroup + wg sync.WaitGroup + closed chan struct{} ch chan chan T } @@ -178,6 +179,8 @@ type WorkPool[T any] struct { // It will panic if given a negative depth, the same as making a negative-buffer channel. func NewWorkPool[T any](depth int) *WorkPool[T] { p := &WorkPool[T]{ + closed: make(chan struct{}), + ch: make(chan chan T, depth), } @@ -203,6 +206,7 @@ func (p *WorkPool[T]) Close() error { return errors.New("cannot close nil work pool") } + close(p.closed) close(p.ch) p.wg.Wait() @@ -243,6 +247,12 @@ func (p *WorkPool[T]) Put(v chan T) { return } + select { + case <-p.closed: + return + default: + } + select { case p.ch <- v: p.wg.Done() From 972cbeb5c3d02b0084e42744ce0caa815f1001da Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 12 May 2026 13:46:39 +0000 Subject: [PATCH 02/15] Client.ReaddirContext --- client.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 9b918584..7a88f25f 100644 --- a/client.go +++ b/client.go @@ -1051,13 +1051,20 @@ func (cl *Client) Link(oldname, newname string) error { // If an error occurs reading the directory, // Readdir returns the entries it was able to read before the error, along with the error. func (cl *Client) Readdir(name string) ([]fs.FileInfo, error) { + return cl.ReaddirContext(context.Background(), name) +} + +// ReaddirContext reads the named directory, returning all its directory entries as [fs.FileInfo] sorted by filename. +// If an error occurs reading the directory, including the context being canceled, +// Readdir returns the entries it was able to read before the error, along with the error. +func (cl *Client) ReaddirContext(ctx context.Context, name string) ([]fs.FileInfo, error) { d, err := cl.OpenDir(name) if err != nil { return nil, err } defer d.Close() - fis, err := d.Readdir(0) + fis, err := d.ReaddirContext(ctx, 0) slices.SortFunc(fis, func(a, b fs.FileInfo) int { return cmp.Compare(a.Name(), b.Name()) From b2536a6351f2a38aa71a917a950f51c094607b3b Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 12 May 2026 14:42:43 +0000 Subject: [PATCH 03/15] =?UTF-8?q?path=20import=20=E2=86=92=20stdpath?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 7a88f25f..8b47d533 100644 --- a/client.go +++ b/client.go @@ -11,7 +11,7 @@ import ( "iter" "math" "os" - "path" + stdpath "path" "slices" "sync/atomic" "syscall" @@ -783,7 +783,7 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { // Slow path: make sure parent exists and then call Mkdir for name. - if parent := path.Dir(name); parent != "" { + if parent := stdpath.Dir(name); parent != "" { err = cl.MkdirAll(parent, perm) if err != nil { return err @@ -874,7 +874,7 @@ func (cl *Client) RemoveAll(pathname string) error { } for _, file := range files { - filename := path.Join(pathname, file.Name()) + filename := stdpath.Join(pathname, file.Name()) switch { case file.IsDir(): if err := cl.RemoveAll(filename); err != nil { @@ -1267,6 +1267,7 @@ func (d *Dir) rangedir(ctx context.Context, grow func(int)) iter.Seq2[*sshfx.Nam // Pull from saved entries first. for i, ent := range d.entries { + switch ent.Name() { if !yield(ent, nil) { // This is a break condition. // We need to remove all entries that have been consumed, From c3fd5a1a6e4ae4d3833875d08762730f7a3cb6c9 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 12 May 2026 14:43:21 +0000 Subject: [PATCH 04/15] rangedir: bug: skip . and .. --- client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client.go b/client.go index 8b47d533..b9bd00de 100644 --- a/client.go +++ b/client.go @@ -1268,6 +1268,10 @@ func (d *Dir) rangedir(ctx context.Context, grow func(int)) iter.Seq2[*sshfx.Nam // Pull from saved entries first. for i, ent := range d.entries { switch ent.Name() { + case ".", "..": // skip useless names. + continue + } + if !yield(ent, nil) { // This is a break condition. // We need to remove all entries that have been consumed, From 77c75f658209398f8b63f774d340f606458930ea Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 12 May 2026 14:51:44 +0000 Subject: [PATCH 05/15] go fmt; golint --- client.go | 14 ++++++++++++-- client_test.go | 4 ++-- encoding/ssh/filexfer/buffer.go | 2 +- encoding/ssh/filexfer/extended_packets_test.go | 4 ++-- encoding/ssh/filexfer/response_packets.go | 2 +- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index b9bd00de..8e220364 100644 --- a/client.go +++ b/client.go @@ -580,6 +580,8 @@ func (cl *Client) getDataBuf(size int) []byte { return hint[:size] // trim our buffer to length, it might be longer than chunkSize. } +// SSHSession is an interface for x/crypto/ssh.Session, +// in order to break the dependency on x/crypto. type SSHSession interface { StdinPipe() (io.WriteCloser, error) StdoutPipe() (io.Reader, error) @@ -591,6 +593,8 @@ type SSHSession interface { Wait() error } +// SSHClient is an interface for x/crypto/ssh.Client, +// in order to break the dependency on x/crypto. type SSHClient[Session SSHSession] interface { NewSession() (Session, error) } @@ -1008,12 +1012,18 @@ func (cl *Client) Symlink(oldname, newname string) error { ) } +// GetExtension returns the data associated with the extension name. +// If there is no data associated with the extension name, it returns the empty string. func (cl *Client) GetExtension(name string) string { return cl.exts[name] } -func (cl *Client) HasExtension(ext *sshfx.ExtensionPair) bool { - return cl.exts[ext.Name] == ext.Data +// HasExtension returns true if the data associated with the extension name is the same as extension data. +// +// This is a much more strict test for extensions than simple existence of data associated with the extension name. +// To test for anything wider than extension data, use [Client.GetExtension]. +func (cl *Client) HasExtension(extension *sshfx.ExtensionPair) bool { + return cl.exts[extension.Name] == extension.Data } // StatVFS retrieves VFS statistics from a remote host. diff --git a/client_test.go b/client_test.go index 2892b0a0..91d99080 100644 --- a/client_test.go +++ b/client_test.go @@ -43,7 +43,7 @@ func TestClient(t *testing.T) { type sink struct{} -func (sink) Close() error { return nil } +func (sink) Close() error { return nil } func (sink) Write(p []byte) (int, error) { return len(p), nil } // Issue #418: panic in clientConn.recv when the sid is incomplete. @@ -59,7 +59,7 @@ func TestClientNoSid(t *testing.T) { stream := new(bytes.Buffer) stream.Write(initData) - stream.Write([]byte{ 0, 0, 0, 10, 0, 0 }) + stream.Write([]byte{0, 0, 0, 10, 0, 0}) cl, err := NewClientPipe(t.Context(), stream, sink{}) if err != nil { diff --git a/encoding/ssh/filexfer/buffer.go b/encoding/ssh/filexfer/buffer.go index 0d8cec1c..8dde0101 100644 --- a/encoding/ssh/filexfer/buffer.go +++ b/encoding/ssh/filexfer/buffer.go @@ -197,7 +197,7 @@ func (b *Buffer) ConsumeUint32() uint32 { func (b *Buffer) AppendUint32(v uint32) { b.b = binary.BigEndian.AppendUint32(b.b, v) } -//*/ + // ConsumeCount consumes a single uint32 count from the buffer, in network byte order (big-endian) as an int. // If the buffer does not have enough data, it will set Err to ErrShortPacket. func (b *Buffer) ConsumeCount() (int, error) { diff --git a/encoding/ssh/filexfer/extended_packets_test.go b/encoding/ssh/filexfer/extended_packets_test.go index e097e5b9..79843f2a 100644 --- a/encoding/ssh/filexfer/extended_packets_test.go +++ b/encoding/ssh/filexfer/extended_packets_test.go @@ -183,7 +183,7 @@ func TestExtendedPacketTestData(t *testing.T) { t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) } - wantBuffer := []byte{ textValue^0x2a } + wantBuffer := []byte{textValue ^ 0x2a} if buf, ok := p.Data.(*Buffer); !ok { t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) @@ -243,7 +243,7 @@ func TestExtendedPacketTestBuffer(t *testing.T) { t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) } - wantBuffer := []byte{ 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r' } + wantBuffer := []byte{0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r'} if buf, ok := p.Data.(*Buffer); !ok { t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) diff --git a/encoding/ssh/filexfer/response_packets.go b/encoding/ssh/filexfer/response_packets.go index 60ec6d04..c85fcdd8 100644 --- a/encoding/ssh/filexfer/response_packets.go +++ b/encoding/ssh/filexfer/response_packets.go @@ -44,7 +44,7 @@ func (p *StatusPacket) MarshalSize() int { // uint32(length) + uint8(type) + uint32(request-id) const size = 4 + 1 + 4 - // uint32(error/status code) + string(error message) + string(language tag) + // uint32(error/status code) + string(error message) + string(language tag) return size + 4 + 4 + len(p.ErrorMessage) + 4 + len(p.LanguageTag) } From 7fa0a047ca1f5ad68aad05f13e75422fc4de8216 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 11:56:22 +0000 Subject: [PATCH 06/15] update go tags --- errno_posix.go | 1 - fuzz.go | 2 -- ls_plan9.go | 1 - ls_stub.go | 1 - ls_unix.go | 1 - 5 files changed, 6 deletions(-) diff --git a/errno_posix.go b/errno_posix.go index bc5bbcd5..45a1f914 100644 --- a/errno_posix.go +++ b/errno_posix.go @@ -1,5 +1,4 @@ //go:build !plan9 -// +build !plan9 package sftp diff --git a/fuzz.go b/fuzz.go index 72202818..324eb20b 100644 --- a/fuzz.go +++ b/fuzz.go @@ -1,6 +1,4 @@ -// go:build gofuzz //go:build gofuzz -// +build gofuzz package sftp diff --git a/ls_plan9.go b/ls_plan9.go index ed7dc044..21e8f11e 100644 --- a/ls_plan9.go +++ b/ls_plan9.go @@ -1,5 +1,4 @@ //go:build plan9 -// +build plan9 package sftp diff --git a/ls_stub.go b/ls_stub.go index 1cf29cd6..b584d695 100644 --- a/ls_stub.go +++ b/ls_stub.go @@ -1,5 +1,4 @@ //go:build windows || android -// +build windows android package sftp diff --git a/ls_unix.go b/ls_unix.go index ec4efa2f..2e659801 100644 --- a/ls_unix.go +++ b/ls_unix.go @@ -1,5 +1,4 @@ //go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js || wasip1 || zos -// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js wasip1 zos package sftp From 5524c1065a24994ad15cc8da896db798497c456d Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 12:01:12 +0000 Subject: [PATCH 07/15] sync.WorkPool: if closed Get returns false, wg.Done before testing if closed --- internal/sync/pool.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/internal/sync/pool.go b/internal/sync/pool.go index 7ffe6b98..54bb1e4b 100644 --- a/internal/sync/pool.go +++ b/internal/sync/pool.go @@ -229,6 +229,13 @@ func (p *WorkPool[T]) Get() (chan T, bool) { return make(chan T, 1), true } + select { + case <-p.closed: + var ch chan T + return ch, false + default: + } + v, ok := <-p.ch if ok { p.wg.Add(1) @@ -247,6 +254,8 @@ func (p *WorkPool[T]) Put(v chan T) { return } + p.wg.Done() + select { case <-p.closed: return @@ -255,7 +264,6 @@ func (p *WorkPool[T]) Put(v chan T) { select { case p.ch <- v: - p.wg.Done() default: panic("worker pool overfill") // This is an overfill, which shouldn't happen, but just in case... From 87fe5f4808e385a6bc4478e7e4514a073f5bb578 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 12:01:37 +0000 Subject: [PATCH 08/15] sync: include Once alias --- internal/sync/aliases.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/sync/aliases.go b/internal/sync/aliases.go index d611b71e..53d72e1f 100644 --- a/internal/sync/aliases.go +++ b/internal/sync/aliases.go @@ -12,3 +12,6 @@ type RWMutex = sync.RWMutex // WaitGroup is an alias to [sync.WaitGroup] type WaitGroup = sync.WaitGroup + +// Once is an alias to [sync.Once] +type Once = sync.Once From 09c2367c37214cac20bd6512facafe05f4969866 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 12:03:50 +0000 Subject: [PATCH 09/15] import tests from v1 --- Makefile | 33 +- client_integration_test.go | 819 +++++++++++++++++++++++++++++++++++++ client_test.go | 121 +++++- 3 files changed, 960 insertions(+), 13 deletions(-) create mode 100644 client_integration_test.go diff --git a/Makefile b/Makefile index 65c6ae99..789c8ce4 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,36 @@ .PHONY: integration integration_w_race +DELAY ?= 0 + +ifneq ($(DELAY),0) +DELAY_FLAG=-delay $(DELAY) +endif + +TAGS=integration,sftp.sync.metrics + integration: - go test -v ./... + go test -v $(DELAY_FLAG) -tags=$(TAGS) + go test ./encoding/... ./internal/... make -C localfs integration integration_w_race: - go test -race -v ./... - make -C localfs integration + go test -race -v $(DELAY_FLAG) -tags=$(TAGS) + go test -race ./encoding/... ./internal/... + make -C localfs integration_w_race + +COUNT ?= 1 +BENCHMARK_PATTERN ?= "." + +benchmark: + go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) $(DELAY_FLAG) -tags=$(TAGS) + make -C localfs benchmark + +benchmark_w_memprofile: +ifneq ($(DELAY),0) + @echo "memprofile with DELAY produces invalid data" >&2 + @exit 1 +endif + go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) -memprofile memprofile.out -tags=integration + go tool pprof -sample_index=alloc_space -svg -output=memprofile-space.svg memprofile.out + go tool pprof -sample_index=alloc_objects -svg -output=memprofile-allocs.svg memprofile.out + make -C localfs benchmark_w_memprofile diff --git a/client_integration_test.go b/client_integration_test.go new file mode 100644 index 00000000..6c60e9f8 --- /dev/null +++ b/client_integration_test.go @@ -0,0 +1,819 @@ +//go:build integration && !windows + +package sftp + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "os" + "os/exec" + "slices" + "testing" + "time" + + sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" + "github.com/pkg/sftp/v2/internal/sync" +) + +const ( + kibi = 1024 + mebi = 1024 * 1024 +) + +const ( + ReadOnly = true + ReadWrite = false + NoDelay time.Duration = 0 +) + +const debuglevel = "ERROR" // set to "DEBUG" for debugging + +var testSftp *string + +func TestMain(m *testing.M) { + sftpServerLocations := []string{ + "/usr/libexec/ssh/sftp-server", + "/usr/libexec/sftp-server", + "/usr/lib/openssh/sftp-server", + "/usr/lib/ssh/sftp-server", + `C:\Program Files\Git\usr\lib\ssh\sftp-server.exe`, + } + + sftpServer, _ := exec.LookPath("sftp-server") + if sftpServer == "" { + for _, loc := range sftpServerLocations { + if _, err := os.Stat(loc); err == nil { + sftpServer = loc + break + } + } + + if sftpServer == "" { + fmt.Fprintln(os.Stdout, "FAIL: could not find sftp-server") + os.Exit(1) + } + } + + testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") + flag.Parse() + + os.Exit(m.Run()) +} + +type delayedWrite struct { + t time.Time + b []byte +} + +// delayedWriter wraps a writer and artificially delays the write. +// This is meant to mimic connections with various latencies. +// Errors returned from the underlying writer will panic so this should only be used over reliable connections. +type delayedWriter struct { + ch chan delayedWrite + closing chan struct{} + + wg sync.WaitGroup + closed <-chan struct{} +} + +func newDelayedWriter(t testing.TB, wr io.WriteCloser, delay time.Duration) *delayedWriter { + closed := make(chan struct{}) + + dw := &delayedWriter{ + ch: make(chan delayedWrite, 128), + closing: make(chan struct{}), + closed: closed, + } + + ctx := t.Context() + + go func() { + defer close(closed) + + defer wr.Close() + + for write := range dw.ch { + select { + case <-ctx.Done(): + return + case <-time.After(time.Until(write.t.Add(delay))): + } + + n, err := wr.Write(write.b) + if err != nil { + panic(err) + } + + if n < len(write.b) { + panic(io.ErrShortWrite) + } + } + }() + + return dw +} + +func (dw *delayedWriter) Write(b []byte) (int, error) { + select { + case <-dw.closing: + return 0, io.ErrClosedPipe + default: + } + + dw.wg.Add(1) + defer dw.wg.Done() + + dw.ch <- delayedWrite{ + t: time.Now(), + b: slices.Clone(b), + } + + return len(b), nil +} + +func (dw *delayedWriter) Close() error { + close(dw.closing) + + dw.wg.Wait() // wait for any outstanding blocking writes + + close(dw.ch) + + <-dw.closed // wait for writer goroutine to finish. + + return nil +} + +func testClient(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { + args := []string{ + "-e", + "-l", debuglevel, + } + if readonly { + args = append(args, "-R") + } + + cmd := exec.Command(*testSftp, args...) + + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + + if delay > 0 { + pw = newDelayedWriter(t, pw, delay) + } + + pr, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + + if err := cmd.Start(); err != nil { + t.Skip("could not start sftp-server process:", err) + } + + cl, err := newClientPipe(t.Context(), pr, nil, pw, nil, opts) + if err != nil { + t.Fatal(err) + } + + return cl, cmd +} + +// github.com/pkg/sftp/issues/42, abrupt server hangup would result in client hangs. +func TestServerRoughDisconnect(t *testing.T) { + cl, cmd := testClient(t, ReadOnly, NoDelay) + defer cmd.Wait() + defer cl.Close() + + f, err := cl.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(io.Discard, f) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +// github.com/pkg/sftp/issues/181, abrupt server hangup would result in client hangs. +// due to broadcastErr filling up the request channel +// this reproduces it about 50% of the time +func TestServerRoughDisconnect2(t *testing.T) { + cl, cmd := testClient(t, ReadOnly, NoDelay) + defer cmd.Wait() + defer cl.Close() + + f, err := cl.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + b := make([]byte, 100*32*kibi) + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + for { + if _, err := f.Read(b); err != nil { + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("File.Read error = %#v, but wanted sshfx.StatusConnectionLost", err) + } + break + } + } +} + +// github.com/pkg/sftp/issues/234 - abrupt shutdown during ReadFrom hangs client +func TestServerRoughDisconnect3(t *testing.T) { + cl, cmd := testClient(t, ReadWrite, NoDelay) + defer cmd.Wait() + defer cl.Close() + + dst, err := cl.OpenFile("/dev/null", OpenFlagReadWrite, 0) + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(dst, src) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +// github.com/pkg/sftp/issues/234 - also affected Write +func TestServerRoughDisconnect4(t *testing.T) { + cl, cmd := testClient(t, ReadWrite, NoDelay) + defer cmd.Wait() + defer cl.Close() + + dst, err := cl.OpenFile("/dev/null", OpenFlagReadWrite, 0) + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + b := make([]byte, 200*32*kibi) + + if _, err = src.Read(b); err != nil { + t.Fatal(err) + } + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + for { + if _, err := dst.Write(b); err != nil { + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("dst.Write error = %#v, but wanted sshfx.StatusConnectionLost", err) + } + break + } + } + + _, err = io.Copy(dst, src) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + cl, cmd := testClient(b, ReadOnly, delay) + defer cmd.Wait() + defer cl.Close() + + buf := make([]byte, bufsize) + + b.SetBytes(int64(size)) + + for b.Loop() { + offset := 0 + + f, err := cl.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + + for offset < size { + remaining := size - offset + buf := buf[:min(remaining, len(buf))] + + n, err := io.ReadFull(f, buf) + offset += n + + if err != nil { + b.Fatalf("read error at %d: %v", offset, err) + } + } + + switch { + case offset < size: + b.Fatalf("read too few bytes! read: %d, wanted: %d", offset, size) + case offset > size: + b.Fatalf("read too many bytes! read: %d, wanted: %d", offset, size) + } + + f.Close() + } +} + +func BenchmarkRead1k(b *testing.B) { + benchmarkRead(b, 1*kibi, NoDelay) +} + +func BenchmarkRead16k(b *testing.B) { + benchmarkRead(b, 16*kibi, NoDelay) +} + +func BenchmarkRead32k(b *testing.B) { + benchmarkRead(b, 32*kibi, NoDelay) +} + +func BenchmarkRead128k(b *testing.B) { + benchmarkRead(b, 128*kibi, NoDelay) +} + +func BenchmarkRead512k(b *testing.B) { + benchmarkRead(b, 512*kibi, NoDelay) +} + +func BenchmarkRead1MiB(b *testing.B) { + benchmarkRead(b, mebi, NoDelay) +} + +func BenchmarkRead4MiB(b *testing.B) { + benchmarkRead(b, 4*mebi, NoDelay) +} + +func BenchmarkRead4MiBDelay10Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkRead4MiBDelay50Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkRead4MiBDelay150Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 0x123 // ~10MiB + + cl, cmd := testClient(b, ReadWrite, delay) + defer cmd.Wait() + defer cl.Close() + + data := make([]byte, size) + for i := range data { + data[i] = uint8(i >> ((i % 4) * 8)) + } + + b.SetBytes(int64(size)) + + for b.Loop() { + func() { + offset := 0 + + f, err := os.CreateTemp("", "sftptest-benchwrite") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + f2, err := cl.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + buf := data[offset:] + buf = buf[:min(len(buf), bufsize)] + + n, err := f2.Write(buf) + offset += n + if err != nil { + b.Fatalf("write error at %d: %v", offset, err) + } + + if n != len(buf) { + b.Fatalf("wrote too few bytes! written: %d, wanted: %d", n, len(buf)) + } + } + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), size) + } + }() + } +} + +func BenchmarkWrite1k(b *testing.B) { + benchmarkWrite(b, 1*kibi, NoDelay) +} + +func BenchmarkWrite16k(b *testing.B) { + benchmarkWrite(b, 16*kibi, NoDelay) +} + +func BenchmarkWrite32k(b *testing.B) { + benchmarkWrite(b, 32*kibi, NoDelay) +} + +func BenchmarkWrite128k(b *testing.B) { + benchmarkWrite(b, 128*kibi, NoDelay) +} + +func BenchmarkWrite512k(b *testing.B) { + benchmarkWrite(b, 512*kibi, NoDelay) +} + +func BenchmarkWrite1MiB(b *testing.B) { + benchmarkWrite(b, mebi, NoDelay) +} + +func BenchmarkWrite4MiB(b *testing.B) { + benchmarkWrite(b, 4*mebi, NoDelay) +} + +func BenchmarkWrite4MiBDelay10Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay50Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay150Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + // open sftp client + cl, cmd := testClient(b, ReadWrite, delay) + defer cmd.Wait() + defer cl.Close() + + data := make([]byte, size) + + b.SetBytes(int64(size)) + + for b.Loop() { + func() { + f, err := os.CreateTemp("", "sftptest-benchreadfrom") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + f2, err := cl.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + f2.ReadFrom(bytes.NewReader(data)) + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), size) + } + }() + } +} + +func BenchmarkReadFrom1k(b *testing.B) { + benchmarkReadFrom(b, 1*kibi, NoDelay) +} + +func BenchmarkReadFrom16k(b *testing.B) { + benchmarkReadFrom(b, 16*kibi, NoDelay) +} + +func BenchmarkReadFrom32k(b *testing.B) { + benchmarkReadFrom(b, 32*kibi, NoDelay) +} + +func BenchmarkReadFrom128k(b *testing.B) { + benchmarkReadFrom(b, 128*kibi, NoDelay) +} + +func BenchmarkReadFrom512k(b *testing.B) { + benchmarkReadFrom(b, 512*kibi, NoDelay) +} + +func BenchmarkReadFrom1MiB(b *testing.B) { + benchmarkReadFrom(b, mebi, NoDelay) +} + +func BenchmarkReadFrom4MiB(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, NoDelay) +} + +func BenchmarkReadFrom4MiBDelay10Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay50Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + // open sftp client + cl, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer cl.Close() + + f, err := os.CreateTemp("", "sftptest-benchwriteto") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + data := make([]byte, size) + + if _, err = f.Write(data); err != nil { + b.Fatal(err) + } + + if err = f.Close(); err != nil { + b.Fatal(err) + } + + buf := bytes.NewBuffer(make([]byte, 0, size)) + + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + buf.Reset() + + f2, err := cl.Open(f.Name()) + if err != nil { + b.Fatal(err) + } + + if _, err = f2.WriteTo(buf); err != nil { + b.Fatal(err) + } + + if err = f2.Close(); err != nil { + b.Fatal(err) + } + + if buf.Len() != size { + b.Fatalf("wrong buffer size: got: %d, want: %d", buf.Len(), size) + } + } +} + +func BenchmarkWriteTo1k(b *testing.B) { + benchmarkWriteTo(b, 1*kibi, NoDelay) +} + +func BenchmarkWriteTo16k(b *testing.B) { + benchmarkWriteTo(b, 16*kibi, NoDelay) +} + +func BenchmarkWriteTo32k(b *testing.B) { + benchmarkWriteTo(b, 32*kibi, NoDelay) +} + +func BenchmarkWriteTo128k(b *testing.B) { + benchmarkWriteTo(b, 128*kibi, NoDelay) +} + +func BenchmarkWriteTo512k(b *testing.B) { + benchmarkWriteTo(b, 512*kibi, NoDelay) +} + +func BenchmarkWriteTo1MiB(b *testing.B) { + benchmarkWriteTo(b, mebi, NoDelay) +} + +func BenchmarkWriteTo4MiB(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, NoDelay) +} + +func BenchmarkWriteTo4MiBDelay10Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay50Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay150Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 150*time.Millisecond) +} + +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} + +func (zeroSource) Close() error { return nil } + +func zeroFile(t testing.TB, filename string, filesize int64) string { + src, err := os.CreateTemp("", filename) + if err != nil { + t.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(src, io.LimitReader(zeroSource{}, filesize)) + if err != nil { + t.Fatal(err) + } + + if n < filesize { + t.Fatal("short copy") + } + + return src.Name() +} + +func benchmarkCopyDown(b *testing.B, filesize int64, delay time.Duration) { + srcFilename := zeroFile(b, "sftptest-benchcopydown-src", filesize) + defer os.Remove(srcFilename) + + cl, cmd := testClient(b, ReadOnly, delay) + defer cmd.Wait() + defer cl.Close() + + b.SetBytes(filesize) + + for b.Loop() { + func() { + dst, err := os.CreateTemp("", "sftptest-benchcopydown-dst") + if err != nil { + b.Fatal(err) + } + defer os.Remove(dst.Name()) + defer dst.Close() + + src, err := cl.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(dst, src) + if err != nil { + b.Fatal("copy error:", err) + } + + if n != filesize { + b.Fatalf("wrong bytes copied: got: %d, want: %d", n, filesize) + } + + fi, err := dst.Stat() + if err != nil { + b.Fatal(err) + } + + if fi.Size() != filesize { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), filesize) + } + }() + } +} + +func BenchmarkCopyDown10MiBDelay10Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 10*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay50Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 50*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay150Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 150*time.Millisecond) +} + +func benchmarkCopyUp(b *testing.B, filesize int64, delay time.Duration) { + srcFilename := zeroFile(b, "sftptest-benchcopyup-src", filesize) + defer os.Remove(srcFilename) + + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + b.SetBytes(filesize) + + for b.Loop() { + func() { + // We need to create the destination filename through the OS first. + tmpDst, err := os.CreateTemp("", "sftptest-benchcopyup-dst") + if err != nil { + b.Fatal(err) + } + dstFilename := tmpDst.Name() + tmpDst.Close() + + defer os.Remove(dstFilename) + + // Now, we can use the filename from above as our destination. + dst, err := sftp.Create(dstFilename) + if err != nil { + b.Fatal(err) + } + defer dst.Close() + + src, err := os.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(dst, src) + if err != nil { + b.Fatal("copy error:", err) + } + + if n < filesize { + b.Error("unable to copy all bytes") + } + + fi, err := os.Stat(dstFilename) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != filesize { + b.Errorf("wrong file size: got %d, want %d", fi.Size(), filesize) + } + }() + } +} + +func BenchmarkCopyUp10MiBDelay10Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 10*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay50Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 50*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay150Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 150*time.Millisecond) +} diff --git a/client_test.go b/client_test.go index 91d99080..3a36afff 100644 --- a/client_test.go +++ b/client_test.go @@ -46,28 +46,129 @@ type sink struct{} func (sink) Close() error { return nil } func (sink) Write(p []byte) (int, error) { return len(p), nil } +func TestClientZeroLengthPacket(t *testing.T) { + // Packet length zero (never valid). This used to crash the client. + r := bytes.NewReader([]byte{0, 0, 0, 0}) + + cl, err := NewClientPipe(t.Context(), r, sink{}) + if err == nil { + t.Error("expected an error, got nil") + } + if cl != nil { + cl.Close() + } +} + +func TestClientShortPacket(t *testing.T) { + // init packet too short. + r := bytes.NewReader([]byte{0, 0, 0, 1, 2}) + + cl, err := NewClientPipe(t.Context(), r, sink{}) + if !errors.Is(err, sshfx.ErrShortPacket) { + t.Fatalf("got error %#v, but expected sshfx.ErrShortPacket", err) + } + if cl != nil { + cl.Close() + } +} + +func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.PacketMarshaller) io.ReadCloser { + pr, pw := io.Pipe() + + go func() { + defer pw.Close() + + bindata, err := verPkt.MarshalBinary() + if err != nil { + t.Errorf("could not binary marshal %#v: %v", verPkt, err) + return + } + pw.Write(bindata) + + var reqid uint32 + for _, packet := range packets { + reqid++ + + header, payload, err := packet.MarshalPacket(reqid, nil) + if err != nil { + t.Errorf("could not packet marshal %#v: %v", packet, err) + return + } + + if _, err := pw.Write(header); err != nil { + t.Error("could not write packet header:", err) + return + } + + if _, err := pw.Write(payload); err != nil { + t.Error("could not write packet payload:", err) + return + } + } + }() + + return pr +} + +type noSidBrokenPacket struct{} + +func (noSidBrokenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + return []byte{0, 0, 0, 10, 0, 0}, nil, nil +} + +func (noSidBrokenPacket) MarshalSize() int { + return 10 +} + // Issue #418: panic in clientConn.recv when the sid is incomplete. func TestClientNoSid(t *testing.T) { - initPkt := &sshfx.VersionPacket{ - Version: sftpProtocolVersion, - } + stream := streamPackets(t, + &sshfx.VersionPacket{ + Version: sftpProtocolVersion, + }, + noSidBrokenPacket{}, + ) - initData, err := initPkt.MarshalBinary() + cl, err := NewClientPipe(t.Context(), stream, sink{}) if err != nil { - t.Fatal("unexpected error:", err) + t.Fatal(err) } + defer cl.Close() - stream := new(bytes.Buffer) - stream.Write(initData) - stream.Write([]byte{0, 0, 0, 10, 0, 0}) + _, err = cl.Stat("anything") + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("cl.Stat = %q, expected sshfx.StatusConnectionLost", err) + } +} + +// sftp/issue/390 - server disconnect should not cause io.EOF or +// io.ErrUnexpectedEOF in sftp.File.Read, because those confuse io.ReadFull. +func TestClientRoughDisconnectEOF(t *testing.T) { + stream := streamPackets(t, + &sshfx.VersionPacket{ + Version: sftpProtocolVersion, + }, + &sshfx.HandlePacket{ + Handle: "foo", + }, + &sshfx.DataPacket{ + Data: []byte("foo"), + }, + ) cl, err := NewClientPipe(t.Context(), stream, sink{}) if err != nil { t.Fatal(err) } + defer cl.Close() - _, err = cl.Stat("anything") + f, err := cl.Open("anything") + if err != nil { + t.Fatal(err) + } + + _, err = io.ReadFull(f, make([]byte, 10)) if !errors.Is(err, sshfx.StatusConnectionLost) { - t.Errorf("cl.Stat = %v, expected sshfx.StatusConnectionLost", err) + t.Errorf("io.ReadFull error = %q, but wanted sshfx.StatusConnectionLost", err) } } From c44c329e8ff71b8074ceb51ba0b4c197eb799e83 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 12:05:11 +0000 Subject: [PATCH 10/15] bugfixes --- client.go | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 8e220364..754f9954 100644 --- a/client.go +++ b/client.go @@ -226,12 +226,6 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller default: } - if c.inflight == nil { - c.inflight = make(map[uint32]chan<- result) - } - - c.inflight[reqid] = ch - if _, err := c.wr.Write(header); err != nil { c.resPool.Put(ch) return reqid, nil, fmt.Errorf("sftp: write packet header: %w", err) @@ -244,6 +238,11 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller } } + if c.inflight == nil { + c.inflight = make(map[uint32]chan<- result) + } + c.inflight[reqid] = ch + return reqid, ch, nil } @@ -253,10 +252,18 @@ func (c *clientConn) returnRaw(raw *sshfx.RawPacket) { } func (c *clientConn) discardBlocking(ch chan result) { - res := <-ch + select { + case <-c.closed: + // We've been disconnected. + // We have to return something to the work pool, or resPool.Close will deadlock. + // It also has to be able to receive a result, + // otherwise a broadcasted error will lock up on a send. + c.resPool.Put(make(chan result, 1)) - c.returnRaw(res.pkt) - c.resPool.Put(ch) + case res := <-ch: + c.returnRaw(res.pkt) + c.resPool.Put(ch) + } } func (c *clientConn) discard(ch chan result) { @@ -492,6 +499,7 @@ func statusToError(status *sshfx.StatusPacket, okExpected bool) error { return fmt.Errorf("unexpected SSH_FX_OK") } return nil + case sshfx.StatusEOF: return io.EOF } @@ -1778,14 +1786,12 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e // // Either way, this should be the last successfully written offset. written := int64(firstErr.off) - f.offset - f.offset = int64(firstErr.off) return int(written), f.wrapErr("writeat", firstErr.err) } // We didn't hit any errors, so we must have written all the bytes in the buffer. written = len(b) - f.offset += int64(written) return written, nil } @@ -1911,6 +1917,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { workCh := make(chan work, f.cl.maxInflight) type rwErr struct { + op string off uint64 err error } @@ -1936,7 +1943,11 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { for { n, err := io.ReadFull(r, b) if n < 0 { - errCh <- rwErr{req.Offset, panicInstead("sftp: readfrom: read returned negative count")} + errCh <- rwErr{ + op: "read", + off: req.Offset, + err: panicInstead("sftp: readfrom: read returned negative count"), + } return } @@ -1966,7 +1977,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { if err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { - errCh <- rwErr{req.Offset, err} + errCh <- rwErr{"read", req.Offset, err} } return } @@ -1984,7 +1995,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { for work := range workCh { err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { - errCh <- rwErr{work.off, err} + errCh <- rwErr{"write", work.off, err} // DO NOT return. // We want to ensure that workCh is drained before errCh is closed. @@ -2020,7 +2031,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { } // ReadFrom is defined to return the read bytes, regardless of any write errors. - return read, f.wrapErr("readfrom", firstErr.err) + return read, f.wrapErr(firstErr.op, firstErr.err) } // We didn't hit any errors, so we must have written all the bytes that we read until EOF. @@ -2341,7 +2352,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } }() - var writeErr error + var readErr error // Dispatch: Dispatch into any number of Reads of length <= f.cl.maxDataLen. go func() { @@ -2358,7 +2369,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { for { reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { - writeErr = err + readErr = err return } @@ -2418,11 +2429,11 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { return written, nil // return nil instead of EOF } - return written, f.wrapErr("writeto", err) + return written, f.wrapErr("recv", err) } } - return written, f.wrapErr("writeto", writeErr) + return written, f.wrapErr("read", readErr) } // WriteFile writes data to the named file, creating it if neccessary. From d86f5dfefa812240fa7d4ffdb7d47cecf7428b89 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 12:10:01 +0000 Subject: [PATCH 11/15] writeat: we shouldn't be using f.offset at all --- client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 754f9954..470bed4f 100644 --- a/client.go +++ b/client.go @@ -1716,7 +1716,7 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e req := &sshfx.WritePacket{ Handle: handle, - Offset: uint64(f.offset), + Offset: uint64(off), } for len(b) > 0 { @@ -1785,7 +1785,7 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e // * the offset of the start of the first error received dispatching a write packet offset. // // Either way, this should be the last successfully written offset. - written := int64(firstErr.off) - f.offset + written := int64(firstErr.off) - off return int(written), f.wrapErr("writeat", firstErr.err) } From a711c561b996b90fd410bc4a347fc5bc95f853ce Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 13:00:23 +0000 Subject: [PATCH 12/15] fix race condition between chan close and chan write --- client_test.go | 59 ++++++++++++++++++++++++++++++++++++++----- internal/sync/pool.go | 9 ++++--- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/client_test.go b/client_test.go index 3a36afff..f259ffc6 100644 --- a/client_test.go +++ b/client_test.go @@ -5,9 +5,11 @@ import ( "errors" "io" "io/fs" + "sync/atomic" "testing" sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" + "github.com/pkg/sftp/v2/internal/sync" ) func TestClient(t *testing.T) { @@ -72,10 +74,46 @@ func TestClientShortPacket(t *testing.T) { } } -func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.PacketMarshaller) io.ReadCloser { +type tickWriter struct { + count atomic.Uint32 + steps atomic.Uint32 + + once sync.Once + step chan struct{} +} + +func (t *tickWriter) wait() { + _ = t.steps.Add(1) + + <-t.step +} + +func (t *tickWriter) Write(b []byte) (written int, err error) { + _ = t.count.Add(1) + + t.step <- struct{}{} + + return len(b), nil +} + +func (t *tickWriter) Close() error { + t.once.Do(func() { close(t.step) }) + return nil +} + +func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.PacketMarshaller) (io.ReadCloser, io.WriteCloser) { pr, pw := io.Pipe() + ticker := &tickWriter{ + step: make(chan struct{}), + } + go func() { + defer func() { + for range ticker.step { + } + }() + defer pw.Close() bindata, err := verPkt.MarshalBinary() @@ -83,7 +121,12 @@ func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.P t.Errorf("could not binary marshal %#v: %v", verPkt, err) return } - pw.Write(bindata) + + ticker.wait() + if _, err := pw.Write(bindata); err != nil { + t.Error(err) + return + } var reqid uint32 for _, packet := range packets { @@ -95,6 +138,8 @@ func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.P return } + ticker.wait() + if _, err := pw.Write(header); err != nil { t.Error("could not write packet header:", err) return @@ -107,7 +152,7 @@ func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.P } }() - return pr + return pr, ticker } type noSidBrokenPacket struct{} @@ -122,14 +167,14 @@ func (noSidBrokenPacket) MarshalSize() int { // Issue #418: panic in clientConn.recv when the sid is incomplete. func TestClientNoSid(t *testing.T) { - stream := streamPackets(t, + rd, wr := streamPackets(t, &sshfx.VersionPacket{ Version: sftpProtocolVersion, }, noSidBrokenPacket{}, ) - cl, err := NewClientPipe(t.Context(), stream, sink{}) + cl, err := NewClientPipe(t.Context(), rd, wr) if err != nil { t.Fatal(err) } @@ -144,7 +189,7 @@ func TestClientNoSid(t *testing.T) { // sftp/issue/390 - server disconnect should not cause io.EOF or // io.ErrUnexpectedEOF in sftp.File.Read, because those confuse io.ReadFull. func TestClientRoughDisconnectEOF(t *testing.T) { - stream := streamPackets(t, + rd, wr := streamPackets(t, &sshfx.VersionPacket{ Version: sftpProtocolVersion, }, @@ -156,7 +201,7 @@ func TestClientRoughDisconnectEOF(t *testing.T) { }, ) - cl, err := NewClientPipe(t.Context(), stream, sink{}) + cl, err := NewClientPipe(t.Context(), rd, wr) if err != nil { t.Fatal(err) } diff --git a/internal/sync/pool.go b/internal/sync/pool.go index 54bb1e4b..5d706f2d 100644 --- a/internal/sync/pool.go +++ b/internal/sync/pool.go @@ -180,8 +180,7 @@ type WorkPool[T any] struct { func NewWorkPool[T any](depth int) *WorkPool[T] { p := &WorkPool[T]{ closed: make(chan struct{}), - - ch: make(chan chan T, depth), + ch: make(chan chan T, depth), } for len(p.ch) < cap(p.ch) { @@ -207,10 +206,11 @@ func (p *WorkPool[T]) Close() error { } close(p.closed) - close(p.ch) p.wg.Wait() + close(p.ch) + for range p.ch { // drain the pool and drop them on all on the ground for GC. } @@ -254,7 +254,7 @@ func (p *WorkPool[T]) Put(v chan T) { return } - p.wg.Done() + defer p.wg.Done() select { case <-p.closed: @@ -263,6 +263,7 @@ func (p *WorkPool[T]) Put(v chan T) { } select { + case <-p.closed: case p.ch <- v: default: panic("worker pool overfill") From 2587c66fd1a6c7e61c4c22a4cd73333f4e79ad91 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 14:06:08 +0000 Subject: [PATCH 13/15] lock unlock around closed --- client_test.go | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/client_test.go b/client_test.go index f259ffc6..c4c33c73 100644 --- a/client_test.go +++ b/client_test.go @@ -78,8 +78,10 @@ type tickWriter struct { count atomic.Uint32 steps atomic.Uint32 - once sync.Once - step chan struct{} + mu sync.Mutex + once sync.Once + closed bool + step chan struct{} } func (t *tickWriter) wait() { @@ -89,6 +91,13 @@ func (t *tickWriter) wait() { } func (t *tickWriter) Write(b []byte) (written int, err error) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return 0, io.ErrClosedPipe + } + _ = t.count.Add(1) t.step <- struct{}{} @@ -97,7 +106,13 @@ func (t *tickWriter) Write(b []byte) (written int, err error) { } func (t *tickWriter) Close() error { - t.once.Do(func() { close(t.step) }) + t.once.Do(func() { + t.mu.Lock() + defer t.mu.Unlock() + + t.closed = true + close(t.step) + }) return nil } From d36cabc7ec137618d7921c878edddbeaee92581b Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 14:07:43 +0000 Subject: [PATCH 14/15] lock before respool.Get --- client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 470bed4f..0daf8f96 100644 --- a/client.go +++ b/client.go @@ -208,6 +208,9 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller } defer c.bufPool.Put(header) + c.mu.Lock() + defer c.mu.Unlock() + // payload by design of the API is all but guaranteed to alias a caller-held byte slice, // so, _do not_ put it into the bufPool. @@ -216,9 +219,6 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller return reqid, nil, sshfx.StatusConnectionLost } - c.mu.Lock() - defer c.mu.Unlock() - select { case <-cancel: c.resPool.Put(ch) From 7029f9de0592ece929c3acf826dcaf4ea0181c30 Mon Sep 17 00:00:00 2001 From: Cassondra Foesch Date: Tue, 19 May 2026 14:08:04 +0000 Subject: [PATCH 15/15] more stabbing at deadlocks --- internal/sync/pool.go | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/internal/sync/pool.go b/internal/sync/pool.go index 5d706f2d..7215efa7 100644 --- a/internal/sync/pool.go +++ b/internal/sync/pool.go @@ -210,7 +210,6 @@ func (p *WorkPool[T]) Close() error { p.wg.Wait() close(p.ch) - for range p.ch { // drain the pool and drop them on all on the ground for GC. } @@ -231,16 +230,22 @@ func (p *WorkPool[T]) Get() (chan T, bool) { select { case <-p.closed: - var ch chan T - return ch, false - default: - } + return nil, false - v, ok := <-p.ch - if ok { - p.wg.Add(1) + case v, ok := <-p.ch: + if !ok { + return nil, false + } + + select { + case <-p.closed: + return nil, false + + default: + p.wg.Add(1) + return v, true + } } - return v, ok } // Put returns the given work channel to the pool. @@ -256,12 +261,6 @@ func (p *WorkPool[T]) Put(v chan T) { defer p.wg.Done() - select { - case <-p.closed: - return - default: - } - select { case <-p.closed: case p.ch <- v: