Skip to content
Open
26 changes: 15 additions & 11 deletions internal/pb/pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type ProgressBar struct {
type progressBar struct {
*mpbv8.Bar
size int64
msg string
msg atomic.Value // stores string; accessed by mpb render goroutine
startTime time.Time
}

Expand Down Expand Up @@ -93,27 +93,26 @@ func (p *ProgressBar) Add(prompt, name string, size int64, reader io.Reader) io.
return reader
}

p.mu.RLock()
oldBar := p.bars[name]
p.mu.RUnlock()
p.mu.Lock()
defer p.mu.Unlock()

// If the bar exists, drop and remove it.
if oldBar != nil {
if oldBar := p.bars[name]; oldBar != nil {
oldBar.Abort(true)
}

newBar := &progressBar{
size: size,
msg: fmt.Sprintf("%s %s", prompt, name),
startTime: time.Now(),
}
// Create a new bar if it does not exist.
newBar.msg.Store(fmt.Sprintf("%s %s", prompt, name))

newBar.Bar = p.mpb.New(size,
mpbv8.BarStyle(),
mpbv8.BarFillerOnComplete("|"),
mpbv8.PrependDecorators(
decor.Any(func(s decor.Statistics) string {
return newBar.msg
return newBar.msg.Load().(string)
}, decor.WCSyncSpaceR),
),
mpbv8.AppendDecorators(
Expand All @@ -129,9 +128,7 @@ func (p *ProgressBar) Add(prompt, name string, size int64, reader io.Reader) io.
),
)

p.mu.Lock()
p.bars[name] = newBar
p.mu.Unlock()

if reader != nil {
return newBar.ProxyReader(reader)
Expand All @@ -156,7 +153,7 @@ func (p *ProgressBar) Complete(name string, msg string) {
p.mu.RUnlock()

if ok {
bar.msg = msg
bar.msg.Store(msg)
bar.Bar.SetCurrent(bar.size)
}
}
Expand All @@ -173,6 +170,13 @@ func (p *ProgressBar) Abort(name string, err error) {
}
}

// Reset resets an existing progress bar for a new phase.
// Aborts the old bar and creates a new one with updated prompt,
// reset progress, and fresh speed counter. Parameter order matches Add.
func (p *ProgressBar) Reset(prompt, name string, size int64, reader io.Reader) io.Reader {
return p.Add(prompt, name, size, reader)
}

// Start starts the progress bar.
func (p *ProgressBar) Start() {}

Expand Down
245 changes: 245 additions & 0 deletions internal/pb/pb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright 2025 The CNAI Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package pb

import (
"bytes"
"errors"
"io"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// --- Functional tests ---

func TestAdd_WrapsReader(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

input := "hello world"
reader := pb.Add("Building =>", "test-file", int64(len(input)), strings.NewReader(input))

require.NotNil(t, reader)
var buf bytes.Buffer
n, err := io.Copy(&buf, reader)
assert.NoError(t, err)
assert.Equal(t, int64(len(input)), n)
assert.Equal(t, input, buf.String())
}

func TestAdd_NilReader(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

reader := pb.Add("Checking =>", "test-file", 100, nil)
assert.Nil(t, reader)

// Bar should still be created and tracked.
bar := pb.Get("test-file")
assert.NotNil(t, bar)
}

func TestAdd_ReplacesExistingBar(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

pb.Add("Phase1 =>", "test-file", 100, strings.NewReader("first"))
bar1 := pb.Get("test-file")
require.NotNil(t, bar1)

pb.Add("Phase2 =>", "test-file", 200, strings.NewReader("second"))
bar2 := pb.Get("test-file")
require.NotNil(t, bar2)

// Should be a different bar instance with new size.
assert.Equal(t, int64(200), bar2.size)
assert.Equal(t, "Phase2 => test-file", bar2.msg.Load().(string))
}

func TestAdd_DisabledProgress(t *testing.T) {
SetDisableProgress(true)
defer SetDisableProgress(false)

pb := NewProgressBar(io.Discard)
defer pb.Stop()

input := strings.NewReader("test")
reader := pb.Add("Building =>", "test-file", 4, input)
// When disabled, should return the exact same reader.
assert.Equal(t, input, reader)
}

func TestReset_SwitchesPhase(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

pb.Add("Hashing =>", "test-file", 100, strings.NewReader("hash-data"))

reader := pb.Reset("Building =>", "test-file", 100, strings.NewReader("build-data"))
assert.NotNil(t, reader)

bar := pb.Get("test-file")
require.NotNil(t, bar)
assert.Equal(t, "Building => test-file", bar.msg.Load().(string))
}

func TestReset_NoExistingBar(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

reader := pb.Reset("Building =>", "new-file", 50, strings.NewReader("data"))
assert.NotNil(t, reader)

bar := pb.Get("new-file")
require.NotNil(t, bar)
assert.Equal(t, "Building => new-file", bar.msg.Load().(string))
}

// --- Concurrency tests (must pass go test -race) ---

func TestAdd_ConcurrentSameName(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

const goroutines = 20
var wg sync.WaitGroup
wg.Add(goroutines)

for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()
pb.Add("Phase =>", "shared-name", 100, strings.NewReader("data"))
}()
}

wg.Wait()

// Exactly one bar should exist for the name.
bar := pb.Get("shared-name")
assert.NotNil(t, bar)
}

