Skip to content

Commit fc41d39

Browse files
committed
Restore streaming download+extract pipeline
The parallel range download and streaming extraction pipeline are now pipelined: S3 bytes flow directly through pzstd into the extractor with no intermediate temp file, so download and extraction overlap. Logging is split into two INFO lines: - "download complete": time from start until last S3 byte consumed (the dominant term; extraction runs concurrently so this ≈ total wall time) - "restore pipeline complete": total_duration + extract_tail (the small gap between last S3 byte and last file written) Also fix the Linux extractor to use io.CopyBuffer with a fixed 1 MB block buffer instead of whole-file ReadFull, matching GNU tar's streaming pattern.
1 parent 8b2c4c5 commit fc41d39

5 files changed

Lines changed: 235 additions & 154 deletions

File tree

cmd/gradle-cache/extract_darwin.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ import (
1313
"github.com/alecthomas/errors"
1414
)
1515

16+
// extractBufPool is a pool of reusable byte-slice pointers used by extractTarGo
17+
// on macOS. Reusing slices eliminates per-file heap allocations for the parallel
18+
// write path. Initial capacity is 256 KiB — large enough for most Gradle cache
19+
// files without needing a separate allocation.
20+
var extractBufPool = sync.Pool{
21+
New: func() interface{} {
22+
b := make([]byte, 0, 256<<10)
23+
return &b
24+
},
25+
}
26+
1627
// mmapThreshold is the minimum file size above which ftruncate+mmap+memcpy is
1728
// faster than write() on macOS APFS. Below this threshold, mmap setup overhead
1829
// exceeds the savings. 64 KB covers most Gradle .jar files.

cmd/gradle-cache/extract_default.go

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,28 @@ import (
1212
"github.com/alecthomas/errors"
1313
)
1414

15-
// extractTarPlatform falls back to sequential extraction on unknown platforms.
15+
// extractTarPlatform falls back to sequential streaming extraction on unknown platforms.
1616
func extractTarPlatform(r io.Reader, dir string) error {
1717
return extractTarSeq(r, dir)
1818
}
1919

