diff --git a/Makefile b/Makefile index 65c6ae99..789c8ce4 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,36 @@ .PHONY: integration integration_w_race +DELAY ?= 0 + +ifneq ($(DELAY),0) +DELAY_FLAG=-delay $(DELAY) +endif + +TAGS=integration,sftp.sync.metrics + integration: - go test -v ./... + go test -v $(DELAY_FLAG) -tags=$(TAGS) + go test ./encoding/... ./internal/... make -C localfs integration integration_w_race: - go test -race -v ./... - make -C localfs integration + go test -race -v $(DELAY_FLAG) -tags=$(TAGS) + go test -race ./encoding/... ./internal/... + make -C localfs integration_w_race + +COUNT ?= 1 +BENCHMARK_PATTERN ?= "." + +benchmark: + go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) $(DELAY_FLAG) -tags=$(TAGS) + make -C localfs benchmark + +benchmark_w_memprofile: +ifneq ($(DELAY),0) + @echo "memprofile with DELAY produces invalid data" >&2 + @exit 1 +endif + go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) -memprofile memprofile.out -tags=integration + go tool pprof -sample_index=alloc_space -svg -output=memprofile-space.svg memprofile.out + go tool pprof -sample_index=alloc_objects -svg -output=memprofile-allocs.svg memprofile.out + make -C localfs benchmark_w_memprofile diff --git a/client.go b/client.go index 7428cdf6..0daf8f96 100644 --- a/client.go +++ b/client.go @@ -11,7 +11,7 @@ import ( "iter" "math" "os" - "path" + stdpath "path" "slices" "sync/atomic" "syscall" @@ -208,6 +208,9 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller } defer c.bufPool.Put(header) + c.mu.Lock() + defer c.mu.Unlock() + // payload by design of the API is all but guaranteed to alias a caller-held byte slice, // so, _do not_ put it into the bufPool. @@ -216,9 +219,6 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller return reqid, nil, sshfx.StatusConnectionLost } - c.mu.Lock() - defer c.mu.Unlock() - select { case <-cancel: c.resPool.Put(ch) @@ -226,12 +226,6 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller default: } - if c.inflight == nil { - c.inflight = make(map[uint32]chan<- result) - } - - c.inflight[reqid] = ch - if _, err := c.wr.Write(header); err != nil { c.resPool.Put(ch) return reqid, nil, fmt.Errorf("sftp: write packet header: %w", err) @@ -244,6 +238,11 @@ func (c *clientConn) dispatch(cancel <-chan struct{}, req sshfx.PacketMarshaller } } + if c.inflight == nil { + c.inflight = make(map[uint32]chan<- result) + } + c.inflight[reqid] = ch + return reqid, ch, nil } @@ -253,10 +252,18 @@ func (c *clientConn) returnRaw(raw *sshfx.RawPacket) { } func (c *clientConn) discardBlocking(ch chan result) { - res := <-ch + select { + case <-c.closed: + // We've been disconnected. + // We have to return something to the work pool, or resPool.Close will deadlock. + // It also has to be able to receive a result, + // otherwise a broadcasted error will lock up on a send. + c.resPool.Put(make(chan result, 1)) - c.returnRaw(res.pkt) - c.resPool.Put(ch) + case res := <-ch: + c.returnRaw(res.pkt) + c.resPool.Put(ch) + } } func (c *clientConn) discard(ch chan result) { @@ -281,6 +288,10 @@ func (c *clientConn) recv(ctx context.Context, reqid uint32, ch chan result) (*s c.discard(ch) return nil, ctx.Err() + case <-c.closed: + c.discard(ch) + return nil, sshfx.StatusConnectionLost + case res := <-ch: c.resPool.Put(ch) @@ -577,6 +588,8 @@ func (cl *Client) getDataBuf(size int) []byte { return hint[:size] // trim our buffer to length, it might be longer than chunkSize. } +// SSHSession is an interface for x/crypto/ssh.Session, +// in order to break the dependency on x/crypto. type SSHSession interface { StdinPipe() (io.WriteCloser, error) StdoutPipe() (io.Reader, error) @@ -588,6 +601,8 @@ type SSHSession interface { Wait() error } +// SSHClient is an interface for x/crypto/ssh.Client, +// in order to break the dependency on x/crypto. type SSHClient[Session SSHSession] interface { NewSession() (Session, error) } @@ -705,11 +720,17 @@ func (cl *Client) ReportPoolMetrics(wr io.Writer) { } } +// Wait blocks until the connection has shut down, +// and returns the error causing the shutdown. +// It can be called concurrently from multiple goroutines. +func (cl *Client) Wait() error { + return cl.conn.Wait() +} + // Close closes the SFTP session. func (cl *Client) Close() error { cl.conn.disconnect(nil) - cl.conn.wr.Close() - return nil + return cl.conn.wr.Close() } func wrapPathError(op, path string, err error) error { @@ -753,11 +774,8 @@ func wrapLinkError(op, oldpath, newpath string, err error) error { func (cl *Client) Mkdir(name string, perm fs.FileMode) error { return wrapPathError("mkdir", name, cl.sendPacket(context.Background(), nil, &sshfx.MkdirPacket{ - Path: name, - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, + Path: name, + Attrs: attrsFromPerm(perm), }), ) } @@ -777,7 +795,7 @@ func (cl *Client) MkdirAll(name string, perm fs.FileMode) error { // Slow path: make sure parent exists and then call Mkdir for name. - if parent := path.Dir(name); parent != "" { + if parent := stdpath.Dir(name); parent != "" { err = cl.MkdirAll(parent, perm) if err != nil { return err @@ -845,6 +863,45 @@ func (cl *Client) Remove(name string) error { return wrapPathError("remove", name, errF) } +// RemoveAll removes path and any children it contains. +// It removes everything it can but returns the first error it encounters. +// f the path does not exist, RemoveAll returns nil (no error). +// If there is an error, it will be of type *PathError. +func (cl *Client) RemoveAll(pathname string) error { + fi, err := cl.Stat(pathname) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + + if !fi.IsDir() { + return cl.Remove(pathname) + } + + files, err := cl.ReadDir(pathname) + if err != nil { + return err + } + + for _, file := range files { + filename := stdpath.Join(pathname, file.Name()) + switch { + case file.IsDir(): + if err := cl.RemoveAll(filename); err != nil { + return err + } + default: + if err := cl.Remove(filename); err != nil { + return err + } + } + } + + return nil +} + func (cl *Client) setstat(ctx context.Context, name string, attrs *sshfx.Attributes) error { return wrapPathError("setstat", name, cl.sendPacket(ctx, nil, &sshfx.SetStatPacket{ @@ -935,7 +992,7 @@ func (cl *Client) ReadLink(name string) (string, error) { // Server-specific restrictions may apply when old path and new path are in different directories. // Even within the same directory, on non-Unix servers Rename is not guaranteed to be an atomic operation. func (cl *Client) Rename(oldpath, newpath string) error { - if cl.hasExtension(openssh.ExtensionPOSIXRename()) { + if cl.HasExtension(openssh.ExtensionPOSIXRename()) { return wrapLinkError("rename", oldpath, newpath, cl.sendPacket(context.Background(), nil, &openssh.POSIXRenameExtendedPacket{ OldPath: oldpath, @@ -963,8 +1020,18 @@ func (cl *Client) Symlink(oldname, newname string) error { ) } -func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool { - return cl.exts[ext.Name] == ext.Data +// GetExtension returns the data associated with the extension name. +// If there is no data associated with the extension name, it returns the empty string. +func (cl *Client) GetExtension(name string) string { + return cl.exts[name] +} + +// HasExtension returns true if the data associated with the extension name is the same as extension data. +// +// This is a much more strict test for extensions than simple existence of data associated with the extension name. +// To test for anything wider than extension data, use [Client.GetExtension]. +func (cl *Client) HasExtension(extension *sshfx.ExtensionPair) bool { + return cl.exts[extension.Name] == extension.Data } // StatVFS retrieves VFS statistics from a remote host. @@ -986,7 +1053,7 @@ func (cl *Client) StatVFS(path string) (*openssh.StatVFSExtendedReplyPacket, err // then no request will be sent, // and Link returns an *fs.LinkError wrapping sshfx.StatusOpUnsupported. func (cl *Client) Link(oldname, newname string) error { - if !cl.hasExtension(openssh.ExtensionHardlink()) { + if !cl.HasExtension(openssh.ExtensionHardlink()) { return wrapLinkError("hardlink", oldname, newname, sshfx.StatusOpUnsupported) } @@ -1002,13 +1069,20 @@ func (cl *Client) Link(oldname, newname string) error { // If an error occurs reading the directory, // Readdir returns the entries it was able to read before the error, along with the error. func (cl *Client) Readdir(name string) ([]fs.FileInfo, error) { + return cl.ReaddirContext(context.Background(), name) +} + +// ReaddirContext reads the named directory, returning all its directory entries as [fs.FileInfo] sorted by filename. +// If an error occurs reading the directory, including the context being canceled, +// Readdir returns the entries it was able to read before the error, along with the error. +func (cl *Client) ReaddirContext(ctx context.Context, name string) ([]fs.FileInfo, error) { d, err := cl.OpenDir(name) if err != nil { return nil, err } defer d.Close() - fis, err := d.Readdir(0) + fis, err := d.ReaddirContext(ctx, 0) slices.SortFunc(fis, func(a, b fs.FileInfo) int { return cmp.Compare(a.Name(), b.Name()) @@ -1211,6 +1285,11 @@ func (d *Dir) rangedir(ctx context.Context, grow func(int)) iter.Seq2[*sshfx.Nam // Pull from saved entries first. for i, ent := range d.entries { + switch ent.Name() { + case ".", "..": // skip useless names. + continue + } + if !yield(ent, nil) { // This is a break condition. // We need to remove all entries that have been consumed, @@ -1422,6 +1501,14 @@ func (cl *Client) Create(name string) (*File, error) { return cl.OpenFile(name, OpenFlagReadWrite|OpenFlagCreate|OpenFlagTruncate, 0666) } +func attrsFromPerm(perm fs.FileMode) sshfx.Attributes { + var attrs sshfx.Attributes + if perm != 0 { + attrs.SetPermissions(sshfx.FileMode(perm.Perm())) + } + return attrs +} + // OpenFile is the generalized open call; // most users can use the simplified Open or Create methods instead. // It opens the named file with the specified flag (OpenFlagReadOnly, etc.). @@ -1434,10 +1521,7 @@ func (cl *Client) OpenFile(name string, flag int, perm fs.FileMode) (*File, erro handle, err := cl.getHandle(context.Background(), nil, &sshfx.OpenPacket{ Filename: name, PFlags: toPortableFlags(flag), - Attrs: sshfx.Attributes{ - Flags: sshfx.AttrPermissions, - Permissions: sshfx.FileMode(perm.Perm()), - }, + Attrs: attrsFromPerm(perm), }) if err != nil { return nil, wrapPathError("openfile", name, err) @@ -1632,7 +1716,7 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e req := &sshfx.WritePacket{ Handle: handle, - Offset: uint64(f.offset), + Offset: uint64(off), } for len(b) > 0 { @@ -1701,15 +1785,13 @@ func (f *File) writeat(ctx context.Context, b []byte, off int64) (written int, e // * the offset of the start of the first error received dispatching a write packet offset. // // Either way, this should be the last successfully written offset. - written := int64(firstErr.off) - f.offset - f.offset = int64(firstErr.off) + written := int64(firstErr.off) - off return int(written), f.wrapErr("writeat", firstErr.err) } // We didn't hit any errors, so we must have written all the bytes in the buffer. written = len(b) - f.offset += int64(written) return written, nil } @@ -1835,6 +1917,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { workCh := make(chan work, f.cl.maxInflight) type rwErr struct { + op string off uint64 err error } @@ -1860,7 +1943,11 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { for { n, err := io.ReadFull(r, b) if n < 0 { - errCh <- rwErr{req.Offset, panicInstead("sftp: readfrom: read returned negative count")} + errCh <- rwErr{ + op: "read", + off: req.Offset, + err: panicInstead("sftp: readfrom: read returned negative count"), + } return } @@ -1890,7 +1977,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { if err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { - errCh <- rwErr{req.Offset, err} + errCh <- rwErr{"read", req.Offset, err} } return } @@ -1908,7 +1995,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { for work := range workCh { err := f.cl.recvStatus(ctx, work.reqid, work.res, statusHint) if err != nil { - errCh <- rwErr{work.off, err} + errCh <- rwErr{"write", work.off, err} // DO NOT return. // We want to ensure that workCh is drained before errCh is closed. @@ -1944,7 +2031,7 @@ func (f *File) ReadFrom(r io.Reader) (read int64, err error) { } // ReadFrom is defined to return the read bytes, regardless of any write errors. - return read, f.wrapErr("readfrom", firstErr.err) + return read, f.wrapErr(firstErr.op, firstErr.err) } // We didn't hit any errors, so we must have written all the bytes that we read until EOF. @@ -2265,7 +2352,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { } }() - var writeErr error + var readErr error // Dispatch: Dispatch into any number of Reads of length <= f.cl.maxDataLen. go func() { @@ -2282,7 +2369,7 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { for { reqid, res, err := f.cl.conn.dispatch(closed, req) if err != nil { - writeErr = err + readErr = err return } @@ -2342,11 +2429,11 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { return written, nil // return nil instead of EOF } - return written, f.wrapErr("writeto", err) + return written, f.wrapErr("recv", err) } } - return written, f.wrapErr("writeto", writeErr) + return written, f.wrapErr("read", readErr) } // WriteFile writes data to the named file, creating it if neccessary. @@ -2468,7 +2555,7 @@ func (f *File) Sync() error { return f.wrapErr("fsync", err) } - if !f.cl.hasExtension(openssh.ExtensionFSync()) { + if !f.cl.HasExtension(openssh.ExtensionFSync()) { return f.wrapErr("fsync", sshfx.StatusOpUnsupported) } diff --git a/client_integration_test.go b/client_integration_test.go new file mode 100644 index 00000000..6c60e9f8 --- /dev/null +++ b/client_integration_test.go @@ -0,0 +1,819 @@ +//go:build integration && !windows + +package sftp + +import ( + "bytes" + "errors" + "flag" + "fmt" + "io" + "os" + "os/exec" + "slices" + "testing" + "time" + + sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" + "github.com/pkg/sftp/v2/internal/sync" +) + +const ( + kibi = 1024 + mebi = 1024 * 1024 +) + +const ( + ReadOnly = true + ReadWrite = false + NoDelay time.Duration = 0 +) + +const debuglevel = "ERROR" // set to "DEBUG" for debugging + +var testSftp *string + +func TestMain(m *testing.M) { + sftpServerLocations := []string{ + "/usr/libexec/ssh/sftp-server", + "/usr/libexec/sftp-server", + "/usr/lib/openssh/sftp-server", + "/usr/lib/ssh/sftp-server", + `C:\Program Files\Git\usr\lib\ssh\sftp-server.exe`, + } + + sftpServer, _ := exec.LookPath("sftp-server") + if sftpServer == "" { + for _, loc := range sftpServerLocations { + if _, err := os.Stat(loc); err == nil { + sftpServer = loc + break + } + } + + if sftpServer == "" { + fmt.Fprintln(os.Stdout, "FAIL: could not find sftp-server") + os.Exit(1) + } + } + + testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") + flag.Parse() + + os.Exit(m.Run()) +} + +type delayedWrite struct { + t time.Time + b []byte +} + +// delayedWriter wraps a writer and artificially delays the write. +// This is meant to mimic connections with various latencies. +// Errors returned from the underlying writer will panic so this should only be used over reliable connections. +type delayedWriter struct { + ch chan delayedWrite + closing chan struct{} + + wg sync.WaitGroup + closed <-chan struct{} +} + +func newDelayedWriter(t testing.TB, wr io.WriteCloser, delay time.Duration) *delayedWriter { + closed := make(chan struct{}) + + dw := &delayedWriter{ + ch: make(chan delayedWrite, 128), + closing: make(chan struct{}), + closed: closed, + } + + ctx := t.Context() + + go func() { + defer close(closed) + + defer wr.Close() + + for write := range dw.ch { + select { + case <-ctx.Done(): + return + case <-time.After(time.Until(write.t.Add(delay))): + } + + n, err := wr.Write(write.b) + if err != nil { + panic(err) + } + + if n < len(write.b) { + panic(io.ErrShortWrite) + } + } + }() + + return dw +} + +func (dw *delayedWriter) Write(b []byte) (int, error) { + select { + case <-dw.closing: + return 0, io.ErrClosedPipe + default: + } + + dw.wg.Add(1) + defer dw.wg.Done() + + dw.ch <- delayedWrite{ + t: time.Now(), + b: slices.Clone(b), + } + + return len(b), nil +} + +func (dw *delayedWriter) Close() error { + close(dw.closing) + + dw.wg.Wait() // wait for any outstanding blocking writes + + close(dw.ch) + + <-dw.closed // wait for writer goroutine to finish. + + return nil +} + +func testClient(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { + args := []string{ + "-e", + "-l", debuglevel, + } + if readonly { + args = append(args, "-R") + } + + cmd := exec.Command(*testSftp, args...) + + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + + if delay > 0 { + pw = newDelayedWriter(t, pw, delay) + } + + pr, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + + if err := cmd.Start(); err != nil { + t.Skip("could not start sftp-server process:", err) + } + + cl, err := newClientPipe(t.Context(), pr, nil, pw, nil, opts) + if err != nil { + t.Fatal(err) + } + + return cl, cmd +} + +// github.com/pkg/sftp/issues/42, abrupt server hangup would result in client hangs. +func TestServerRoughDisconnect(t *testing.T) { + cl, cmd := testClient(t, ReadOnly, NoDelay) + defer cmd.Wait() + defer cl.Close() + + f, err := cl.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(io.Discard, f) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +// github.com/pkg/sftp/issues/181, abrupt server hangup would result in client hangs. +// due to broadcastErr filling up the request channel +// this reproduces it about 50% of the time +func TestServerRoughDisconnect2(t *testing.T) { + cl, cmd := testClient(t, ReadOnly, NoDelay) + defer cmd.Wait() + defer cl.Close() + + f, err := cl.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + b := make([]byte, 100*32*kibi) + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + for { + if _, err := f.Read(b); err != nil { + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("File.Read error = %#v, but wanted sshfx.StatusConnectionLost", err) + } + break + } + } +} + +// github.com/pkg/sftp/issues/234 - abrupt shutdown during ReadFrom hangs client +func TestServerRoughDisconnect3(t *testing.T) { + cl, cmd := testClient(t, ReadWrite, NoDelay) + defer cmd.Wait() + defer cl.Close() + + dst, err := cl.OpenFile("/dev/null", OpenFlagReadWrite, 0) + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(dst, src) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +// github.com/pkg/sftp/issues/234 - also affected Write +func TestServerRoughDisconnect4(t *testing.T) { + cl, cmd := testClient(t, ReadWrite, NoDelay) + defer cmd.Wait() + defer cl.Close() + + dst, err := cl.OpenFile("/dev/null", OpenFlagReadWrite, 0) + if err != nil { + t.Fatal(err) + } + defer dst.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + b := make([]byte, 200*32*kibi) + + if _, err = src.Read(b); err != nil { + t.Fatal(err) + } + + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + for { + if _, err := dst.Write(b); err != nil { + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("dst.Write error = %#v, but wanted sshfx.StatusConnectionLost", err) + } + break + } + } + + _, err = io.Copy(dst, src) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.Copy error = %#v, but wanted sshfx.StatusConnectionLost", err) + } +} + +func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + cl, cmd := testClient(b, ReadOnly, delay) + defer cmd.Wait() + defer cl.Close() + + buf := make([]byte, bufsize) + + b.SetBytes(int64(size)) + + for b.Loop() { + offset := 0 + + f, err := cl.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + + for offset < size { + remaining := size - offset + buf := buf[:min(remaining, len(buf))] + + n, err := io.ReadFull(f, buf) + offset += n + + if err != nil { + b.Fatalf("read error at %d: %v", offset, err) + } + } + + switch { + case offset < size: + b.Fatalf("read too few bytes! read: %d, wanted: %d", offset, size) + case offset > size: + b.Fatalf("read too many bytes! read: %d, wanted: %d", offset, size) + } + + f.Close() + } +} + +func BenchmarkRead1k(b *testing.B) { + benchmarkRead(b, 1*kibi, NoDelay) +} + +func BenchmarkRead16k(b *testing.B) { + benchmarkRead(b, 16*kibi, NoDelay) +} + +func BenchmarkRead32k(b *testing.B) { + benchmarkRead(b, 32*kibi, NoDelay) +} + +func BenchmarkRead128k(b *testing.B) { + benchmarkRead(b, 128*kibi, NoDelay) +} + +func BenchmarkRead512k(b *testing.B) { + benchmarkRead(b, 512*kibi, NoDelay) +} + +func BenchmarkRead1MiB(b *testing.B) { + benchmarkRead(b, mebi, NoDelay) +} + +func BenchmarkRead4MiB(b *testing.B) { + benchmarkRead(b, 4*mebi, NoDelay) +} + +func BenchmarkRead4MiBDelay10Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkRead4MiBDelay50Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkRead4MiBDelay150Msec(b *testing.B) { + benchmarkRead(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 0x123 // ~10MiB + + cl, cmd := testClient(b, ReadWrite, delay) + defer cmd.Wait() + defer cl.Close() + + data := make([]byte, size) + for i := range data { + data[i] = uint8(i >> ((i % 4) * 8)) + } + + b.SetBytes(int64(size)) + + for b.Loop() { + func() { + offset := 0 + + f, err := os.CreateTemp("", "sftptest-benchwrite") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + f2, err := cl.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + for offset < size { + buf := data[offset:] + buf = buf[:min(len(buf), bufsize)] + + n, err := f2.Write(buf) + offset += n + if err != nil { + b.Fatalf("write error at %d: %v", offset, err) + } + + if n != len(buf) { + b.Fatalf("wrote too few bytes! written: %d, wanted: %d", n, len(buf)) + } + } + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), size) + } + }() + } +} + +func BenchmarkWrite1k(b *testing.B) { + benchmarkWrite(b, 1*kibi, NoDelay) +} + +func BenchmarkWrite16k(b *testing.B) { + benchmarkWrite(b, 16*kibi, NoDelay) +} + +func BenchmarkWrite32k(b *testing.B) { + benchmarkWrite(b, 32*kibi, NoDelay) +} + +func BenchmarkWrite128k(b *testing.B) { + benchmarkWrite(b, 128*kibi, NoDelay) +} + +func BenchmarkWrite512k(b *testing.B) { + benchmarkWrite(b, 512*kibi, NoDelay) +} + +func BenchmarkWrite1MiB(b *testing.B) { + benchmarkWrite(b, mebi, NoDelay) +} + +func BenchmarkWrite4MiB(b *testing.B) { + benchmarkWrite(b, 4*mebi, NoDelay) +} + +func BenchmarkWrite4MiBDelay10Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay50Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay150Msec(b *testing.B) { + benchmarkWrite(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + // open sftp client + cl, cmd := testClient(b, ReadWrite, delay) + defer cmd.Wait() + defer cl.Close() + + data := make([]byte, size) + + b.SetBytes(int64(size)) + + for b.Loop() { + func() { + f, err := os.CreateTemp("", "sftptest-benchreadfrom") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + f2, err := cl.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + f2.ReadFrom(bytes.NewReader(data)) + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), size) + } + }() + } +} + +func BenchmarkReadFrom1k(b *testing.B) { + benchmarkReadFrom(b, 1*kibi, NoDelay) +} + +func BenchmarkReadFrom16k(b *testing.B) { + benchmarkReadFrom(b, 16*kibi, NoDelay) +} + +func BenchmarkReadFrom32k(b *testing.B) { + benchmarkReadFrom(b, 32*kibi, NoDelay) +} + +func BenchmarkReadFrom128k(b *testing.B) { + benchmarkReadFrom(b, 128*kibi, NoDelay) +} + +func BenchmarkReadFrom512k(b *testing.B) { + benchmarkReadFrom(b, 512*kibi, NoDelay) +} + +func BenchmarkReadFrom1MiB(b *testing.B) { + benchmarkReadFrom(b, mebi, NoDelay) +} + +func BenchmarkReadFrom4MiB(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, NoDelay) +} + +func BenchmarkReadFrom4MiBDelay10Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay50Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { + benchmarkReadFrom(b, 4*mebi, 150*time.Millisecond) +} + +func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { + size := 10*mebi + 123 // ~10MiB + + // open sftp client + cl, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer cl.Close() + + f, err := os.CreateTemp("", "sftptest-benchwriteto") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + data := make([]byte, size) + + if _, err = f.Write(data); err != nil { + b.Fatal(err) + } + + if err = f.Close(); err != nil { + b.Fatal(err) + } + + buf := bytes.NewBuffer(make([]byte, 0, size)) + + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + buf.Reset() + + f2, err := cl.Open(f.Name()) + if err != nil { + b.Fatal(err) + } + + if _, err = f2.WriteTo(buf); err != nil { + b.Fatal(err) + } + + if err = f2.Close(); err != nil { + b.Fatal(err) + } + + if buf.Len() != size { + b.Fatalf("wrong buffer size: got: %d, want: %d", buf.Len(), size) + } + } +} + +func BenchmarkWriteTo1k(b *testing.B) { + benchmarkWriteTo(b, 1*kibi, NoDelay) +} + +func BenchmarkWriteTo16k(b *testing.B) { + benchmarkWriteTo(b, 16*kibi, NoDelay) +} + +func BenchmarkWriteTo32k(b *testing.B) { + benchmarkWriteTo(b, 32*kibi, NoDelay) +} + +func BenchmarkWriteTo128k(b *testing.B) { + benchmarkWriteTo(b, 128*kibi, NoDelay) +} + +func BenchmarkWriteTo512k(b *testing.B) { + benchmarkWriteTo(b, 512*kibi, NoDelay) +} + +func BenchmarkWriteTo1MiB(b *testing.B) { + benchmarkWriteTo(b, mebi, NoDelay) +} + +func BenchmarkWriteTo4MiB(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, NoDelay) +} + +func BenchmarkWriteTo4MiBDelay10Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 10*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay50Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 50*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay150Msec(b *testing.B) { + benchmarkWriteTo(b, 4*mebi, 150*time.Millisecond) +} + +type zeroSource struct{} + +func (zeroSource) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} + +func (zeroSource) Close() error { return nil } + +func zeroFile(t testing.TB, filename string, filesize int64) string { + src, err := os.CreateTemp("", filename) + if err != nil { + t.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(src, io.LimitReader(zeroSource{}, filesize)) + if err != nil { + t.Fatal(err) + } + + if n < filesize { + t.Fatal("short copy") + } + + return src.Name() +} + +func benchmarkCopyDown(b *testing.B, filesize int64, delay time.Duration) { + srcFilename := zeroFile(b, "sftptest-benchcopydown-src", filesize) + defer os.Remove(srcFilename) + + cl, cmd := testClient(b, ReadOnly, delay) + defer cmd.Wait() + defer cl.Close() + + b.SetBytes(filesize) + + for b.Loop() { + func() { + dst, err := os.CreateTemp("", "sftptest-benchcopydown-dst") + if err != nil { + b.Fatal(err) + } + defer os.Remove(dst.Name()) + defer dst.Close() + + src, err := cl.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(dst, src) + if err != nil { + b.Fatal("copy error:", err) + } + + if n != filesize { + b.Fatalf("wrong bytes copied: got: %d, want: %d", n, filesize) + } + + fi, err := dst.Stat() + if err != nil { + b.Fatal(err) + } + + if fi.Size() != filesize { + b.Fatalf("wrong file size: got: %d, want: %d", fi.Size(), filesize) + } + }() + } +} + +func BenchmarkCopyDown10MiBDelay10Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 10*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay50Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 50*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay150Msec(b *testing.B) { + benchmarkCopyDown(b, 10*mebi, 150*time.Millisecond) +} + +func benchmarkCopyUp(b *testing.B, filesize int64, delay time.Duration) { + srcFilename := zeroFile(b, "sftptest-benchcopyup-src", filesize) + defer os.Remove(srcFilename) + + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + b.SetBytes(filesize) + + for b.Loop() { + func() { + // We need to create the destination filename through the OS first. + tmpDst, err := os.CreateTemp("", "sftptest-benchcopyup-dst") + if err != nil { + b.Fatal(err) + } + dstFilename := tmpDst.Name() + tmpDst.Close() + + defer os.Remove(dstFilename) + + // Now, we can use the filename from above as our destination. + dst, err := sftp.Create(dstFilename) + if err != nil { + b.Fatal(err) + } + defer dst.Close() + + src, err := os.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + + n, err := io.Copy(dst, src) + if err != nil { + b.Fatal("copy error:", err) + } + + if n < filesize { + b.Error("unable to copy all bytes") + } + + fi, err := os.Stat(dstFilename) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != filesize { + b.Errorf("wrong file size: got %d, want %d", fi.Size(), filesize) + } + }() + } +} + +func BenchmarkCopyUp10MiBDelay10Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 10*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay50Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 50*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay150Msec(b *testing.B) { + benchmarkCopyUp(b, 10*mebi, 150*time.Millisecond) +} diff --git a/client_test.go b/client_test.go index e9e68eab..c4c33c73 100644 --- a/client_test.go +++ b/client_test.go @@ -1,9 +1,15 @@ package sftp import ( + "bytes" + "errors" "io" "io/fs" + "sync/atomic" "testing" + + sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" + "github.com/pkg/sftp/v2/internal/sync" ) func TestClient(t *testing.T) { @@ -36,3 +42,193 @@ func TestClient(t *testing.T) { var _ allFS = new(fsys) } + +type sink struct{} + +func (sink) Close() error { return nil } +func (sink) Write(p []byte) (int, error) { return len(p), nil } + +func TestClientZeroLengthPacket(t *testing.T) { + // Packet length zero (never valid). This used to crash the client. + r := bytes.NewReader([]byte{0, 0, 0, 0}) + + cl, err := NewClientPipe(t.Context(), r, sink{}) + if err == nil { + t.Error("expected an error, got nil") + } + if cl != nil { + cl.Close() + } +} + +func TestClientShortPacket(t *testing.T) { + // init packet too short. + r := bytes.NewReader([]byte{0, 0, 0, 1, 2}) + + cl, err := NewClientPipe(t.Context(), r, sink{}) + if !errors.Is(err, sshfx.ErrShortPacket) { + t.Fatalf("got error %#v, but expected sshfx.ErrShortPacket", err) + } + if cl != nil { + cl.Close() + } +} + +type tickWriter struct { + count atomic.Uint32 + steps atomic.Uint32 + + mu sync.Mutex + once sync.Once + closed bool + step chan struct{} +} + +func (t *tickWriter) wait() { + _ = t.steps.Add(1) + + <-t.step +} + +func (t *tickWriter) Write(b []byte) (written int, err error) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.closed { + return 0, io.ErrClosedPipe + } + + _ = t.count.Add(1) + + t.step <- struct{}{} + + return len(b), nil +} + +func (t *tickWriter) Close() error { + t.once.Do(func() { + t.mu.Lock() + defer t.mu.Unlock() + + t.closed = true + close(t.step) + }) + return nil +} + +func streamPackets(t testing.TB, verPkt *sshfx.VersionPacket, packets ...sshfx.PacketMarshaller) (io.ReadCloser, io.WriteCloser) { + pr, pw := io.Pipe() + + ticker := &tickWriter{ + step: make(chan struct{}), + } + + go func() { + defer func() { + for range ticker.step { + } + }() + + defer pw.Close() + + bindata, err := verPkt.MarshalBinary() + if err != nil { + t.Errorf("could not binary marshal %#v: %v", verPkt, err) + return + } + + ticker.wait() + if _, err := pw.Write(bindata); err != nil { + t.Error(err) + return + } + + var reqid uint32 + for _, packet := range packets { + reqid++ + + header, payload, err := packet.MarshalPacket(reqid, nil) + if err != nil { + t.Errorf("could not packet marshal %#v: %v", packet, err) + return + } + + ticker.wait() + + if _, err := pw.Write(header); err != nil { + t.Error("could not write packet header:", err) + return + } + + if _, err := pw.Write(payload); err != nil { + t.Error("could not write packet payload:", err) + return + } + } + }() + + return pr, ticker +} + +type noSidBrokenPacket struct{} + +func (noSidBrokenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + return []byte{0, 0, 0, 10, 0, 0}, nil, nil +} + +func (noSidBrokenPacket) MarshalSize() int { + return 10 +} + +// Issue #418: panic in clientConn.recv when the sid is incomplete. +func TestClientNoSid(t *testing.T) { + rd, wr := streamPackets(t, + &sshfx.VersionPacket{ + Version: sftpProtocolVersion, + }, + noSidBrokenPacket{}, + ) + + cl, err := NewClientPipe(t.Context(), rd, wr) + if err != nil { + t.Fatal(err) + } + defer cl.Close() + + _, err = cl.Stat("anything") + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("cl.Stat = %q, expected sshfx.StatusConnectionLost", err) + } +} + +// sftp/issue/390 - server disconnect should not cause io.EOF or +// io.ErrUnexpectedEOF in sftp.File.Read, because those confuse io.ReadFull. +func TestClientRoughDisconnectEOF(t *testing.T) { + rd, wr := streamPackets(t, + &sshfx.VersionPacket{ + Version: sftpProtocolVersion, + }, + &sshfx.HandlePacket{ + Handle: "foo", + }, + &sshfx.DataPacket{ + Data: []byte("foo"), + }, + ) + + cl, err := NewClientPipe(t.Context(), rd, wr) + if err != nil { + t.Fatal(err) + } + defer cl.Close() + + f, err := cl.Open("anything") + if err != nil { + t.Fatal(err) + } + + _, err = io.ReadFull(f, make([]byte, 10)) + if !errors.Is(err, sshfx.StatusConnectionLost) { + t.Errorf("io.ReadFull error = %q, but wanted sshfx.StatusConnectionLost", err) + } +} diff --git a/encoding/ssh/filexfer/buffer.go b/encoding/ssh/filexfer/buffer.go index 0d8cec1c..8dde0101 100644 --- a/encoding/ssh/filexfer/buffer.go +++ b/encoding/ssh/filexfer/buffer.go @@ -197,7 +197,7 @@ func (b *Buffer) ConsumeUint32() uint32 { func (b *Buffer) AppendUint32(v uint32) { b.b = binary.BigEndian.AppendUint32(b.b, v) } -//*/ + // ConsumeCount consumes a single uint32 count from the buffer, in network byte order (big-endian) as an int. // If the buffer does not have enough data, it will set Err to ErrShortPacket. func (b *Buffer) ConsumeCount() (int, error) { diff --git a/encoding/ssh/filexfer/extended_packets_test.go b/encoding/ssh/filexfer/extended_packets_test.go index e097e5b9..79843f2a 100644 --- a/encoding/ssh/filexfer/extended_packets_test.go +++ b/encoding/ssh/filexfer/extended_packets_test.go @@ -183,7 +183,7 @@ func TestExtendedPacketTestData(t *testing.T) { t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) } - wantBuffer := []byte{ textValue^0x2a } + wantBuffer := []byte{textValue ^ 0x2a} if buf, ok := p.Data.(*Buffer); !ok { t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) @@ -243,7 +243,7 @@ func TestExtendedPacketTestBuffer(t *testing.T) { t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) } - wantBuffer := []byte{ 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r' } + wantBuffer := []byte{0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r'} if buf, ok := p.Data.(*Buffer); !ok { t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) diff --git a/encoding/ssh/filexfer/response_packets.go b/encoding/ssh/filexfer/response_packets.go index 60ec6d04..c85fcdd8 100644 --- a/encoding/ssh/filexfer/response_packets.go +++ b/encoding/ssh/filexfer/response_packets.go @@ -44,7 +44,7 @@ func (p *StatusPacket) MarshalSize() int { // uint32(length) + uint8(type) + uint32(request-id) const size = 4 + 1 + 4 - // uint32(error/status code) + string(error message) + string(language tag) + // uint32(error/status code) + string(error message) + string(language tag) return size + 4 + 4 + len(p.ErrorMessage) + 4 + len(p.LanguageTag) } diff --git a/errno_posix.go b/errno_posix.go index bc5bbcd5..45a1f914 100644 --- a/errno_posix.go +++ b/errno_posix.go @@ -1,5 +1,4 @@ //go:build !plan9 -// +build !plan9 package sftp diff --git a/fuzz.go b/fuzz.go index 72202818..324eb20b 100644 --- a/fuzz.go +++ b/fuzz.go @@ -1,6 +1,4 @@ -// go:build gofuzz //go:build gofuzz -// +build gofuzz package sftp diff --git a/internal/sync/aliases.go b/internal/sync/aliases.go index d611b71e..53d72e1f 100644 --- a/internal/sync/aliases.go +++ b/internal/sync/aliases.go @@ -12,3 +12,6 @@ type RWMutex = sync.RWMutex // WaitGroup is an alias to [sync.WaitGroup] type WaitGroup = sync.WaitGroup + +// Once is an alias to [sync.Once] +type Once = sync.Once diff --git a/internal/sync/pool.go b/internal/sync/pool.go index 9c3e6dc2..7215efa7 100644 --- a/internal/sync/pool.go +++ b/internal/sync/pool.go @@ -168,7 +168,8 @@ func (p *Pool[T]) Put(v *T) { // relieving pressure on the garbage collector and amortizing allocation overhead. // While also co-ordinating outstanding work, so the caller can wait for all work to be complete. type WorkPool[T any] struct { - wg sync.WaitGroup + wg sync.WaitGroup + closed chan struct{} ch chan chan T } @@ -178,7 +179,8 @@ type WorkPool[T any] struct { // It will panic if given a negative depth, the same as making a negative-buffer channel. func NewWorkPool[T any](depth int) *WorkPool[T] { p := &WorkPool[T]{ - ch: make(chan chan T, depth), + closed: make(chan struct{}), + ch: make(chan chan T, depth), } for len(p.ch) < cap(p.ch) { @@ -203,10 +205,11 @@ func (p *WorkPool[T]) Close() error { return errors.New("cannot close nil work pool") } - close(p.ch) + close(p.closed) p.wg.Wait() + close(p.ch) for range p.ch { // drain the pool and drop them on all on the ground for GC. } @@ -225,11 +228,24 @@ func (p *WorkPool[T]) Get() (chan T, bool) { return make(chan T, 1), true } - v, ok := <-p.ch - if ok { - p.wg.Add(1) + select { + case <-p.closed: + return nil, false + + case v, ok := <-p.ch: + if !ok { + return nil, false + } + + select { + case <-p.closed: + return nil, false + + default: + p.wg.Add(1) + return v, true + } } - return v, ok } // Put returns the given work channel to the pool. @@ -243,9 +259,11 @@ func (p *WorkPool[T]) Put(v chan T) { return } + defer p.wg.Done() + select { + case <-p.closed: case p.ch <- v: - p.wg.Done() default: panic("worker pool overfill") // This is an overfill, which shouldn't happen, but just in case... diff --git a/ls_plan9.go b/ls_plan9.go index ed7dc044..21e8f11e 100644 --- a/ls_plan9.go +++ b/ls_plan9.go @@ -1,5 +1,4 @@ //go:build plan9 -// +build plan9 package sftp diff --git a/ls_stub.go b/ls_stub.go index 1cf29cd6..b584d695 100644 --- a/ls_stub.go +++ b/ls_stub.go @@ -1,5 +1,4 @@ //go:build windows || android -// +build windows android package sftp diff --git a/ls_unix.go b/ls_unix.go index ec4efa2f..2e659801 100644 --- a/ls_unix.go +++ b/ls_unix.go @@ -1,5 +1,4 @@ //go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js || wasip1 || zos -// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js wasip1 zos package sftp