func TestComplete_ConcurrentWithRender(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

pb.Add("Building =>", "test-file", 100, nil)

const iterations = 100
var wg sync.WaitGroup
wg.Add(2)

// Goroutine 1: simulate Complete updating msg.
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
pb.Complete("test-file", "Done => test-file")
}
}()

// Goroutine 2: simulate render goroutine reading msg.
go func() {
defer wg.Done()
bar := pb.Get("test-file")
if bar == nil {
return
}
for i := 0; i < iterations; i++ {
_ = bar.msg.Load().(string)
}
}()

wg.Wait()
}

func TestAdd_ConcurrentDifferentNames(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

const goroutines = 20
names := make([]string, goroutines)
for i := 0; i < goroutines; i++ {
names[i] = strings.Repeat("x", i+1) // unique names
}

var wg sync.WaitGroup
wg.Add(goroutines)

for _, name := range names {
go func() {
defer wg.Done()
pb.Add("Building =>", name, 100, strings.NewReader("data"))
}()
}

wg.Wait()

// All bars should exist.
for _, name := range names {
bar := pb.Get(name)
assert.NotNil(t, bar, "bar for %q should exist", name)
}
}

// --- Error / idempotency tests ---

func TestAbort_NonExistentBar(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

// Should not panic.
pb.Abort("does-not-exist", errors.New("test error"))
}

func TestAbort_AlreadyAbortedBar(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

pb.Add("Building =>", "test-file", 100, nil)
pb.Abort("test-file", errors.New("first abort"))

// Second abort should not panic (mpb uses sync.Once internally).
pb.Abort("test-file", errors.New("second abort"))
}

func TestComplete_NonExistentBar(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

// Should not panic.
pb.Complete("does-not-exist", "done")
}

func TestAdd_AfterAbort(t *testing.T) {
pb := NewProgressBar(io.Discard)
defer pb.Stop()

pb.Add("Phase1 =>", "test-file", 100, nil)
pb.Abort("test-file", errors.New("abort"))

// Add same name again should work.
reader := pb.Add("Phase2 =>", "test-file", 200, strings.NewReader("data"))
assert.NotNil(t, reader)

bar := pb.Get("test-file")
require.NotNil(t, bar)
assert.Equal(t, int64(200), bar.size)
}
14 changes: 11 additions & 3 deletions pkg/backend/build/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, p
return ocispec.Descriptor{}, fmt.Errorf("failed to encode file: %w", err)
}

reader, digest, size, err := ab.computeDigestAndSize(ctx, mediaType, path, workDirPath, info, reader, codec)
reader, digest, size, err := ab.computeDigestAndSize(ctx, mediaType, path, workDirPath, info, reader, codec,
func(size int64, r io.Reader) io.Reader {
return hooks.OnHash(relPath, size, r)
},
)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to compute digest and size: %w", err)
}
Expand Down Expand Up @@ -247,7 +251,8 @@ func (ab *abstractBuilder) BuildManifest(ctx context.Context, layers []ocispec.D
}

// computeDigestAndSize computes the digest and size for the encoded content, using cache if available.
func (ab *abstractBuilder) computeDigestAndSize(ctx context.Context, mediaType, path, workDirPath string, info os.FileInfo, reader io.Reader, codec pkgcodec.Codec) (io.Reader, string, int64, error) {
// The onHash callback wraps the reader with progress tracking before hashing.
func (ab *abstractBuilder) computeDigestAndSize(ctx context.Context, mediaType, path, workDirPath string, info os.FileInfo, reader io.Reader, codec pkgcodec.Codec, onHash func(size int64, r io.Reader) io.Reader) (io.Reader, string, int64, error) {
// Try to retrieve valid digest from cache for raw model weights.
if mediaType == modelspec.MediaTypeModelWeightRaw {
if digest, size, ok := ab.retrieveCache(ctx, path, info); ok {
Expand All @@ -257,8 +262,11 @@ func (ab *abstractBuilder) computeDigestAndSize(ctx context.Context, mediaType,

logrus.Infof("builder: calculating digest for file %s", path)

// Wrap reader with progress tracking via onHash callback.
wrappedReader := onHash(info.Size(), reader)

hash := sha256.New()
size, err := io.Copy(hash, reader)
size, err := io.Copy(hash, wrappedReader)
if err != nil {
return reader, "", 0, fmt.Errorf("failed to copy content to hash: %w", err)
}
Expand Down
Loading
Loading