2020
func extractTarSeq(r io.Reader, dir string) error {
21+
copyBuf := make([]byte, 1<<20)
2122
tr := tar.NewReader(r)
2223
cleanDir := filepath.Clean(dir) + string(os.PathSeparator)
24+
25+
createdDirs := make(map[string]struct{})
26+
ensureDir := func(d string) error {
27+
if _, ok := createdDirs[d]; ok {
28+
return nil
29+
}
30+
if err := os.MkdirAll(d, 0o750); err != nil {
31+
return err
32+
}
33+
createdDirs[d] = struct{}{}
34+
return nil
35+
}
36+
2337
for {
2438
hdr, err := tr.Next()
2539
if err == io.EOF {
@@ -28,37 +42,34 @@ func extractTarSeq(r io.Reader, dir string) error {
2842
if err != nil {
2943
return errors.Wrap(err, "read tar entry")
3044
}
45+
3146
target := filepath.Join(dir, hdr.Name)
3247
if !strings.HasPrefix(filepath.Clean(target)+string(os.PathSeparator), cleanDir) {
3348
return errors.Errorf("tar entry %q escapes destination directory", hdr.Name)
3449
}
50+
3551
switch hdr.Typeflag {
3652
case tar.TypeDir:
37-
if err := os.MkdirAll(target, hdr.FileInfo().Mode()); err != nil {
53+
if err := ensureDir(target); err != nil {
3854
return errors.Errorf("mkdir %s: %w", hdr.Name, err)
3955
}
4056
case tar.TypeReg:
41-
bufPtr := extractBufPool.Get().(*[]byte)
42-
if int64(cap(*bufPtr)) >= hdr.Size {
43-
*bufPtr = (*bufPtr)[:hdr.Size]
44-
} else {
45-
*bufPtr = make([]byte, hdr.Size)
46-
}
47-
if _, err := io.ReadFull(tr, *bufPtr); err != nil {
48-
extractBufPool.Put(bufPtr)
49-
return errors.Errorf("read %s: %w", hdr.Name, err)
50-
}
51-
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
52-
extractBufPool.Put(bufPtr)
57+
if err := ensureDir(filepath.Dir(target)); err != nil {
5358
return errors.Errorf("mkdir %s: %w", hdr.Name, err)
5459
}
55-
if err := os.WriteFile(target, *bufPtr, hdr.FileInfo().Mode()); err != nil {
56-
extractBufPool.Put(bufPtr)
60+
f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode())
61+
if err != nil {
62+
return errors.Errorf("open %s: %w", hdr.Name, err)
63+
}
64+
if _, err := io.CopyBuffer(f, io.LimitReader(tr, hdr.Size), copyBuf); err != nil {
65+
f.Close() //nolint:errcheck
5766
return errors.Errorf("write %s: %w", hdr.Name, err)
5867
}
59-
extractBufPool.Put(bufPtr)
68+
if err := f.Close(); err != nil {
69+
return errors.Errorf("close %s: %w", hdr.Name, err)
70+
}
6071
case tar.TypeSymlink:
61-
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
72+
if err := ensureDir(filepath.Dir(target)); err != nil {
6273
return errors.Errorf("mkdir for symlink %s: %w", hdr.Name, err)
6374
}
6475
if err := os.Symlink(hdr.Linkname, target); err != nil {

cmd/gradle-cache/extract_linux.go

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,33 @@ func extractTarPlatform(r io.Reader, dir string) error {
1919
return extractTarSeq(r, dir)
2020
}
2121

22-
// extractTarSeq extracts a tar stream sequentially using pooled buffers.
23-
// One goroutine reads and writes, avoiding goroutine-scheduling overhead and
24-
// VFS writeback fragmentation — the same pattern GNU tar uses on Linux.
22+
// extractTarSeq extracts a tar stream sequentially using a fixed-size copy
23+
// buffer. Files are streamed directly from the tar reader to disk one 1 MiB
24+
// block at a time — the same block-streaming pattern GNU tar uses — so the
25+
// decompressor pipe keeps flowing without large per-file allocations.
2526
func extractTarSeq(r io.Reader, dir string) error {
27+
// Single fixed-size copy buffer for all file writes in this call.
28+
// 1 MiB is large enough to amortise write syscall overhead without
29+
// creating memory pressure for many-file archives.
30+
copyBuf := make([]byte, 1<<20)
31+
2632
tr := tar.NewReader(r)
2733
cleanDir := filepath.Clean(dir) + string(os.PathSeparator)
34+
35+
// createdDirs tracks parent directories we have already MkdirAll'd so
36+
// each unique path is only created once (same optimisation as darwin).
37+
createdDirs := make(map[string]struct{})
38+
ensureDir := func(d string) error {
39+
if _, ok := createdDirs[d]; ok {
40+
return nil
41+
}
42+
if err := os.MkdirAll(d, 0o750); err != nil {
43+
return err
44+
}
45+
createdDirs[d] = struct{}{}
46+
return nil
47+
}
48+
2849
for {
2950
hdr, err := tr.Next()
3051
if err == io.EOF {
@@ -33,37 +54,36 @@ func extractTarSeq(r io.Reader, dir string) error {
3354
if err != nil {
3455
return errors.Wrap(err, "read tar entry")
3556
}
57+
3658
target := filepath.Join(dir, hdr.Name)
3759
if !strings.HasPrefix(filepath.Clean(target)+string(os.PathSeparator), cleanDir) {
3860
return errors.Errorf("tar entry %q escapes destination directory", hdr.Name)
3961
}
62+
4063
switch hdr.Typeflag {
4164
case tar.TypeDir:
42-
if err := os.MkdirAll(target, hdr.FileInfo().Mode()); err != nil {
65+
if err := ensureDir(target); err != nil {
4366
return errors.Errorf("mkdir %s: %w", hdr.Name, err)
4467
}
68+
4569
case tar.TypeReg:
46-
bufPtr := extractBufPool.Get().(*[]byte)
47-
if int64(cap(*bufPtr)) >= hdr.Size {
48-
*bufPtr = (*bufPtr)[:hdr.Size]
49-
} else {
50-
*bufPtr = make([]byte, hdr.Size)
51-
}
52-
if _, err := io.ReadFull(tr, *bufPtr); err != nil {
53-
extractBufPool.Put(bufPtr)
54-
return errors.Errorf("read %s: %w", hdr.Name, err)
55-
}
56-
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
57-
extractBufPool.Put(bufPtr)
70+
if err := ensureDir(filepath.Dir(target)); err != nil {
5871
return errors.Errorf("mkdir %s: %w", hdr.Name, err)
5972
}
60-
if err := os.WriteFile(target, *bufPtr, hdr.FileInfo().Mode()); err != nil {
61-
extractBufPool.Put(bufPtr)
73+
f, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode())
74+
if err != nil {
75+
return errors.Errorf("open %s: %w", hdr.Name, err)
76+
}
77+
if _, err := io.CopyBuffer(f, io.LimitReader(tr, hdr.Size), copyBuf); err != nil {
78+
f.Close() //nolint:errcheck
6279
return errors.Errorf("write %s: %w", hdr.Name, err)
6380
}
64-
extractBufPool.Put(bufPtr)
81+
if err := f.Close(); err != nil {
82+
return errors.Errorf("close %s: %w", hdr.Name, err)
83+
}
84+
6585
case tar.TypeSymlink:
66-
if err := os.MkdirAll(filepath.Dir(target), 0o750); err != nil {
86+
if err := ensureDir(filepath.Dir(target)); err != nil {
6787
return errors.Errorf("mkdir for symlink %s: %w", hdr.Name, err)
6888
}
6989
if err := os.Symlink(hdr.Linkname, target); err != nil {

cmd/gradle-cache/main.go

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"runtime"
2525
"strconv"
2626
"strings"
27-
"sync"
2827
"time"
2928

3029
"github.com/alecthomas/errors"
@@ -110,9 +109,12 @@ func (c *RestoreCmd) Run(ctx context.Context) error {
110109
}
111110
slog.Info("cache hit", "key", hitKey)
112111

113-
// ── Download phase ────────────────────────────────────────────────────────
114-
// Download to a temp file first so we get a clean download-speed measurement
115-
// independent of decompression and file-extraction throughput.
112+
// ── Download + extract phase (pipelined) ─────────────────────────────────
113+
// The S3 body streams directly into pzstd → extractor with no temp file.
114+
// Download and extraction run concurrently: pzstd decompresses as bytes
115+
// arrive, and the extractor writes files as blocks are decompressed.
116+
// This matches the Ruby aws-sdk-s3 behaviour and keeps total time close to
117+
// max(download_time, extract_time) rather than their sum.
116118
dlStart := time.Now()
117119
slog.Info("downloading bundle", "key", hitKey)
118120

@@ -121,36 +123,37 @@ func (c *RestoreCmd) Run(ctx context.Context) error {
121123
return errors.Wrap(err, "create temp dir")
122124
}
123125

124-
bundle, err := os.CreateTemp("", "gradle-cache-bundle-*")
126+
body, _, err := client.get(ctx, c.Bucket, hitKey)
125127
if err != nil {
126-
return errors.Wrap(err, "create bundle temp file")
128+
return errors.Wrap(err, "get bundle")
127129
}
128-
defer func() {
129-
bundle.Close() //nolint:errcheck,gosec
130-
os.Remove(bundle.Name()) //nolint:errcheck,gosec
131-
}()
130+
defer body.Close() //nolint:errcheck,gosec
132131

133-
dlBytes, err := client.download(ctx, c.Bucket, hitKey, bundle)
134-
if err != nil {
135-
return errors.Wrap(err, "download bundle")
132+
// countingBody records bytes consumed and timestamps when the S3 body is
133+
// exhausted so we can log download speed independently of extraction.
134+
cb := &countingBody{r: body, dlStart: dlStart}
135+
if err := extractTarZstd(ctx, cb, tmpDir); err != nil {
136+
return errors.Wrap(err, "extract bundle")
136137
}
137-
dlElapsed := time.Since(dlStart)
138-
dlMBps := float64(dlBytes) / dlElapsed.Seconds() / 1e6
139-
slog.Info("download complete", "duration", dlElapsed,
140-
"size_mb", fmt.Sprintf("%.1f", float64(dlBytes)/1e6),
141-
"speed_mbps", fmt.Sprintf("%.1f", dlMBps))
142138

143-
// ── Extract phase ─────────────────────────────────────────────────────────
144-
if _, err := bundle.Seek(0, io.SeekStart); err != nil {
145-
return errors.Wrap(err, "rewind bundle")
146-
}
147-
extractStart := time.Now()
148-
if err := extractTarZstd(ctx, bundle, tmpDir); err != nil {
149-
return errors.Wrap(err, "extract bundle")
139+
totalElapsed := time.Since(dlStart)
140+
141+
// Log download phase: time from start until the last S3 byte was consumed
142+
// by the pzstd pipeline. Because download and extraction run concurrently,
143+
// this is normally the dominant term.
144+
if !cb.eofAt.IsZero() {
145+
dlElapsed := cb.eofAt.Sub(dlStart)
146+
slog.Info("download complete", "duration", dlElapsed.Round(time.Millisecond),
147+
"size_mb", fmt.Sprintf("%.1f", float64(cb.n)/1e6),
148+
"speed_mbps", fmt.Sprintf("%.1f", float64(cb.n)/dlElapsed.Seconds()/1e6))
150149
}
151-
extractElapsed := time.Since(extractStart)
152-
slog.Info("extract complete", "duration", extractElapsed,
153-
"speed_mbps", fmt.Sprintf("%.1f", float64(dlBytes)/extractElapsed.Seconds()/1e6))
150+
151+
// Log total restore time (find + download + extraction, all pipelined).
152+
// The "extract tail" is the small gap between the last byte being consumed
153+
// and the last file being written; most extraction happened during download.
154+
slog.Info("restore pipeline complete",
155+
"total_duration", totalElapsed.Round(time.Millisecond),
156+
"extract_tail", time.Since(cb.eofAt).Round(time.Millisecond))
154157

155158
// Symlink $GRADLE_USER_HOME/caches → tmpDir/caches.
156159
cachesTarget := filepath.Join(tmpDir, "caches")
@@ -452,6 +455,25 @@ func zstdDecompressCmd(ctx context.Context) *exec.Cmd {
452455
// pzstd/zstd decompresses in parallel; the resulting tar stream is extracted
453456
// by extractTarGo (pooled-buffer parallel writer) or piped to system tar as
454457
// a fallback when building without CGO on platforms where tar is unavailable.
458+
// countingBody wraps an io.Reader, counts bytes consumed, and records the time
459+
// at which the underlying reader returns io.EOF (i.e. when the last S3 byte
460+
// was consumed by the downstream pipeline).
461+
type countingBody struct {
462+
r io.Reader
463+
n int64
464+
dlStart time.Time
465+
eofAt time.Time
466+
}
467+
468+
func (c *countingBody) Read(p []byte) (int, error) {
469+
n, err := c.r.Read(p)
470+
c.n += int64(n)
471+
if err == io.EOF && c.eofAt.IsZero() {
472+
c.eofAt = time.Now()
473+
}
474+
return n, err
475+
}
476+
455477
func extractTarZstd(ctx context.Context, r io.Reader, dir string) error {
456478
zstdCmd := zstdDecompressCmd(ctx)
457479
zstdCmd.Stdin = r
@@ -481,17 +503,6 @@ func extractTarZstd(ctx context.Context, r io.Reader, dir string) error {
481503
return errors.Join(errs...)
482504
}
483505

484-
// extractBufPool is a pool of reusable byte-slice pointers shared by all
485-
// platform extractors. Reusing slices eliminates per-file heap allocations and
486-
// the GC pressure they cause. Initial capacity is 256 KiB — large enough for
487-
// most Gradle cache files without needing a separate allocation.
488-
var extractBufPool = sync.Pool{
489-
New: func() interface{} {
490-
b := make([]byte, 0, 256<<10)
491-
return &b
492-
},
493-
}
494-
495506
// zstdCompressCmd returns the command for zstd compression.
496507
// Prefers pzstd (creates parallel frames, decompressable in parallel) and
497508
// falls back to zstd -TN -c.

0 commit comments

Comments
 (0)