From 9075ab5bad014fb06ea8cf39a1ca93ddbe179e88 Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:20:39 -0400 Subject: [PATCH 01/19] fix: mount weights in cog serve like cog run does (#3044) * fix: mount weights in cog serve like cog run does was not mounting managed weights into the container, while (via predict.Predictor) did. This caused models with weights to fail during setup with errors like: OSError: Incorrect path_or_model_id: '/src/weights/bge-m3' The fix adds the same weight-handling pattern to cmdServe: 1. weights.CheckDrift() to validate weights.lock is in sync with cog.yaml 2. newWeightManager(src) to create a weights manager 3. wm.Prepare(ctx) to assemble hardlinks from the local store 4. Append mount specs to runOptions.Volumes as read-only binds 5. defer mounts.Release() to clean up on server exit Integration tests: - weights_import_serve.txtar: tests import -> serve (warm cache) - weights_pull_serve.txtar: tests import -> pull -> serve (cold cache) * test: propagate weight env vars to cog serve tests * test: stop cog serve before checking weight cleanup --- integration-tests/harness/harness.go | 45 ++++++++--- .../tests/weights_import_serve.txtar | 69 ++++++++++++++++ .../tests/weights_pull_serve.txtar | 78 +++++++++++++++++++ pkg/cli/serve.go | 26 +++++++ 4 files changed, 209 insertions(+), 9 deletions(-) create mode 100644 integration-tests/tests/weights_import_serve.txtar create mode 100644 integration-tests/tests/weights_pull_serve.txtar diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 014d770ed6..6512897783 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -29,12 +29,14 @@ import ( // into testscript environments (Setup) and background processes (cmdCogServe). // Keep this list in sync: if you add a new env var to propagate, add it here. var propagatedEnvVars = []string{ - "COG_SDK_WHEEL", // SDK wheel override - "COGLET_WHEEL", // coglet wheel override - "RUST_LOG", // Rust logging control - "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) - "BUILDKIT_PROGRESS", // Docker build output format - "COG_REGISTRY_HOST", // registry host for cog base image resolution + "COG_SDK_WHEEL", // SDK wheel override + "COGLET_WHEEL", // coglet wheel override + "COG_CACHE_DIR", // isolated cache for managed weights + "COG_MODEL_REGISTRY", // test registry override for model refs + "RUST_LOG", // Rust logging control + "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) + "BUILDKIT_PROGRESS", // Docker build output format + "COG_REGISTRY_HOST", // registry host for cog base image resolution } // Harness provides utilities for running cog integration tests. @@ -242,6 +244,7 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool // Built-in commands (defined in this file) NewCommand("cog", h.cmdCog), NewCommand("curl", h.cmdCurl), + NewCommand("server-stop", h.cmdServerStop), NewCommand("wait-for", h.cmdWaitFor), NewCommand("docker-run", h.cmdDockerRun), @@ -618,6 +621,17 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { ts.Fatalf("curl: all %d attempts failed with status %d: %s", maxAttempts, lastStatus, errorMsg) } +// cmdServerStop stops the background 'cog serve' process for the current test. +func (h *Harness) cmdServerStop(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("server-stop: negation is not supported") + } + if len(args) != 0 { + ts.Fatalf("server-stop: usage: server-stop") + } + h.StopServer(ts) +} + // StopServer stops the background server process for a test script. func (h *Harness) StopServer(ts *testscript.TestScript) { workDir := ts.Getenv("WORK") @@ -643,11 +657,24 @@ func (h *Harness) stopServerByWorkDir(workDir string) { _ = resp.Body.Close() } - // Force kill the cog process if still running + // Give the cog process time to exit after /shutdown so defers can run. + done := make(chan struct{}) + go func() { + _ = info.cmd.Wait() + close(done) + }() + + // Force kill the cog process if still running. if info.cmd.Process != nil { - _ = info.cmd.Process.Kill() + select { + case <-done: + case <-time.After(5 * time.Second): + _ = info.cmd.Process.Kill() + <-done + } + } else { + <-done } - _ = info.cmd.Wait() // Also kill any Docker container that may still be running on this port // Find container by port and kill it diff --git a/integration-tests/tests/weights_import_serve.txtar b/integration-tests/tests/weights_import_serve.txtar new file mode 100644 index 0000000000..47ada1ffe0 --- /dev/null +++ b/integration-tests/tests/weights_import_serve.txtar @@ -0,0 +1,69 @@ +# End-to-end test that `cog weights import` warms the local store +# enough for `cog serve` to mount the weights without a separate +# `cog weights pull`. +# +# This is the serve-side companion to weights_import_predict.txtar. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Point the model ref at the ephemeral test registry. cog.yaml carries +# the bare repo via 'model:'; COG_MODEL_REGISTRY swaps in the host +# now that registry-start has assigned a port. +env COG_MODEL_REGISTRY=$TEST_REGISTRY + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Step 1: import. After this, the local content store under +# $COG_CACHE_DIR/weights/files/sha256/ must contain greeting.txt's +# bytes — that's the cog-i12u guarantee. +cog weights import +exists weights.lock + +# Step 2: serve, with NO intervening `cog weights pull`. The +# import path warmed the store; serve's hardlink-assemble must +# work straight off the back of import. +cog serve + +# Step 3: predict via HTTP API and verify the weight is mounted. +curl POST /predictions '{"input":{"s":"world"}}' +stdout '"status":"succeeded"' +stdout 'hello world \(weight-size=1024\)' + +# Per-invocation mount dir must be cleaned up after serve exits. +server-stop +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# 'model:' carries the bare repo. The registry host is set via +# COG_MODEL_REGISTRY in the test body since $TEST_REGISTRY varies +# per run. +build: + python_version: "3.12" +predict: "predict.py:Predictor" +model: test/import-serve-model +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/integration-tests/tests/weights_pull_serve.txtar b/integration-tests/tests/weights_pull_serve.txtar new file mode 100644 index 0000000000..38388119f2 --- /dev/null +++ b/integration-tests/tests/weights_pull_serve.txtar @@ -0,0 +1,78 @@ +# End-to-end test of the managed-weights local-run flow via `cog serve`. +# Verifies: cog weights import -> cog weights pull -> cog serve, where +# the predictor container sees the weight files at the configured target +# path via the HTTP API. Also verifies mount cleanup after serve exits. +# +# This is the serve-side companion to weights_pull_predict.txtar. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Point the model ref at the ephemeral test registry. cog serve +# picks this up via newWeightManager → resolveWeightRepo → ResolveModelRef. +env COG_MODEL_REGISTRY=$TEST_REGISTRY + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Step 1: import to the local registry. This also warms the local +# cache as a side effect of import (cog-i12u guarantee). +cog weights import +exists weights.lock + +# Purge the cache so step 2 exercises pull's cold-fetch path. The +# realistic scenario for `cog weights pull` is "lockfile checked in, +# cache empty" — e.g. a fresh clone. +exec sh -c 'rm -rf $WORK/cache/weights/files' + +# Step 2: pull into the isolated cache. +cog weights pull +stderr 'Pulled' + +# Step 3: serve. The predictor opens the file at the configured +# target path and returns its size, proving the mount is visible +# and readable inside the container. +cog serve + +curl POST /predictions '{"input":{"s":"world"}}' +stdout '"status":"succeeded"' +stdout 'hello world \(weight-size=1024\)' + +# Step 4: the per-invocation mount dir must be cleaned up after +# serve exits. .cog/mounts/ either doesn't exist (everything +# cleaned) or is empty. +server-stop +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# 'model:' carries the bare repo. The registry host is set via +# COG_MODEL_REGISTRY in the test body since $TEST_REGISTRY varies +# per run. +build: + python_version: "3.12" +predict: "predict.py:Predictor" +model: test/pull-serve-model +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index b99562ce0c..1d3471345a 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -13,6 +13,7 @@ import ( "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) var ( @@ -84,6 +85,10 @@ func cmdServe(cmd *cobra.Command, arg []string) error { } defer src.Close() + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + console.Info("Building Docker image from environment in cog.yaml...") console.Info("") resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) @@ -125,6 +130,27 @@ func cmdServe(cmd *cobra.Command, arg []string) error { Workdir: "/src", } + wm, err := newWeightManager(src) + if err != nil { + return err + } + mounts, err := wm.Prepare(ctx) + if err != nil { + return fmt.Errorf("prepare weights: %w", err) + } + defer func() { + if err := mounts.Release(); err != nil { + console.Warnf("Failed to clean up weight mounts: %s", err) + } + }() + for _, spec := range mounts.Specs { + runOptions.Volumes = append(runOptions.Volumes, command.Volume{ + Source: spec.Source, + Destination: spec.Target, + ReadOnly: true, + }) + } + // On Linux, host.docker.internal is not available by default — add it. // This allows the container to reach services running on the host, // e.g. when --upload-url points to a local upload server. From a1b710ac7dc858a2704365d530fc01e9e93d4283 Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:06:30 -0400 Subject: [PATCH 02/19] Bump version to 0.21.0-rc.2 (#3045) --- VERSION.txt | 2 +- crates/Cargo.lock | 4 ++-- crates/Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index 1e7ac68eed..c22d34c81e 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.21.0-rc.1 +0.21.0-rc.2 diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 74e832ef9e..002d9a53cf 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -259,7 +259,7 @@ dependencies = [ [[package]] name = "coglet" -version = "0.21.0-rc.1" +version = "0.21.0-rc.2" dependencies = [ "anyhow", "async-trait", @@ -291,7 +291,7 @@ dependencies = [ [[package]] name = "coglet-python" -version = "0.21.0-rc.1" +version = "0.21.0-rc.2" dependencies = [ "async-trait", "base64", diff --git a/crates/Cargo.toml b/crates/Cargo.toml index b95a914547..ba042fd22a 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = ["coglet", "coglet-python"] [workspace.package] -version = "0.21.0-rc.1" +version = "0.21.0-rc.2" edition = "2024" license = "Apache-2.0" repository = "https://github.com/replicate/cog" From 9b9f3102af463a969d6dfb45c883346d1e3a6cf0 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Tue, 2 Jun 2026 14:12:25 -0400 Subject: [PATCH 03/19] chore: regen lockfile Signed-off-by: Mark Phelps --- mise.lock | 80 +++++++++++++++++++++---------------------------------- 1 file changed, 30 insertions(+), 50 deletions(-) diff --git a/mise.lock b/mise.lock index 57b2918a40..5c6661efd8 100644 --- a/mise.lock +++ b/mise.lock @@ -39,30 +39,37 @@ backend = "aqua:golangci/golangci-lint" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64-musl"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64-musl"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-arm64"] checksum = "sha256:03bfadf67e52b441b7ec21305e501c717df93c959836d66c7f97312654acb297" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-x64"] checksum = "sha256:66fb0da81b8033b477f97eea420d4b46b230ca172b8bb87c6610109f3772b6b6" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.windows-x64"] checksum = "sha256:c60c87695e79db8e320f0e5be885059859de52bb5ee5f11be5577828570bc2a3" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-windows-amd64.zip" +provenance = "github-attestations" [[tools."aqua:gotestyourself/gotestsum"]] version = "1.13.0" @@ -161,16 +168,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-arm64-musl"] -checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" +checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64-musl"] -checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" +checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -193,16 +200,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-arm64-musl"] -checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" +checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64-musl"] -checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" +checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -248,32 +255,6 @@ url = "https://github.com/vektra/mockery/releases/download/v3.7.0/mockery_3.7.0_ checksum = "sha256:76d524ad1740cd02ed621d90015f538fdb53cf6cd4a1ad4d289db017fb69bd0e" url = "https://github.com/vektra/mockery/releases/download/v3.7.0/mockery_3.7.0_Windows_x86_64.tar.gz" -[[tools."aqua:ziglang/zig"]] -version = "0.15.2" -backend = "aqua:ziglang/zig" - -[tools."aqua:ziglang/zig"."platforms.linux-arm64"] -url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-arm64-musl"] -url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-x64-musl"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.macos-arm64"] -checksum = "blake3:c7d2fb746701fea2c070f66c29a0300ba42a0d5d2c09c493462b8c0f4f0bd604" -url = "https://ziglang.org/download/0.15.2/zig-aarch64-macos-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.macos-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-macos-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.windows-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-windows-0.15.2.zip" - [[tools.cargo-binstall]] version = "1.16.6" backend = "aqua:cargo-bins/cargo-binstall" @@ -282,10 +263,18 @@ backend = "aqua:cargo-bins/cargo-binstall" checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" +[tools.cargo-binstall."platforms.linux-arm64-musl"] +checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" +url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" + [tools.cargo-binstall."platforms.linux-x64"] checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" +[tools.cargo-binstall."platforms.linux-x64-musl"] +checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" +url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" + [tools.cargo-binstall."platforms.macos-arm64"] checksum = "sha256:30543b378b96fbddabee1edfaccde7914dd2f851f02c560de859f81a21ab665b" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-apple-darwin.zip" @@ -298,22 +287,10 @@ url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/ca checksum = "sha256:fca962c3d12ae6192280111074db073c15abad3ba162a1a5a2af0f6f01872114" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-pc-windows-msvc.zip" -[[tools."cargo:cargo-deny"]] -version = "0.19.0" -backend = "cargo:cargo-deny" - -[[tools."cargo:cargo-insta"]] -version = "1.46.0" -backend = "cargo:cargo-insta" - [[tools."cargo:cargo-nextest"]] version = "0.9.120" backend = "cargo:cargo-nextest" -[[tools."cargo:cargo-zigbuild"]] -version = "0.20.1" -backend = "cargo:cargo-zigbuild" - [[tools."cargo:maturin"]] version = "1.11.5" backend = "cargo:maturin" @@ -370,10 +347,6 @@ backend = "pipx:nox" uvx = "true" uvx_args = "--python-preference=managed -p 3.13" -[[tools.python]] -version = "3.12.12" -backend = "core:python" - [[tools.ruff]] version = "0.14.13" backend = "aqua:astral-sh/ruff" @@ -449,30 +422,37 @@ backend = "aqua:astral-sh/uv" [tools.uv."platforms.linux-arm64"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-arm64-musl"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-x64"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-x64-musl"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.macos-arm64"] checksum = "sha256:fcf0a9ea6599c6ae28a4c854ac6da76f2c889354d7c36ce136ef071f7ab9721f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-apple-darwin.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.macos-x64"] checksum = "sha256:171eb8c518313e157c5b4cec7b4f743bc6bab1bd23e09b646679a02d096a047f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-apple-darwin.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.windows-x64"] checksum = "sha256:eb02fd95d8e0eed462b4a67ecdd320d865b38c560bffcda9a0b87ec944bdf036" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-pc-windows-msvc.zip" +provenance = "github-attestations" [[tools.zig]] version = "0.15.2" From 8428257203c5b2bbbeb9f3b2d156a8ca569645ae Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Tue, 2 Jun 2026 15:13:14 -0400 Subject: [PATCH 04/19] ci: pin mise version in release workflows (#3046) * ci: pin mise version in release workflows * ci: pin mise version in all workflows --- .github/workflows/ci.yaml | 20 +++++++++++++++++++ .github/workflows/docs.yaml | 1 + .github/workflows/mirror-cog-base-images.yaml | 1 + .github/workflows/release-build.yaml | 2 ++ .github/workflows/release-publish.yaml | 1 + .github/workflows/rust-advisories.yaml | 1 + .../update-compatibility-matrices.yaml | 1 + 7 files changed, 27 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 66be3e736b..d3f82eadb0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -138,6 +138,7 @@ jobs: fetch-depth: 0 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Build SDK run: mise run ci:build:sdk @@ -190,6 +191,7 @@ jobs: fetch-depth: 0 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Get version from VERSION.txt id: version @@ -221,6 +223,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check Go formatting run: mise run fmt:go @@ -233,6 +236,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure Rust components run: rustup component add rustfmt clippy @@ -247,6 +251,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check Python formatting run: mise run fmt:python @@ -259,6 +264,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check llms.txt is up to date run: mise run docs:llm:check @@ -283,6 +289,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check stubs are up to date run: | @@ -302,6 +309,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Lint Go run: mise run lint:go @@ -318,6 +326,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure Rust components run: rustup component add rustfmt clippy @@ -332,6 +341,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check licenses, bans, and sources run: cargo deny --manifest-path crates/Cargo.toml check bans licenses sources @@ -349,6 +359,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check advisories (informational) run: cargo deny --manifest-path crates/Cargo.toml check advisories @@ -373,6 +384,7 @@ jobs: run: tar xf dist/*.tar.gz --strip-components=1 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure nox is installed run: mise install pipx:nox --force @@ -387,6 +399,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Lint Markdown run: mise run lint:docs @@ -407,6 +420,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test Go shell: bash @@ -433,6 +447,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Fuzz schema type resolution run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime=30s @@ -455,6 +470,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test Rust run: mise run test:rust @@ -478,6 +494,7 @@ jobs: run: tar xf dist/*.tar.gz --strip-components=1 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Remove src to ensure tests run against wheel run: rm -rf python/cog @@ -508,6 +525,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test coglet-python bindings run: uvx nox -s coglet -p ${{ matrix.python-version }} @@ -628,6 +646,7 @@ jobs: chmod +x ./cog - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Set wheel environment run: | @@ -787,6 +806,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check coglet crates.io publish run: cargo publish --dry-run -p coglet --manifest-path crates/Cargo.toml diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 1010debde1..8fc3d52412 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -15,6 +15,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-docs-${{ github.job }} - name: Generate CLI docs diff --git a/.github/workflows/mirror-cog-base-images.yaml b/.github/workflows/mirror-cog-base-images.yaml index 1a08bb2d87..9c5f739158 100644 --- a/.github/workflows/mirror-cog-base-images.yaml +++ b/.github/workflows/mirror-cog-base-images.yaml @@ -26,6 +26,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-mirror-${{ github.job }} - name: Install crane diff --git a/.github/workflows/release-build.yaml b/.github/workflows/release-build.yaml index 81f3ca45ce..eb4f5f9865 100644 --- a/.github/workflows/release-build.yaml +++ b/.github/workflows/release-build.yaml @@ -129,6 +129,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-rel-${{ github.job }} - name: Update coglet version constraint @@ -219,6 +220,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-rel-${{ github.job }} - name: Check for existing release diff --git a/.github/workflows/release-publish.yaml b/.github/workflows/release-publish.yaml index 7a5a08a32c..9cef41f876 100644 --- a/.github/workflows/release-publish.yaml +++ b/.github/workflows/release-publish.yaml @@ -115,6 +115,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-pub-${{ github.job }} - uses: rust-lang/crates-io-auth-action@v1 id: auth diff --git a/.github/workflows/rust-advisories.yaml b/.github/workflows/rust-advisories.yaml index 143595999d..d2c5fe5a41 100644 --- a/.github/workflows/rust-advisories.yaml +++ b/.github/workflows/rust-advisories.yaml @@ -21,6 +21,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check advisories id: advisories diff --git a/.github/workflows/update-compatibility-matrices.yaml b/.github/workflows/update-compatibility-matrices.yaml index 716095687d..2e9694b1aa 100644 --- a/.github/workflows/update-compatibility-matrices.yaml +++ b/.github/workflows/update-compatibility-matrices.yaml @@ -27,6 +27,7 @@ jobs: - name: Install tools uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache: true cache_key_prefix: mise-compatibility-matrices From b5c856331bd7a64f64ea12ca58ad9fa3a49c4425 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:36:45 -0400 Subject: [PATCH 05/19] test: cover kong command registration and root globals --- cmd/cog-kong/cli.go | 20 ++-- cmd/cog-kong/flags.go | 1 - cmd/cog-kong/main.go | 70 ++++++++---- cmd/cog-kong/main_test.go | 227 ++++++++++++++++++++++++++++++++++++++ cmd/cog-kong/push.go | 2 +- cmd/cog-kong/stubs.go | 100 +++++++++++++++++ 6 files changed, 389 insertions(+), 31 deletions(-) create mode 100644 cmd/cog-kong/main_test.go create mode 100644 cmd/cog-kong/stubs.go diff --git a/cmd/cog-kong/cli.go b/cmd/cog-kong/cli.go index 78790ec139..34def0795a 100644 --- a/cmd/cog-kong/cli.go +++ b/cmd/cog-kong/cli.go @@ -1,12 +1,11 @@ package main import ( - "context" + "os" "github.com/alecthomas/kong" "github.com/replicate/cog/pkg/global" - "github.com/replicate/cog/pkg/update" "github.com/replicate/cog/pkg/util/console" ) @@ -14,6 +13,8 @@ import ( // The AfterApply hook replaces Cobra's PersistentPreRun. type Globals struct { Debug bool `name:"debug" short:"d" env:"COG_DEBUG" help:"Show debugging output."` + NoColor bool `name:"no-color" help:"Disable colored output."` + Help bool `name:"help" short:"h" help:"Show context-sensitive help."` Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."` Profile bool `name:"profile" hidden:"" help:"Enable profiling."` Version kong.VersionFlag `name:"version" short:"v" help:"Show version of Cog."` @@ -21,18 +22,23 @@ type Globals struct { // AfterApply runs after flag parsing, before the command's Run. // This is the Kong equivalent of Cobra's PersistentPreRun. -func (g *Globals) AfterApply(ctx context.Context) error { +func (g *Globals) AfterApply() error { if g.Debug { global.Debug = true console.SetLevel(console.DebugLevel) } + if g.NoColor { + global.NoColor = true + } + if global.NoColor || !console.ShouldUseColor() { + console.SetColor(false) + } + if global.NoColor { + _ = os.Setenv("NO_COLOR", "1") + } if g.Profile { global.ProfilingEnabled = true } global.ReplicateRegistryHost = g.Registry - - if err := update.DisplayAndCheckForRelease(ctx); err != nil { - console.Debugf("%s", err) - } return nil } diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go index 11c67b1aa7..ca2c014b12 100644 --- a/cmd/cog-kong/flags.go +++ b/cmd/cog-kong/flags.go @@ -65,7 +65,6 @@ func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]strin Strip: b.Strip, Precompile: b.Precompile, Annotations: annotations, - OCIIndex: model.OCIIndexEnabled(), } } diff --git a/cmd/cog-kong/main.go b/cmd/cog-kong/main.go index 332468b4aa..63cf073a08 100644 --- a/cmd/cog-kong/main.go +++ b/cmd/cog-kong/main.go @@ -12,6 +12,7 @@ import ( "github.com/alecthomas/kong" "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/update" "github.com/replicate/cog/pkg/util/console" ) @@ -26,8 +27,19 @@ var ( type CLI struct { Globals - Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` - Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` + BaseImage BaseImageCmd `cmd:"" name:"base-image" help:"Tools for working with Cog base images."` + Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` + Debug DebugCmd `cmd:"" help:"Debug Cog internals."` + Doctor DoctorCmd `cmd:"" help:"Check your project for common issues and fix them (experimental)."` + Exec ExecCmd `cmd:"" help:"Execute a command inside a Docker environment."` + Init InitCmd `cmd:"" help:"Configure your project for use with Cog."` + Login LoginCmd `cmd:"" help:"Log in to a container registry."` + Predict PredictCmd `cmd:"" help:"Run a prediction."` + Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` + RunCommand RunCmd `cmd:"" name:"run" help:"Run a prediction."` + Serve ServeCmd `cmd:"" help:"Run an HTTP server."` + Train TrainCmd `cmd:"" help:"Run a training job."` + Weights WeightsCmd `cmd:"" help:"Commands for managing model weight files."` } func main() { @@ -35,26 +47,7 @@ func main() { var cli CLI - initOpts := []kong.Option{ - // CLI metadata and variable interpolation for struct tags - kong.Name("cog"), - kong.Description("Containers for machine learning."), - kong.Vars{ - "version": fmt.Sprintf("cog version %s (built %s)", version, buildTime), - "commit": commit, - "progress_default": progressDefault(), - "registry_default": global.DefaultReplicateRegistryHost, - }, - kong.UsageOnError(), - - // bindings for lazily injecting dependencies into Run() methods - kong.BindTo(ctx, (*context.Context)(nil)), - kong.BindSingletonProvider(provideDockerClient), - kong.BindToProvider(provideRegistryClient), - kong.BindSingletonProvider(provideProviderRegistry), - } - - parser, err := kong.New(&cli, initOpts...) + parser, err := newParser(ctx, &cli) if err != nil { // Fatal error creating the parser — this is a bug, so panic to get a stack trace. panic(err) @@ -74,14 +67,47 @@ func main() { // otherwise it's a real parse error (e.g. unexpected command or flag), so print the error and exit non-zero. parser.FatalIfErrorf(err) + os.Exit(1) + } + if cli.Help { + _ = kctx.PrintUsage(false) + return } + displayUpdateCheck(ctx) err = kctx.Run() cancel() // command returned an error. Print and exit non-zero. if err != nil { parser.FatalIfErrorf(err) + os.Exit(1) + } +} + +func displayUpdateCheck(ctx context.Context) { + if err := update.DisplayAndCheckForRelease(ctx); err != nil { + console.Debugf("%s", err) + } +} + +func newParser(ctx context.Context, cli *CLI, options ...kong.Option) (*kong.Kong, error) { + defaultOptions := []kong.Option{ + kong.Name("cog"), + kong.Description("Containers for machine learning."), + kong.Vars{ + "version": fmt.Sprintf("cog version %s (built %s)", version, buildTime), + "commit": commit, + "progress_default": progressDefault(), + "registry_default": global.DefaultReplicateRegistryHost, + }, + kong.UsageOnError(), + kong.NoDefaultHelp(), + kong.BindTo(ctx, (*context.Context)(nil)), + kong.BindSingletonProvider(provideDockerClient), + kong.BindToProvider(provideRegistryClient), + kong.BindSingletonProvider(provideProviderRegistry), } + return kong.New(cli, append(defaultOptions, options...)...) } func newCancellationContext() (context.Context, context.CancelFunc) { diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go new file mode 100644 index 0000000000..4b3e38d607 --- /dev/null +++ b/cmd/cog-kong/main_test.go @@ -0,0 +1,227 @@ +package main + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/alecthomas/kong" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/util/console" +) + +func newTestParser(t *testing.T) *kong.Kong { + t.Helper() + parser, err := newParser(t.Context(), &CLI{}) + require.NoError(t, err) + return parser +} + +type kongGlobalState struct { + debug bool + noColor bool + profilingEnabled bool + registry string + consoleColor bool + consoleLevel console.Level + noColorEnv string + hadNoColorEnv bool +} + +func snapshotKongGlobalState() kongGlobalState { + noColorEnv, hadNoColorEnv := os.LookupEnv("NO_COLOR") + return kongGlobalState{ + debug: global.Debug, + noColor: global.NoColor, + profilingEnabled: global.ProfilingEnabled, + registry: global.ReplicateRegistryHost, + consoleColor: console.ConsoleInstance.Color, + consoleLevel: console.ConsoleInstance.Level, + noColorEnv: noColorEnv, + hadNoColorEnv: hadNoColorEnv, + } +} + +func restoreKongGlobalState(t *testing.T, state kongGlobalState) { + t.Helper() + global.Debug = state.debug + global.NoColor = state.noColor + global.ProfilingEnabled = state.profilingEnabled + global.ReplicateRegistryHost = state.registry + console.SetColor(state.consoleColor) + console.SetLevel(state.consoleLevel) + if state.hadNoColorEnv { + require.NoError(t, os.Setenv("NO_COLOR", state.noColorEnv)) + } else { + require.NoError(t, os.Unsetenv("NO_COLOR")) + } +} + +func preserveKongGlobalState(t *testing.T) kongGlobalState { + t.Helper() + state := snapshotKongGlobalState() + t.Cleanup(func() { + restoreKongGlobalState(t, state) + }) + return state +} + +type testExitCode int + +func newVersionTestParser(t *testing.T, stdout *bytes.Buffer) *kong.Kong { + t.Helper() + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) + require.NoError(t, err) + parser.Stdout = stdout + parser.Stderr = stdout + return parser +} + +func TestKongRegistersAllTopLevelCommands(t *testing.T) { + parser := newTestParser(t) + commands := map[string]bool{} + for _, node := range parser.Model.Children { + commands[node.Name] = true + } + + for _, name := range []string{ + "base-image", + "build", + "debug", + "doctor", + "exec", + "init", + "login", + "predict", + "push", + "run", + "serve", + "train", + "weights", + } { + require.Truef(t, commands[name], "missing command %q", name) + } +} + +func TestKongRegistersNestedCommands(t *testing.T) { + preserveKongGlobalState(t) + parser := newTestParser(t) + + for _, args := range [][]string{ + {"weights", "import", "--help"}, + {"weights", "pull", "--help"}, + {"weights", "status", "--help"}, + {"base-image", "dockerfile", "--help"}, + {"base-image", "build", "--help"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} + +func TestKongRootHelpParses(t *testing.T) { + preserveKongGlobalState(t) + parser := newTestParser(t) + var stdout bytes.Buffer + parser.Stdout = &stdout + parser.Stderr = &stdout + + kctx, err := parser.Parse([]string{"--help"}) + if err != nil { + var parseErr *kong.ParseError + require.True(t, errors.As(err, &parseErr), "expected ParseError, got %T", err) + require.True(t, strings.HasPrefix(parseErr.Error(), "expected"), "expected command selection error, got %q", parseErr.Error()) + kctx = parseErr.Context + } + require.NoError(t, kctx.PrintUsage(false)) + + help := stdout.String() + require.Contains(t, help, "Usage: cog [flags]") + require.Contains(t, help, "build") + require.Contains(t, help, "push") + require.Contains(t, help, "weights") + require.NotContains(t, help, "Usage: cog default") +} + +func TestKongRootGlobalFlagsParse(t *testing.T) { + state := preserveKongGlobalState(t) + + for _, args := range [][]string{ + {"--debug", "build", "--help"}, + {"--no-color", "build", "--help"}, + {"--profile", "build", "--help"}, + {"--registry", "example.com", "build", "--help"}, + } { + restoreKongGlobalState(t, state) + parser := newTestParser(t) + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } + + restoreKongGlobalState(t, state) + var stdout bytes.Buffer + versionParser := newVersionTestParser(t, &stdout) + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = versionParser.Parse([]string{"--version"}) + }) +} + +func TestKongHelpParsingDoesNotWriteUpdateState(t *testing.T) { + preserveKongGlobalState(t) + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("COG_NO_UPDATE_CHECK", "") + + for _, args := range [][]string{ + {"--help"}, + {"build", "--help"}, + } { + parser := newTestParser(t) + _, err := parser.Parse(args) + if err != nil { + var parseErr *kong.ParseError + require.True(t, errors.As(err, &parseErr), "expected ParseError, got %T", err) + require.True(t, strings.HasPrefix(parseErr.Error(), "expected"), "expected command selection error, got %q", parseErr.Error()) + } + require.NoFileExists(t, filepath.Join(home, ".config", "cog", "update-state.json"), "parse %v", args) + } +} + +func TestKongNoColorAfterApplySetsGlobalNoColor(t *testing.T) { + preserveKongGlobalState(t) + require.NoError(t, os.Unsetenv("NO_COLOR")) + global.NoColor = false + + globals := Globals{NoColor: true, Registry: global.ReplicateRegistryHost} + require.NoError(t, globals.AfterApply()) + + require.True(t, global.NoColor) + require.Equal(t, "1", os.Getenv("NO_COLOR")) +} + +func TestKongVersionExitsBeforeRootUsage(t *testing.T) { + var stdout bytes.Buffer + parser := newVersionTestParser(t, &stdout) + + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"--version"}) + }) + require.Equal(t, "cog version dev (built none)\n", stdout.String()) +} + +func TestKongCommandVersionExitsBeforeCommandRun(t *testing.T) { + var stdout bytes.Buffer + parser := newVersionTestParser(t, &stdout) + + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"build", "--version"}) + }) + require.Equal(t, "cog version dev (built none)\n", stdout.String()) +} diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go index 9a92411024..db5139a888 100644 --- a/cmd/cog-kong/push.go +++ b/cmd/cog-kong/push.go @@ -71,7 +71,7 @@ func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regCl // Push the model console.Infof("\nPushing image '%s'...", m.ImageRef()) - pushErr := resolver.Push(ctx, m, model.PushOptions{}) + _, pushErr := resolver.Push(ctx, m, model.PushOptions{}) // PostPush: the provider handles formatting errors and showing success messages if err := p.PostPush(ctx, pushOpts, pushErr); err != nil { diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go new file mode 100644 index 0000000000..f9a812362e --- /dev/null +++ b/cmd/cog-kong/stubs.go @@ -0,0 +1,100 @@ +package main + +import "errors" + +var errKongCommandNotImplemented = errors.New("kong command not implemented") + +type BaseImageCmd struct { + Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Generate a Dockerfile for a Cog base image."` + Build BaseImageBuildCmd `cmd:"" help:"Build a Cog base image."` +} + +type BaseImageDockerfileCmd struct{} + +func (cmd *BaseImageDockerfileCmd) Run() error { + return errKongCommandNotImplemented +} + +type BaseImageBuildCmd struct{} + +func (cmd *BaseImageBuildCmd) Run() error { + return errKongCommandNotImplemented +} + +type DebugCmd struct{} + +func (cmd *DebugCmd) Run() error { + return errKongCommandNotImplemented +} + +type DoctorCmd struct{} + +func (cmd *DoctorCmd) Run() error { + return errKongCommandNotImplemented +} + +type ExecCmd struct{} + +func (cmd *ExecCmd) Run() error { + return errKongCommandNotImplemented +} + +type InitCmd struct{} + +func (cmd *InitCmd) Run() error { + return errKongCommandNotImplemented +} + +type LoginCmd struct{} + +func (cmd *LoginCmd) Run() error { + return errKongCommandNotImplemented +} + +type PredictCmd struct{} + +func (cmd *PredictCmd) Run() error { + return errKongCommandNotImplemented +} + +type RunCmd struct{} + +func (cmd *RunCmd) Run() error { + return errKongCommandNotImplemented +} + +type ServeCmd struct{} + +func (cmd *ServeCmd) Run() error { + return errKongCommandNotImplemented +} + +type TrainCmd struct{} + +func (cmd *TrainCmd) Run() error { + return errKongCommandNotImplemented +} + +type WeightsCmd struct { + Import WeightsImportCmd `cmd:"" help:"Import model weights."` + Pull WeightsPullCmd `cmd:"" help:"Pull model weights."` + Status WeightsStatusCmd `cmd:"" help:"Show model weight status."` +} + +type WeightsImportCmd struct{} + +func (cmd *WeightsImportCmd) Run() error { + return errKongCommandNotImplemented +} + +type WeightsPullCmd struct{} + +func (cmd *WeightsPullCmd) Run() error { + return errKongCommandNotImplemented +} + +type WeightsStatusCmd struct{} + +func (cmd *WeightsStatusCmd) Run() error { + return errKongCommandNotImplemented +} From 127e8b1a1b61437c5c6715d8f1ad04fe97368b0b Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:38:24 -0400 Subject: [PATCH 06/19] refactor: share build command logic between cobra and kong --- cmd/cog-kong/build.go | 30 +++------- cmd/cog-kong/flags.go | 26 +++++---- pkg/cli/build.go | 129 +++++++++++++++++++++++++++++++++--------- pkg/cli/build_test.go | 60 ++++++++++++++++++++ 4 files changed, 185 insertions(+), 60 deletions(-) create mode 100644 pkg/cli/build_test.go diff --git a/cmd/cog-kong/build.go b/cmd/cog-kong/build.go index e5afb7054c..e73e2f72f9 100644 --- a/cmd/cog-kong/build.go +++ b/cmd/cog-kong/build.go @@ -3,11 +3,9 @@ package main import ( "context" - "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" - "github.com/replicate/cog/pkg/util/console" ) // BuildCmd implements the "cog build" command. @@ -22,23 +20,11 @@ func (cmd *BuildCmd) Validate() error { return cmd.ValidateMutualExclusivity() } -// Run executes the build command. -func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, src *model.Source) error { - imageName := src.Config.Image - if cmd.Tag != "" { - imageName = cmd.Tag - } - if imageName == "" { - imageName = config.DockerImageName(src.ProjectDir) - } - - resolver := model.NewResolver(dockerClient, regClient) - m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, nil)) - if err != nil { - return err - } - - console.Infof("\nImage built as %s", m.ImageRef()) - - return nil +// Run executes the build command via the shared cli.RunBuild runner. +func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunBuild(ctx, dockerClient, regClient, cli.BuildCommandOptions{ + ConfigFilename: cmd.File, + Tag: cmd.Tag, + Flags: cmd.BuildFlags.Options(), + }) } diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go index ca2c014b12..72e8f9b6d6 100644 --- a/cmd/cog-kong/flags.go +++ b/cmd/cog-kong/flags.go @@ -5,6 +5,7 @@ import ( "os" "strings" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/model" ) @@ -49,25 +50,30 @@ func (b *BuildFlags) AfterApply() error { return nil } -// BuildOptions constructs a model.BuildOptions from the current flag values. -// The imageName and annotations parameters vary by caller (build vs push). -func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions { - return model.BuildOptions{ - ImageName: imageName, - Secrets: b.Secrets, +// Options converts the Kong build flags into the parser-independent +// cli.BuildFlagsOptions shared with the Cobra CLI. +func (b *BuildFlags) Options() cli.BuildFlagsOptions { + return cli.BuildFlagsOptions{ NoCache: b.NoCache, SeparateWeights: b.SeparateWeights, - UseCudaBaseImage: b.UseCudaBaseImage, + Secrets: b.Secrets, ProgressOutput: b.Progress, - SchemaFile: b.OpenAPISchema, - DockerfileFile: b.Dockerfile, + UseCudaBaseImage: b.UseCudaBaseImage, UseCogBaseImage: b.UseCogBaseImage, + OpenAPISchema: b.OpenAPISchema, + DockerfileFile: b.Dockerfile, Strip: b.Strip, Precompile: b.Precompile, - Annotations: annotations, + Timestamp: b.Timestamp, } } +// BuildOptions constructs a model.BuildOptions from the current flag values. +// The imageName and annotations parameters vary by caller (build vs push). +func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions { + return b.Options().ModelBuildOptions(imageName, annotations) +} + // ValidateMutualExclusivity ensures that at most one of --use-cog-base-image, // --use-cuda-base-image, and --dockerfile is explicitly set. func (b *BuildFlags) ValidateMutualExclusivity() error { diff --git a/pkg/cli/build.go b/pkg/cli/build.go index bbdc5f37f7..73542bdc17 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "os" "strings" @@ -10,6 +11,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" @@ -83,7 +85,101 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + return RunBuild(ctx, dockerClient, registry.NewRegistryClient(), BuildCommandOptions{ + ConfigFilename: configFilename, + Tag: buildTag, + Flags: buildFlagsOptionsFromCobra(cmd), + }) +} + +// BuildFlagsOptions holds the parser-independent build flag values shared by +// both the Cobra and Kong CLIs. It mirrors the flags registered on the build +// and push commands. +type BuildFlagsOptions struct { + NoCache bool + SeparateWeights bool + Secrets []string + ProgressOutput string + UseCudaBaseImage string + UseCogBaseImage *bool + OpenAPISchema string + DockerfileFile string + Strip bool + Precompile bool + SkipSchemaValidation bool + Timestamp int64 +} + +// BuildCommandOptions holds everything RunBuild needs that is independent of +// the argument parser. +type BuildCommandOptions struct { + ConfigFilename string + Tag string + Flags BuildFlagsOptions +} + +// ResolveBuildImageName resolves the output image name from the config image, +// config model, an explicit tag, and the project directory, matching the +// historical precedence: tag > config image > config model > derived name. +func ResolveBuildImageName(configImage, configModel, tag, projectDir string) string { + imageName := configImage + if imageName == "" { + imageName = configModel + } + if tag != "" { + imageName = tag + } + if imageName == "" { + imageName = config.DockerImageName(projectDir) + } + return imageName +} + +// ModelBuildOptions converts the shared build flags into model.BuildOptions for +// the given image name and annotations. +func (o BuildFlagsOptions) ModelBuildOptions(imageName string, annotations map[string]string) model.BuildOptions { + return model.BuildOptions{ + ImageName: imageName, + Secrets: o.Secrets, + NoCache: o.NoCache, + SeparateWeights: o.SeparateWeights, + UseCudaBaseImage: o.UseCudaBaseImage, + ProgressOutput: o.ProgressOutput, + SchemaFile: o.OpenAPISchema, + DockerfileFile: o.DockerfileFile, + UseCogBaseImage: o.UseCogBaseImage, + Strip: o.Strip, + Precompile: o.Precompile, + Annotations: annotations, + SkipSchemaValidation: o.SkipSchemaValidation, + } +} + +// buildFlagsOptionsFromCobra reads the package-level Cobra flag globals into a +// BuildFlagsOptions value. +func buildFlagsOptionsFromCobra(cmd *cobra.Command) BuildFlagsOptions { + return BuildFlagsOptions{ + NoCache: buildNoCache, + SeparateWeights: buildSeparateWeights, + Secrets: buildSecrets, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + OpenAPISchema: buildSchemaFile, + DockerfileFile: buildDockerfileFile, + Strip: buildStrip, + Precompile: buildPrecompile, + SkipSchemaValidation: buildSkipSchemaValidation, + Timestamp: config.BuildSourceEpochTimestamp, + } +} + +// RunBuild builds a Docker image from the model source described by opts. It is +// shared by both the Cobra and Kong build commands. +func RunBuild(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts BuildCommandOptions) error { + config.BuildSourceEpochTimestamp = opts.Flags.Timestamp + + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -93,22 +189,13 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - imageName := src.Config.Image - if imageName == "" { - imageName = src.Config.Model - } - if buildTag != "" { - imageName = buildTag - } - if imageName == "" { - imageName = config.DockerImageName(src.ProjectDir) - } + imageName := ResolveBuildImageName(src.Config.Image, src.Config.Model, opts.Tag, src.ProjectDir) console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(imageName)) console.Info("") - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.Build(ctx, src, buildOptionsFromFlags(cmd, imageName, nil)) + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, opts.Flags.ModelBuildOptions(imageName, nil)) if err != nil { return err } @@ -215,19 +302,5 @@ func addSkipSchemaValidationFlag(cmd *cobra.Command) { // buildOptionsFromFlags creates BuildOptions from the current CLI flag values. // The imageName and annotations parameters vary by command and must be provided. func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions { - return model.BuildOptions{ - ImageName: imageName, - Secrets: buildSecrets, - NoCache: buildNoCache, - SeparateWeights: buildSeparateWeights, - UseCudaBaseImage: buildUseCudaBaseImage, - ProgressOutput: buildProgressOutput, - SchemaFile: buildSchemaFile, - DockerfileFile: buildDockerfileFile, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - Strip: buildStrip, - Precompile: buildPrecompile, - Annotations: annotations, - SkipSchemaValidation: buildSkipSchemaValidation, - } + return buildFlagsOptionsFromCobra(cmd).ModelBuildOptions(imageName, annotations) } diff --git a/pkg/cli/build_test.go b/pkg/cli/build_test.go new file mode 100644 index 0000000000..80a3554ba8 --- /dev/null +++ b/pkg/cli/build_test.go @@ -0,0 +1,60 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolveBuildImageNameFallsBackToModel(t *testing.T) { + name := ResolveBuildImageName("", "r8.im/user/model", "", "/tmp/project") + require.Equal(t, "r8.im/user/model", name) +} + +func TestResolveBuildImageNamePrefersConfigImageOverModel(t *testing.T) { + name := ResolveBuildImageName("config-image", "r8.im/user/model", "", "/tmp/project") + require.Equal(t, "config-image", name) +} + +func TestResolveBuildImageNamePrefersTag(t *testing.T) { + name := ResolveBuildImageName("config-image", "r8.im/user/model", "custom:tag", "/tmp/project") + require.Equal(t, "custom:tag", name) +} + +func TestUseCogBaseImageExplicitness(t *testing.T) { + falseValue := false + opts := BuildFlagsOptions{UseCogBaseImage: &falseValue} + require.NotNil(t, opts.UseCogBaseImage) + require.False(t, *opts.UseCogBaseImage) + + opts = BuildFlagsOptions{} + require.Nil(t, opts.UseCogBaseImage) +} + +func TestBuildFlagsOptionsModelBuildOptions(t *testing.T) { + opts := BuildFlagsOptions{ + NoCache: true, + SeparateWeights: true, + Secrets: []string{"id=foo"}, + ProgressOutput: "plain", + UseCudaBaseImage: "false", + OpenAPISchema: "schema.json", + DockerfileFile: "Dockerfile", + Strip: true, + Precompile: true, + } + annotations := map[string]string{"k": "v"} + bo := opts.ModelBuildOptions("my-image", annotations) + + require.Equal(t, "my-image", bo.ImageName) + require.True(t, bo.NoCache) + require.True(t, bo.SeparateWeights) + require.Equal(t, []string{"id=foo"}, bo.Secrets) + require.Equal(t, "plain", bo.ProgressOutput) + require.Equal(t, "false", bo.UseCudaBaseImage) + require.Equal(t, "schema.json", bo.SchemaFile) + require.Equal(t, "Dockerfile", bo.DockerfileFile) + require.True(t, bo.Strip) + require.True(t, bo.Precompile) + require.Equal(t, annotations, bo.Annotations) +} From 957c9244e787e2ea72bc3b430fa99d34243da7ab Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:41:01 -0400 Subject: [PATCH 07/19] refactor: share push command logic between cobra and kong --- cmd/cog-kong/push.go | 71 +++++------------------------------------- pkg/cli/push.go | 74 ++++++++++++++++++++++++++++++++++---------- pkg/cli/push_test.go | 31 +++++++++++++++++++ 3 files changed, 96 insertions(+), 80 deletions(-) diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go index db5139a888..be8790c96d 100644 --- a/cmd/cog-kong/push.go +++ b/cmd/cog-kong/push.go @@ -2,15 +2,11 @@ package main import ( "context" - "fmt" - - "github.com/replicate/go/uuid" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/registry" - "github.com/replicate/cog/pkg/util/console" ) // PushCmd implements the "cog push" command. @@ -25,62 +21,11 @@ func (cmd *PushCmd) Validate() error { return cmd.ValidateMutualExclusivity() } -// Run executes the push command: build then push. -func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry, src *model.Source) error { - imageName := src.Config.Image - if cmd.Image != "" { - imageName = cmd.Image - } - - if imageName == "" { - return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") - } - - // Look up the provider for the target registry - p := providerReg.ForImage(imageName) - if p == nil { - return fmt.Errorf("no provider found for image '%s'", imageName) - } - - pushOpts := provider.PushOptions{ - Image: imageName, - Config: src.Config, - ProjectDir: src.ProjectDir, - } - - // Generate a push ID for annotations - buildID, _ := uuid.NewV7() - annotations := map[string]string{} - if buildID.String() != "" { - annotations["run.cog.push_id"] = buildID.String() - } - - // Build the model - resolver := model.NewResolver(dockerClient, regClient) - m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, annotations)) - if err != nil { - _ = p.PostPush(ctx, pushOpts, err) - return err - } - - // Log weights info - weights := m.WeightArtifacts() - if len(weights) > 0 { - console.Infof("\n%d weight artifact(s)", len(weights)) - } - - // Push the model - console.Infof("\nPushing image '%s'...", m.ImageRef()) - _, pushErr := resolver.Push(ctx, m, model.PushOptions{}) - - // PostPush: the provider handles formatting errors and showing success messages - if err := p.PostPush(ctx, pushOpts, pushErr); err != nil { - return err - } - - if pushErr != nil { - return fmt.Errorf("failed to push image: %w", pushErr) - } - - return nil +// Run executes the push command via the shared cli.RunPush runner. +func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry) error { + return cli.RunPush(ctx, dockerClient, regClient, providerReg, cli.PushCommandOptions{ + ConfigFilename: cmd.File, + Image: cmd.Image, + Flags: cmd.BuildFlags.Options(), + }) } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 6246c01772..23ff391f1b 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -10,7 +11,9 @@ import ( "github.com/replicate/go/uuid" + "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/provider/setup" @@ -64,7 +67,53 @@ func push(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + image := "" + if len(args) > 0 { + image = args[0] + } + + return RunPush(ctx, dockerClient, registry.NewRegistryClient(), provider.DefaultRegistry(), PushCommandOptions{ + ConfigFilename: configFilename, + Image: image, + Flags: buildFlagsOptionsFromCobra(cmd), + }) +} + +// PushCommandOptions holds everything RunPush needs that is independent of the +// argument parser. +type PushCommandOptions struct { + ConfigFilename string + Image string + Flags BuildFlagsOptions +} + +// ResolvePushTarget resolves the push target image reference from the config +// image, config model, and positional args, running the push-specific +// user-input checks. It returns the target string, the resolved model ref (or +// nil for FormatImage paths), and any validation error. +func ResolvePushTarget(configImage, configModel string, args []string) (string, *model.ResolvedRef, error) { + modelRef, err := validatePushArgs(configImage, configModel, args) + if err != nil { + return "", nil, err + } + switch { + case modelRef != nil: + return modelRef.String(), modelRef, nil + case len(args) > 0: + return args[0], nil, nil + case configImage != "": + return configImage, nil, nil + default: + return "", nil, errors.New("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") + } +} + +// RunPush builds and pushes a model image. It is shared by both the Cobra and +// Kong push commands. +func RunPush(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry, opts PushCommandOptions) error { + config.BuildSourceEpochTimestamp = opts.Flags.Timestamp + + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -81,24 +130,17 @@ func push(cmd *cobra.Command, args []string) error { // can drive credential selection from the correct host even when // Build fails (the build-error path calls p.PostPush(...) for // Replicate-specific guidance). - modelRef, err := validatePushArgs(src.Config.Image, src.Config.Model, args) + var args []string + if opts.Image != "" { + args = []string{opts.Image} + } + pushTarget, _, err := ResolvePushTarget(src.Config.Image, src.Config.Model, args) if err != nil { return err } - var pushTarget string - switch { - case modelRef != nil: - pushTarget = modelRef.String() - case len(args) > 0: - pushTarget = args[0] - case src.Config.Image != "": - pushTarget = src.Config.Image - default: - return errors.New("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") - } // Look up the provider for the target registry - p := provider.DefaultRegistry().ForImage(pushTarget) + p := providerReg.ForImage(pushTarget) if p == nil { return fmt.Errorf("no provider found for image '%s'", pushTarget) } @@ -116,14 +158,12 @@ func push(cmd *cobra.Command, args []string) error { annotations["run.cog.push_id"] = buildID.String() } - regClient := registry.NewRegistryClient() resolver := model.NewResolver(dockerClient, regClient) // Build the model console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(pushTarget)) console.Info("") - buildOpts := buildOptionsFromFlags(cmd, pushTarget, annotations) - m, err := resolver.Build(ctx, src, buildOpts) + m, err := resolver.Build(ctx, src, opts.Flags.ModelBuildOptions(pushTarget, annotations)) if err != nil { // Call PostPush to handle error logging/analytics _ = p.PostPush(ctx, pushOpts, err) diff --git a/pkg/cli/push_test.go b/pkg/cli/push_test.go index 1e567f36a2..f0ef4a2b18 100644 --- a/pkg/cli/push_test.go +++ b/pkg/cli/push_test.go @@ -331,3 +331,34 @@ func TestFormatPushResult_DefensiveGuards(t *testing.T) { assert.NotContains(t, out, "image") }) } + +func TestResolvePushTargetUsesModelRef(t *testing.T) { + target, ref, err := ResolvePushTarget("", "r8.im/user/model", nil) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ref.String(), target) +} + +func TestResolvePushTargetRejectsPositionalWithModel(t *testing.T) { + _, _, err := ResolvePushTarget("", "r8.im/user/model", []string{"example.com/image"}) + require.ErrorContains(t, err, "positional image argument not supported") +} + +func TestResolvePushTargetUsesConfigImage(t *testing.T) { + target, ref, err := ResolvePushTarget("registry.example.com/user/model", "", nil) + require.NoError(t, err) + require.Nil(t, ref) + require.Equal(t, "registry.example.com/user/model", target) +} + +func TestResolvePushTargetUsesPositionalArg(t *testing.T) { + target, ref, err := ResolvePushTarget("", "", []string{"registry.example.com/user/model"}) + require.NoError(t, err) + require.Nil(t, ref) + require.Equal(t, "registry.example.com/user/model", target) +} + +func TestResolvePushTargetRequiresTarget(t *testing.T) { + _, _, err := ResolvePushTarget("", "", nil) + require.ErrorContains(t, err, "you must either set the 'image' option") +} From 95b223e331a59dab2d478b4d6be302bac49011af Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:43:17 -0400 Subject: [PATCH 08/19] feat: add kong serve and exec parity via shared runtime runners --- cmd/cog-kong/main_test.go | 12 ++++ cmd/cog-kong/runtime.go | 77 ++++++++++++++++++++++++++ cmd/cog-kong/stubs.go | 12 ---- pkg/cli/exec.go | 61 ++++++++++++-------- pkg/cli/serve.go | 113 ++++++++++++++++++++++++++++---------- 5 files changed, 210 insertions(+), 65 deletions(-) create mode 100644 cmd/cog-kong/runtime.go diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index 4b3e38d607..441cb334a3 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -225,3 +225,15 @@ func TestKongCommandVersionExitsBeforeCommandRun(t *testing.T) { }) require.Equal(t, "cog version dev (built none)\n", stdout.String()) } + +func TestKongRuntimeCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"serve", "--port", "5000", "--upload-url", "https://example.com/upload", "--gpus", "all", "--help"}, + {"exec", "--gpus", "all", "--publish", "8888", "--env", "A=B", "python", "-c", "print(1)"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} diff --git a/cmd/cog-kong/runtime.go b/cmd/cog-kong/runtime.go new file mode 100644 index 0000000000..437a2d7098 --- /dev/null +++ b/cmd/cog-kong/runtime.go @@ -0,0 +1,77 @@ +package main + +import ( + "context" + "errors" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/registry" +) + +// errMissingExecCommand matches Cobra's cobra.MinimumNArgs(1) error message for +// the exec command. +var errMissingExecCommand = errors.New("accepts at least 1 arg(s), received 0") + +// RuntimeFlags groups the build/run flags shared by the serve and exec +// commands. +type RuntimeFlags struct { + ConfigFlag `embed:""` + + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` + CudaBase string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` + CogBase *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + GPUs string `name:"gpus" help:"GPU devices to add to the container, in the same format as docker run --gpus."` + Env []string `name:"env" short:"e" help:"Environment variables, in the form name=value."` +} + +// Options converts the Kong runtime flags into cli.RuntimeBuildOptions. +func (f RuntimeFlags) Options() cli.RuntimeBuildOptions { + return cli.RuntimeBuildOptions{ + ConfigFilename: f.File, + ProgressOutput: f.Progress, + UseCudaBaseImage: f.CudaBase, + UseCogBaseImage: f.CogBase, + GPUs: f.GPUs, + Env: f.Env, + } +} + +// ServeCmd implements the "cog serve" command. +type ServeCmd struct { + RuntimeFlags `embed:""` + + Port int `name:"port" short:"p" default:"8393" help:"Port on which to listen."` + UploadURL string `name:"upload-url" help:"Upload URL for file outputs (e.g. https://example.com/upload/)."` +} + +func (cmd *ServeCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunServe(ctx, dockerClient, regClient, cli.ServeCommandOptions{ + RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + Port: cmd.Port, + UploadURL: cmd.UploadURL, + }) +} + +// ExecCmd implements the "cog exec" command. +type ExecCmd struct { + RuntimeFlags `embed:""` + + Publish []string `name:"publish" short:"p" help:"Publish a container's port to the host, e.g. -p 8000."` + Args []string `arg:"" passthrough:"" name:"command" help:"Command and arguments to execute."` +} + +func (cmd *ExecCmd) Validate() error { + if len(cmd.Args) == 0 { + return errMissingExecCommand + } + return nil +} + +func (cmd *ExecCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunExec(ctx, dockerClient, regClient, cli.ExecCommandOptions{ + RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + Args: cmd.Args, + Ports: cmd.Publish, + }) +} diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go index f9a812362e..ff35ccde76 100644 --- a/cmd/cog-kong/stubs.go +++ b/cmd/cog-kong/stubs.go @@ -33,12 +33,6 @@ func (cmd *DoctorCmd) Run() error { return errKongCommandNotImplemented } -type ExecCmd struct{} - -func (cmd *ExecCmd) Run() error { - return errKongCommandNotImplemented -} - type InitCmd struct{} func (cmd *InitCmd) Run() error { @@ -63,12 +57,6 @@ func (cmd *RunCmd) Run() error { return errKongCommandNotImplemented } -type ServeCmd struct{} - -func (cmd *ServeCmd) Run() error { - return errKongCommandNotImplemented -} - type TrainCmd struct{} func (cmd *TrainCmd) Run() error { diff --git a/pkg/cli/exec.go b/pkg/cli/exec.go index 4d14d5b320..9212015047 100644 --- a/pkg/cli/exec.go +++ b/pkg/cli/exec.go @@ -1,8 +1,8 @@ package cli import ( + "context" "errors" - "os" "strconv" "strings" @@ -75,46 +75,59 @@ func execCmd(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + return RunExec(ctx, dockerClient, registry.NewRegistryClient(), ExecCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Args: args, + Ports: execPorts, + }) +} + +// ExecCommandOptions holds everything RunExec needs that is independent of the +// argument parser. +type ExecCommandOptions struct { + RuntimeBuildOptions + Args []string + Ports []string +} + +// RunExec builds a local image and executes an arbitrary command inside it with +// the project directory volume-mounted. It is shared by both the Cobra and Kong +// exec commands. +func RunExec(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts ExecCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } defer src.Close() - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) + resolver := model.NewResolver(dockerClient, regClient) console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - opts := serveBuildOptions(cmd) - opts.SkipSchemaValidation = true - m, err := resolver.Build(ctx, src, opts) + buildOpts := opts.ServeBuildOptions() + buildOpts.SkipSchemaValidation = true + m, err := resolver.Build(ctx, src, buildOpts) if err != nil { return err } - gpus := "" - if gpusFlag != "" { - gpus = gpusFlag - } else if m.HasGPU() { - gpus = "all" - } - - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) - } - runOptions := command.RunOptions{ - Args: args, - Env: env, - GPUs: gpus, + Args: opts.Args, + Env: runtimeEnv(opts.Env), + GPUs: runtimeGPUs(opts.GPUs, m), Image: m.ImageRef(), Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}}, Workdir: "/src", } - for _, portString := range execPorts { + for _, portString := range opts.Ports { port, err := strconv.Atoi(portString) if err != nil { return err @@ -124,7 +137,7 @@ func execCmd(cmd *cobra.Command, args []string) error { } console.Info("") - console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) + console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(opts.Args, " "))) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index b99562ce0c..c4b82e9cdd 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -56,20 +57,71 @@ and outputs as a REST API. Compatible with the Cog HTTP protocol.`, return cmd } -// serveBuildOptions creates BuildOptions for cog serve. +// RuntimeBuildOptions holds the parser-independent options shared by the +// runtime commands (serve, exec) that build a local image and run it with the +// project directory volume-mounted. +type RuntimeBuildOptions struct { + ConfigFilename string + ProgressOutput string + UseCudaBaseImage string + UseCogBaseImage *bool + GPUs string + Env []string +} + +// ServeBuildOptions creates BuildOptions for cog serve and cog exec. // Same build path as cog build, but with ExcludeSource so COPY . /src is // skipped — source is volume-mounted at runtime instead. All other layers // (wheels, apt, etc.) share Docker layer cache with cog build. -func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { +func (o RuntimeBuildOptions) ServeBuildOptions() model.BuildOptions { return model.BuildOptions{ - UseCudaBaseImage: buildUseCudaBaseImage, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - ProgressOutput: buildProgressOutput, + UseCudaBaseImage: o.UseCudaBaseImage, + UseCogBaseImage: o.UseCogBaseImage, + ProgressOutput: o.ProgressOutput, ExcludeSource: true, SkipLabels: true, } } +// runtimeGPUs resolves the GPU spec: an explicit request wins, otherwise +// "all" if the model declares GPU usage, otherwise empty. +func runtimeGPUs(requested string, m *model.Model) string { + if requested != "" { + return requested + } + if m.HasGPU() { + return "all" + } + return "" +} + +// runtimeEnv appends RUST_LOG from the host environment (for Rust coglet +// debugging) to the provided env list without mutating the input slice. +func runtimeEnv(env []string) []string { + out := append([]string{}, env...) + if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { + out = append(out, "RUST_LOG="+rustLog) + } + return out +} + +// serveBuildOptions creates BuildOptions for cog serve from Cobra flag globals. +func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { + return RuntimeBuildOptions{ + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + }.ServeBuildOptions() +} + +// ServeCommandOptions holds everything RunServe needs that is independent of +// the argument parser. +type ServeCommandOptions struct { + RuntimeBuildOptions + Port int + UploadURL string +} + func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() @@ -78,7 +130,24 @@ func cmdServe(cmd *cobra.Command, arg []string) error { return err } - src, err := model.NewSource(configFilename) + return RunServe(ctx, dockerClient, registry.NewRegistryClient(), ServeCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Port: port, + UploadURL: uploadURL, + }) +} + +// RunServe builds the model image and runs the Cog HTTP server. It is shared by +// both the Cobra and Kong serve commands. +func RunServe(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts ServeCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -86,19 +155,12 @@ func cmdServe(cmd *cobra.Command, arg []string) error { console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } - gpus := "" - if gpusFlag != "" { - gpus = gpusFlag - } else if m.HasGPU() { - gpus = "all" - } - args := []string{ "python", "--check-hash-based-pycs", "never", @@ -106,38 +168,31 @@ func cmdServe(cmd *cobra.Command, arg []string) error { "--await-explicit-shutdown", "true", } - if uploadURL != "" { - args = append(args, "--upload-url", uploadURL) - } - - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) + if opts.UploadURL != "" { + args = append(args, "--upload-url", opts.UploadURL) } runOptions := command.RunOptions{ Args: args, - Env: env, - GPUs: gpus, + Env: runtimeEnv(opts.Env), + GPUs: runtimeGPUs(opts.GPUs, m), Image: m.ImageRef(), Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}}, Workdir: "/src", + Ports: []command.Port{{HostPort: opts.Port, ContainerPort: 5000}}, } // On Linux, host.docker.internal is not available by default — add it. // This allows the container to reach services running on the host, // e.g. when --upload-url points to a local upload server. - if uploadURL != "" { + if opts.UploadURL != "" { runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} } - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000}) - console.Info("") console.Infof("Running %[1]s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) console.Info("") - console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", port))) + console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", opts.Port))) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) From 1477b1f28d871cdb1f79f9e1efe9e6c116806536 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:45:23 -0400 Subject: [PATCH 09/19] feat: add kong init, login, doctor, debug parity --- cmd/cog-kong/main_test.go | 14 +++++++++ cmd/cog-kong/simple.go | 64 +++++++++++++++++++++++++++++++++++++++ cmd/cog-kong/stubs.go | 24 --------------- pkg/cli/debug.go | 43 ++++++++++++++++++++------ pkg/cli/doctor.go | 7 +++-- pkg/cli/init.go | 6 ++++ pkg/cli/login.go | 31 ++++++++++++++----- 7 files changed, 145 insertions(+), 44 deletions(-) create mode 100644 cmd/cog-kong/simple.go diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index 441cb334a3..a42a22154b 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -237,3 +237,17 @@ func TestKongRuntimeCommandFlagsParse(t *testing.T) { require.NoErrorf(t, err, "parse %v", args) } } + +func TestKongSimpleCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"init", "--help"}, + {"login", "--token-stdin", "--help"}, + {"doctor", "--fix", "--file", "custom.yaml", "--help"}, + {"debug", "--image-name", "myimage", "--help"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} diff --git a/cmd/cog-kong/simple.go b/cmd/cog-kong/simple.go new file mode 100644 index 0000000000..ea91374d87 --- /dev/null +++ b/cmd/cog-kong/simple.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/provider" + "github.com/replicate/cog/pkg/registry" +) + +// InitCmd implements the "cog init" command. +type InitCmd struct{} + +func (cmd *InitCmd) Run() error { return cli.RunInit() } + +// LoginCmd implements the "cog login" command. +type LoginCmd struct { + TokenStdin bool `name:"token-stdin" help:"Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token"` +} + +func (cmd *LoginCmd) Run(ctx context.Context, providerReg *provider.Registry) error { + return cli.RunLogin(ctx, providerReg, cli.LoginOptions{ + TokenStdin: cmd.TokenStdin, + Host: global.ReplicateRegistryHost, + }) +} + +// DoctorCmd implements the "cog doctor" command. +type DoctorCmd struct { + ConfigFlag `embed:""` + + Fix bool `name:"fix" help:"Automatically apply fixes."` +} + +func (cmd *DoctorCmd) Run(ctx context.Context) error { + return cli.RunDoctor(ctx, cmd.File, cmd.Fix) +} + +// DebugCmd implements the hidden "cog debug" command, which prints a generated +// Dockerfile. +type DebugCmd struct { + ConfigFlag `embed:""` + + ImageName string `name:"image-name" help:"The image name to use for the generated Dockerfile."` + SeparateWeights bool `name:"separate-weights" help:"Separate model weights from code in image layers."` + UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` + UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + + // Hidden flags for parity with the Cobra debug command. + Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` +} + +func (cmd *DebugCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunDebug(ctx, dockerClient, regClient, cli.DebugCommandOptions{ + ConfigFilename: cmd.File, + ImageName: cmd.ImageName, + SeparateWeights: cmd.SeparateWeights, + UseCudaBaseImage: cmd.UseCudaBaseImage, + UseCogBaseImage: cmd.UseCogBaseImage, + }) +} diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go index ff35ccde76..3524caea05 100644 --- a/cmd/cog-kong/stubs.go +++ b/cmd/cog-kong/stubs.go @@ -21,30 +21,6 @@ func (cmd *BaseImageBuildCmd) Run() error { return errKongCommandNotImplemented } -type DebugCmd struct{} - -func (cmd *DebugCmd) Run() error { - return errKongCommandNotImplemented -} - -type DoctorCmd struct{} - -func (cmd *DoctorCmd) Run() error { - return errKongCommandNotImplemented -} - -type InitCmd struct{} - -func (cmd *InitCmd) Run() error { - return errKongCommandNotImplemented -} - -type LoginCmd struct{} - -func (cmd *LoginCmd) Run() error { - return errKongCommandNotImplemented -} - type PredictCmd struct{} func (cmd *PredictCmd) Run() error { diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 4b83a2c366..b158ba2889 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "strings" @@ -8,6 +9,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/dockerfile" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" @@ -38,35 +40,56 @@ func newDebugCommand() *cobra.Command { func cmdDockerfile(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - src, err := model.NewSource(configFilename) + dockerClient, err := docker.NewClient(ctx) if err != nil { return err } - defer src.Close() - dockerClient, err := docker.NewClient(ctx) + return RunDebug(ctx, dockerClient, registry.NewRegistryClient(), DebugCommandOptions{ + ConfigFilename: configFilename, + ImageName: imageName, + SeparateWeights: buildSeparateWeights, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + }) +} + +// DebugCommandOptions holds the parser-independent options for the debug +// command, which generates and prints a Dockerfile. +type DebugCommandOptions struct { + ConfigFilename string + ImageName string + SeparateWeights bool + UseCudaBaseImage string + UseCogBaseImage *bool +} + +// RunDebug generates the Dockerfile(s) for the model and prints them to stdout. +// It is shared by both the Cobra and Kong debug commands. +func RunDebug(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts DebugCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } + defer src.Close() buildDir, err := src.DotCog.TempPath("build") if err != nil { return err } - client := registry.NewRegistryClient() - generator, err := dockerfile.NewStandardGenerator(src.Config, src.ProjectDir, buildDir, src.ConfigFilename, dockerClient, client, true) + generator, err := dockerfile.NewStandardGenerator(src.Config, src.ProjectDir, buildDir, src.ConfigFilename, dockerClient, regClient, true) if err != nil { return fmt.Errorf("Error creating Dockerfile generator: %w", err) } - generator.SetUseCudaBaseImage(buildUseCudaBaseImage) - useCogBaseImage := DetermineUseCogBaseImage(cmd) - if useCogBaseImage != nil { - generator.SetUseCogBaseImage(*useCogBaseImage) + generator.SetUseCudaBaseImage(opts.UseCudaBaseImage) + if opts.UseCogBaseImage != nil { + generator.SetUseCogBaseImage(*opts.UseCogBaseImage) } - if buildSeparateWeights { + if opts.SeparateWeights { + imageName := opts.ImageName if imageName == "" { imageName = config.DockerImageName(src.ProjectDir) } diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 62a6e114b7..3d93bdea27 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -32,7 +32,7 @@ NOTE: cog doctor is experimental. Behavior and checks may change in future versi By default, cog doctor reports problems without modifying any files. Pass --fix to automatically apply safe fixes.`, RunE: func(cmd *cobra.Command, args []string) error { - return runDoctor(cmd.Context(), fix) + return RunDoctor(cmd.Context(), configFilename, fix) }, Args: cobra.NoArgs, // printDoctorResults already prints findings to the user; suppress @@ -47,7 +47,10 @@ Pass --fix to automatically apply safe fixes.`, return cmd } -func runDoctor(ctx context.Context, fix bool) error { +// RunDoctor diagnoses (and optionally fixes) common issues in the Cog project +// described by configFilename. It is shared by both the Cobra and Kong doctor +// commands. +func RunDoctor(ctx context.Context, configFilename string, fix bool) error { console.Warnf("NOTE: cog doctor is experimental. Behavior and checks may change in future versions.") console.Info("") diff --git a/pkg/cli/init.go b/pkg/cli/init.go index b75f864794..98b00c5f9b 100644 --- a/pkg/cli/init.go +++ b/pkg/cli/init.go @@ -37,6 +37,12 @@ and run interface. Edit them to match your model's requirements.`, } func initCommand(cmd *cobra.Command, args []string) error { + return RunInit() +} + +// RunInit sets up the current directory for use with Cog by writing the base +// project template. It is shared by both the Cobra and Kong init commands. +func RunInit() error { console.Info("Setting up the current directory for use with Cog...") console.Info("") diff --git a/pkg/cli/login.go b/pkg/cli/login.go index 721032b0c3..d5dbe82b07 100644 --- a/pkg/cli/login.go +++ b/pkg/cli/login.go @@ -1,6 +1,8 @@ package cli import ( + "context" + "github.com/spf13/cobra" "github.com/replicate/cog/pkg/global" @@ -36,25 +38,38 @@ func login(cmd *cobra.Command, args []string) error { // Initialize the provider registry setup.Init() - // Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var) - registryHost := global.ReplicateRegistryHost - tokenStdin, err := cmd.Flags().GetBool("token-stdin") if err != nil { return err } + return RunLogin(ctx, provider.DefaultRegistry(), LoginOptions{ + TokenStdin: tokenStdin, + // Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var) + Host: global.ReplicateRegistryHost, + }) +} + +// LoginOptions holds the parser-independent options for the login command. +type LoginOptions struct { + TokenStdin bool + Host string +} + +// RunLogin logs in to the container registry for opts.Host. It is shared by +// both the Cobra and Kong login commands. +func RunLogin(ctx context.Context, providerReg *provider.Registry, opts LoginOptions) error { // Look up the provider for this registry - p := provider.DefaultRegistry().ForHost(registryHost) + p := providerReg.ForHost(opts.Host) if p == nil { // This shouldn't happen since GenericProvider matches everything - console.Warnf("No provider found for registry '%s'.", registryHost) - console.Infof("Please use 'docker login %s' to authenticate.", registryHost) + console.Warnf("No provider found for registry '%s'.", opts.Host) + console.Infof("Please use 'docker login %s' to authenticate.", opts.Host) return nil } return p.Login(ctx, provider.LoginOptions{ - TokenStdin: tokenStdin, - Host: registryHost, + TokenStdin: opts.TokenStdin, + Host: opts.Host, }) } From 1f1b4c4b5a1f27129235fd0bfdac335d7cbd4520 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:48:32 -0400 Subject: [PATCH 10/19] feat: add kong predict, run, and train parity via shared runners --- cmd/cog-kong/main_test.go | 14 ++++++ cmd/cog-kong/predict.go | 51 +++++++++++++++++++++ cmd/cog-kong/stubs.go | 18 -------- cmd/cog-kong/train.go | 28 ++++++++++++ pkg/cli/predict.go | 95 +++++++++++++++++++++++++++------------ pkg/cli/train.go | 59 +++++++++++++++++++----- 6 files changed, 207 insertions(+), 58 deletions(-) create mode 100644 cmd/cog-kong/predict.go create mode 100644 cmd/cog-kong/train.go diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index a42a22154b..e35c5c88e4 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -251,3 +251,17 @@ func TestKongSimpleCommandFlagsParse(t *testing.T) { require.NoErrorf(t, err, "parse %v", args) } } + +func TestKongPredictionCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"predict", "example/image", "--input", "prompt=cat", "--help"}, + {"run", "example/image", "--input", "prompt=cat", "--output", "out.png", "--help"}, + {"run", "--json", "@inputs.json", "--gpus", "all", "--help"}, + {"train", "example/image", "--input", "dataset=@data.json", "--help"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} diff --git a/cmd/cog-kong/predict.go b/cmd/cog-kong/predict.go new file mode 100644 index 0000000000..0ca3ddf00e --- /dev/null +++ b/cmd/cog-kong/predict.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" +) + +// predictionFlags groups the flags shared by the predict and run commands. +type predictionFlags struct { + RuntimeFlags `embed:""` + + Image string `arg:"" optional:"" name:"image" help:"Image to run. If omitted, builds from cog.yaml in the current directory."` + Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` + Output string `name:"output" short:"o" help:"Output path."` + JSON string `name:"json" help:"Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-)."` + UseReplicate bool `name:"use-replicate-token" help:"Pass REPLICATE_API_TOKEN from local environment into the model context."` + SetupTimeout uint32 `name:"setup-timeout" default:"300" help:"The timeout for a container to setup (in seconds)."` +} + +func (f predictionFlags) options(use string) cli.PredictionCommandOptions { + return cli.PredictionCommandOptions{ + RuntimeBuildOptions: f.RuntimeFlags.Options(), + Use: use, + Image: f.Image, + Input: f.Input, + InputJSON: f.JSON, + OutputPath: f.Output, + SetupTimeout: f.SetupTimeout, + UseReplicateAPI: f.UseReplicate, + } +} + +// PredictCmd implements the hidden, deprecated "cog predict" command. +type PredictCmd struct { + predictionFlags `embed:""` +} + +func (cmd *PredictCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunPrediction(ctx, dockerClient, cmd.predictionFlags.options("predict")) +} + +// RunCmd implements the "cog run" command. +type RunCmd struct { + predictionFlags `embed:""` +} + +func (cmd *RunCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunPrediction(ctx, dockerClient, cmd.predictionFlags.options("run")) +} diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go index 3524caea05..b9495d1cf3 100644 --- a/cmd/cog-kong/stubs.go +++ b/cmd/cog-kong/stubs.go @@ -21,24 +21,6 @@ func (cmd *BaseImageBuildCmd) Run() error { return errKongCommandNotImplemented } -type PredictCmd struct{} - -func (cmd *PredictCmd) Run() error { - return errKongCommandNotImplemented -} - -type RunCmd struct{} - -func (cmd *RunCmd) Run() error { - return errKongCommandNotImplemented -} - -type TrainCmd struct{} - -func (cmd *TrainCmd) Run() error { - return errKongCommandNotImplemented -} - type WeightsCmd struct { Import WeightsImportCmd `cmd:"" help:"Import model weights."` Pull WeightsPullCmd `cmd:"" help:"Pull model weights."` diff --git a/cmd/cog-kong/train.go b/cmd/cog-kong/train.go new file mode 100644 index 0000000000..d49f9634fc --- /dev/null +++ b/cmd/cog-kong/train.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" +) + +// TrainCmd implements the hidden, deprecated "cog train" command. +type TrainCmd struct { + RuntimeFlags `embed:""` + + Image string `arg:"" optional:"" name:"image" help:"Image to train. If omitted, builds from cog.yaml in the current directory."` + Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` + Output string `name:"output" short:"o" default:"weights" help:"Output path."` + SetupTimeout uint32 `name:"setup-timeout" default:"300" help:"The timeout for a container to setup (in seconds)."` +} + +func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunTrain(ctx, dockerClient, cli.TrainCommandOptions{ + RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + Image: cmd.Image, + Input: cmd.Input, + OutputPath: cmd.Output, + SetupTimeout: cmd.SetupTimeout, + }) +} diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index a8c322215c..27fea328f8 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -189,21 +189,64 @@ func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) { } func cmdPredict(cmd *cobra.Command, args []string) error { - if cmd.CalledAs() == "predict" || cmd.Name() == "predict" { - console.Warn(`"cog predict" is deprecated, use "cog run"`) + dockerClient, err := docker.NewClient(cmd.Context()) + if err != nil { + return err } - ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) - defer stop() + image := "" + if len(args) > 0 { + image = args[0] + } - dockerClient, err := docker.NewClient(ctx) - if err != nil { - return err + return RunPrediction(cmd.Context(), dockerClient, PredictionCommandOptions{ + Use: cmd.CalledAs(), + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Image: image, + Input: inputFlags, + InputJSON: inputJSON, + OutputPath: outPath, + SetupTimeout: setupTimeout, + UseReplicateAPI: useReplicateAPIToken, + }) +} + +// PredictionCommandOptions holds everything RunPrediction needs that is +// independent of the argument parser. +type PredictionCommandOptions struct { + RuntimeBuildOptions + + // Use is the invoked command name ("predict" or "run"); when "predict" a + // deprecation warning is printed. + Use string + Image string + Input []string + InputJSON string + OutputPath string + SetupTimeout uint32 + UseReplicateAPI bool +} + +// RunPrediction builds or pulls a model image and runs a prediction. It is +// shared by both the Cobra and Kong predict/run commands. +func RunPrediction(parent context.Context, dockerClient command.Command, opts PredictionCommandOptions) error { + if opts.Use == "predict" { + console.Warn(`"cog predict" is deprecated, use "cog run"`) } + ctx, stop := signal.NotifyContext(parent, syscall.SIGINT, syscall.SIGTERM) + defer stop() + imageName := "" volumes := []command.Volume{} - gpus := gpusFlag + gpus := opts.GPUs // The Manager is built only when we have cog.yaml in scope (the // build-from-source path). Pre-built images are opaque to Cog and @@ -212,9 +255,9 @@ func cmdPredict(cmd *cobra.Command, args []string) error { resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - if len(args) == 0 { + if opts.Image == "" { // Build image - src, err := model.NewSource(configFilename) + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -226,7 +269,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } @@ -248,7 +291,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { } } else { // Use existing image - imageName = args[0] + imageName = opts.Image // If the image name contains '=', then it's probably a mistake if strings.Contains(imageName, "=") { @@ -273,11 +316,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Info("") console.Info("Starting Docker image and running setup()...") - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) - } + env := runtimeEnv(opts.Env) predictor, err := predict.NewPredictor(ctx, predict.PredictorOptions{ RunOptions: command.RunOptions{ @@ -294,7 +333,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - timeout := time.Duration(setupTimeout) * time.Second + timeout := time.Duration(opts.SetupTimeout) * time.Second if err := predictor.Start(ctx, os.Stderr, timeout); err != nil { // Only retry if we're using a GPU but the user didn't explicitly select a GPU with --gpus // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it @@ -332,21 +371,21 @@ func cmdPredict(cmd *cobra.Command, args []string) error { } }() - if inputJSON != "" { - if len(inputFlags) > 0 { + if opts.InputJSON != "" { + if len(opts.Input) > 0 { return fmt.Errorf("Must use one of --json or --input to provide model inputs") } - return predictJSONInputs(*predictor, inputJSON, outPath, false) + return predictJSONInputs(*predictor, opts.InputJSON, opts.OutputPath, false, opts.UseReplicateAPI) } - return predictIndividualInputs(*predictor, inputFlags, outPath, false) + return predictIndividualInputs(*predictor, opts.Input, opts.OutputPath, false, opts.UseReplicateAPI) } func isURI(ref *openapi3.Schema) bool { return ref != nil && ref.Type.Is("string") && ref.Format == "uri" } -func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath string, isTrain bool) error { +func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath string, isTrain bool, useReplicateAPI bool) error { jsonInputs, err := parseJSONInput(jsonInput) if err != nil { return err @@ -373,10 +412,10 @@ func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath } } - return runPrediction(predictor, inputs, outputPath, isTrain, true) + return runPrediction(predictor, inputs, outputPath, isTrain, true, useReplicateAPI) } -func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool) error { +func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool, useReplicateAPI bool) error { schema, err := predictor.GetSchema() if err != nil { return err @@ -387,10 +426,10 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - return runPrediction(predictor, inputs, outputPath, isTrain, false) + return runPrediction(predictor, inputs, outputPath, isTrain, false, useReplicateAPI) } -func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPath string, isTrain bool, needsJSON bool) error { +func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPath string, isTrain bool, needsJSON bool, useReplicateAPI bool) error { if isTrain { console.Info("Running training...") } else { @@ -421,7 +460,7 @@ func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPat context := predict.RequestContext{} - if useReplicateAPIToken { + if useReplicateAPI { context.ReplicateAPIToken = os.Getenv("REPLICATE_API_TOKEN") if context.ReplicateAPIToken == "" { return fmt.Errorf("Failed to find REPLICATE_API_TOKEN in the current environment when called with --use-replicate-token") diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 2fa4824296..337ca8e7ae 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -55,26 +55,61 @@ Otherwise, it will build the model in the current directory and train it.`, } func cmdTrain(cmd *cobra.Command, args []string) error { - ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) - defer stop() - - dockerClient, err := docker.NewClient(ctx) + dockerClient, err := docker.NewClient(cmd.Context()) if err != nil { return err } + image := "" + if len(args) > 0 { + image = args[0] + } + + return RunTrain(cmd.Context(), dockerClient, TrainCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: trainEnvFlags, + }, + Image: image, + Input: trainInputFlags, + OutputPath: trainOutPath, + SetupTimeout: setupTimeout, + }) +} + +// TrainCommandOptions holds everything RunTrain needs that is independent of +// the argument parser. +type TrainCommandOptions struct { + RuntimeBuildOptions + + Image string + Input []string + OutputPath string + SetupTimeout uint32 +} + +// RunTrain builds or pulls a model image and runs a training job. It is shared +// by both the Cobra and Kong train commands. +func RunTrain(parent context.Context, dockerClient command.Command, opts TrainCommandOptions) error { + ctx, stop := signal.NotifyContext(parent, syscall.SIGINT, syscall.SIGTERM) + defer stop() + imageName := "" volumes := []command.Volume{} - gpus := gpusFlag + gpus := opts.GPUs // Managed-weight mounts only apply when we have cog.yaml in scope. var wm *weights.Manager resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - if len(args) == 0 { + if opts.Image == "" { // Build image - src, err := model.NewSource(configFilename) + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -86,7 +121,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } @@ -108,7 +143,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { } } else { // Use existing image - imageName = args[0] + imageName = opts.Image // Pull the image (if needed) and validate it's a Cog model ref, err := model.ParseRef(imageName) @@ -133,7 +168,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { GPUs: gpus, Image: imageName, Volumes: volumes, - Env: trainEnvFlags, + Env: opts.Env, Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, }, IsTrain: true, @@ -144,7 +179,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - if err := predictor.Start(ctx, os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil { + if err := predictor.Start(ctx, os.Stderr, time.Duration(opts.SetupTimeout)*time.Second); err != nil { return err } @@ -156,5 +191,5 @@ func cmdTrain(cmd *cobra.Command, args []string) error { } }() - return predictIndividualInputs(*predictor, trainInputFlags, trainOutPath, true) + return predictIndividualInputs(*predictor, opts.Input, opts.OutputPath, true, false) } From 89a4b4f2f15c128648ad0b824b6b4d60bb1e6b67 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:51:23 -0400 Subject: [PATCH 11/19] feat: add kong weights command parity and fix version short-flag collision --- cmd/cog-kong/cli.go | 2 +- cmd/cog-kong/main_test.go | 13 ++++++++++ cmd/cog-kong/stubs.go | 23 ------------------ cmd/cog-kong/weights.go | 51 +++++++++++++++++++++++++++++++++++++++ pkg/cli/weights.go | 25 +++++++++++-------- pkg/cli/weights_pull.go | 9 ++++--- pkg/cli/weights_status.go | 9 ++++--- 7 files changed, 92 insertions(+), 40 deletions(-) create mode 100644 cmd/cog-kong/weights.go diff --git a/cmd/cog-kong/cli.go b/cmd/cog-kong/cli.go index 34def0795a..f0ec1c0cee 100644 --- a/cmd/cog-kong/cli.go +++ b/cmd/cog-kong/cli.go @@ -17,7 +17,7 @@ type Globals struct { Help bool `name:"help" short:"h" help:"Show context-sensitive help."` Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."` Profile bool `name:"profile" hidden:"" help:"Enable profiling."` - Version kong.VersionFlag `name:"version" short:"v" help:"Show version of Cog."` + Version kong.VersionFlag `name:"version" help:"Show version of Cog."` } // AfterApply runs after flag parsing, before the command's Run. diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index e35c5c88e4..bc18ad9859 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -265,3 +265,16 @@ func TestKongPredictionCommandFlagsParse(t *testing.T) { require.NoErrorf(t, err, "parse %v", args) } } + +func TestKongWeightsCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"weights", "import", "--dry-run", "--verbose", "model.safetensors", "--help"}, + {"weights", "pull", "--verbose", "weights-name", "--help"}, + {"weights", "status", "--json", "--verbose", "--help"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go index b9495d1cf3..27bb6245cf 100644 --- a/cmd/cog-kong/stubs.go +++ b/cmd/cog-kong/stubs.go @@ -21,26 +21,3 @@ func (cmd *BaseImageBuildCmd) Run() error { return errKongCommandNotImplemented } -type WeightsCmd struct { - Import WeightsImportCmd `cmd:"" help:"Import model weights."` - Pull WeightsPullCmd `cmd:"" help:"Pull model weights."` - Status WeightsStatusCmd `cmd:"" help:"Show model weight status."` -} - -type WeightsImportCmd struct{} - -func (cmd *WeightsImportCmd) Run() error { - return errKongCommandNotImplemented -} - -type WeightsPullCmd struct{} - -func (cmd *WeightsPullCmd) Run() error { - return errKongCommandNotImplemented -} - -type WeightsStatusCmd struct{} - -func (cmd *WeightsStatusCmd) Run() error { - return errKongCommandNotImplemented -} diff --git a/cmd/cog-kong/weights.go b/cmd/cog-kong/weights.go new file mode 100644 index 0000000000..900b40b6f2 --- /dev/null +++ b/cmd/cog-kong/weights.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" +) + +// WeightsCmd implements the hidden, experimental "cog weights" command group. +type WeightsCmd struct { + Import WeightsImportCmd `cmd:"" help:"Build and push weights to a registry."` + Pull WeightsPullCmd `cmd:"" help:"Populate the local weight cache from the registry."` + Status WeightsStatusCmd `cmd:"" help:"Show the status of configured weights."` +} + +// WeightsImportCmd implements "cog weights import". +type WeightsImportCmd struct { + ConfigFlag `embed:""` + + DryRun bool `name:"dry-run" help:"Show what would be imported without making changes."` + Verbose bool `name:"verbose" short:"v" help:"Show per-file details."` + Names []string `arg:"" optional:"" name:"name" help:"Weight names to import. If omitted, all weights in cog.yaml are imported."` +} + +func (cmd *WeightsImportCmd) Run(ctx context.Context) error { + return cli.RunWeightsImport(ctx, cmd.File, cmd.Names, cmd.DryRun, cmd.Verbose) +} + +// WeightsPullCmd implements "cog weights pull". +type WeightsPullCmd struct { + ConfigFlag `embed:""` + + Verbose bool `name:"verbose" short:"v" help:"Show per-layer and per-file progress."` + Names []string `arg:"" optional:"" name:"name" help:"Weight names to pull. If omitted, all weights in cog.yaml are pulled."` +} + +func (cmd *WeightsPullCmd) Run(ctx context.Context) error { + return cli.RunWeightsPull(ctx, cmd.File, cmd.Names, cmd.Verbose) +} + +// WeightsStatusCmd implements "cog weights status". +type WeightsStatusCmd struct { + ConfigFlag `embed:""` + + JSON bool `name:"json" help:"Output as JSON."` + Verbose bool `name:"verbose" short:"v" help:"Show per-layer status."` +} + +func (cmd *WeightsStatusCmd) Run(ctx context.Context) error { + return cli.RunWeightsStatus(ctx, cmd.File, cmd.JSON, cmd.Verbose) +} diff --git a/pkg/cli/weights.go b/pkg/cli/weights.go index de9c1941a4..344d0e4f4b 100644 --- a/pkg/cli/weights.go +++ b/pkg/cli/weights.go @@ -27,11 +27,6 @@ func newWeightsCommand() *cobra.Command { NOTE: cog weights is experimental. Behavior may change in future versions. Do not rely on it in production workflows.`, Hidden: true, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - console.Warnf("NOTE: cog weights is experimental. Behavior may change in future versions.") - console.Warnf("Do not rely on it in production workflows.") - console.Info("") - }, } cmd.AddCommand(newWeightsImportCommand()) @@ -66,7 +61,7 @@ Use --dry-run to preview what would change without importing anything. Add --verbose to see per-file details including which files pass the filter.`, Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsImportCommand(cmd, args, dryRun, verbose) + return RunWeightsImport(cmd.Context(), configFilename, args, dryRun, verbose) }, } @@ -76,8 +71,18 @@ Add --verbose to see per-file details including which files pass the filter.`, return cmd } -func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose bool) error { - ctx := cmd.Context() +// weightsExperimentalWarning prints the experimental notice shared by all +// weights subcommands. It runs once per subcommand invocation. +func weightsExperimentalWarning() { + console.Warnf("NOTE: cog weights is experimental. Behavior may change in future versions.") + console.Warnf("Do not rely on it in production workflows.") + console.Info("") +} + +// RunWeightsImport builds and pushes weights to a registry. It is shared by +// both the Cobra and Kong weights import commands. +func RunWeightsImport(ctx context.Context, configFilename string, args []string, dryRun, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { @@ -92,7 +97,7 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo return err } - weightSpecs, err := collectWeightSpecs(src, args) + weightSpecs, err := collectWeightSpecs(src, configFilename, args) if err != nil { return err } @@ -233,7 +238,7 @@ func planStatusIcon(status model.WeightImportPlanStatus) string { // collectWeightSpecs extracts WeightSpecs from the source, optionally // filtered to only the names listed in filterNames. An error is returned // if no weights match or if a requested name doesn't exist. -func collectWeightSpecs(src *model.Source, filterNames []string) ([]*model.WeightSpec, error) { +func collectWeightSpecs(src *model.Source, configFilename string, filterNames []string) ([]*model.WeightSpec, error) { if len(src.Config.Weights) == 0 { return nil, fmt.Errorf("no weights defined in %s", configFilename) } diff --git a/pkg/cli/weights_pull.go b/pkg/cli/weights_pull.go index 97ad2b79f1..30389ed0ff 100644 --- a/pkg/cli/weights_pull.go +++ b/pkg/cli/weights_pull.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "path/filepath" @@ -40,7 +41,7 @@ home directory is on a different filesystem than your project. Use --verbose to show per-layer and per-file progress.`, Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsPullCommand(cmd, args, verbose) + return RunWeightsPull(cmd.Context(), configFilename, args, verbose) }, } @@ -49,8 +50,10 @@ Use --verbose to show per-layer and per-file progress.`, return cmd } -func weightsPullCommand(cmd *cobra.Command, args []string, verbose bool) error { - ctx := cmd.Context() +// RunWeightsPull populates the local weight cache from the registry. It is +// shared by both the Cobra and Kong weights pull commands. +func RunWeightsPull(ctx context.Context, configFilename string, args []string, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { diff --git a/pkg/cli/weights_status.go b/pkg/cli/weights_status.go index 7490964201..2102355107 100644 --- a/pkg/cli/weights_status.go +++ b/pkg/cli/weights_status.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "errors" "fmt" @@ -81,7 +82,7 @@ Use --verbose to show per-layer status for each weight. Exit code is 0 when all weights are ready, 1 otherwise.`, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsStatusCommand(cmd, jsonOutput, verbose) + return RunWeightsStatus(cmd.Context(), configFilename, jsonOutput, verbose) }, } @@ -92,8 +93,10 @@ Exit code is 0 when all weights are ready, 1 otherwise.`, return cmd } -func weightsStatusCommand(cmd *cobra.Command, jsonOutput, verbose bool) error { - ctx := cmd.Context() +// RunWeightsStatus shows the status of configured weights. It is shared by both +// the Cobra and Kong weights status commands. +func RunWeightsStatus(ctx context.Context, configFilename string, jsonOutput, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { From acff1645399a4e125522eb8193fc707e09093464 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:53:48 -0400 Subject: [PATCH 12/19] feat: add kong base-image command parity and remove stubs --- cmd/cog-kong/baseimage.go | 62 +++++++++++++++++ cmd/cog-kong/main_test.go | 12 ++++ cmd/cog-kong/stubs.go | 23 ------- pkg/cli/baseimage.go | 138 ++++++++++++++++++++++++-------------- 4 files changed, 161 insertions(+), 74 deletions(-) create mode 100644 cmd/cog-kong/baseimage.go delete mode 100644 cmd/cog-kong/stubs.go diff --git a/cmd/cog-kong/baseimage.go b/cmd/cog-kong/baseimage.go new file mode 100644 index 0000000000..6d2e2fab74 --- /dev/null +++ b/cmd/cog-kong/baseimage.go @@ -0,0 +1,62 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" +) + +// BaseImageCmd implements the experimental "cog base-image" command group. +type BaseImageCmd struct { + Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Display Cog base image Dockerfile."` + Build BaseImageBuildCmd `cmd:"" help:"Build Cog base image."` +} + +// baseImageVersionFlags groups the version-selecting flags shared by the +// base-image subcommands. +type baseImageVersionFlags struct { + CUDA string `name:"cuda" help:"CUDA version."` + Python string `name:"python" help:"Python version."` + Torch string `name:"torch" help:"Torch version."` + + // Hidden flags for parity with the Cobra base-image command. + BreakSystemPackages bool `name:"break-system-packages" hidden:"" help:"Allow pip to modify uv-managed Python installs."` + BuildContextDir string `name:"build-context-dir" hidden:"" help:"Directory for generated Docker build context artifacts."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` +} + +func (f baseImageVersionFlags) options() cli.BaseImageOptions { + return cli.BaseImageOptions{ + CUDAVersion: f.CUDA, + PythonVersion: f.Python, + TorchVersion: f.Torch, + BreakSystemPackages: f.BreakSystemPackages, + BuildContextDir: f.BuildContextDir, + Timestamp: f.Timestamp, + } +} + +// BaseImageDockerfileCmd implements "cog base-image dockerfile". +type BaseImageDockerfileCmd struct { + baseImageVersionFlags `embed:""` + + NoCache bool `name:"no-cache" help:"Do not use cache when building the image."` + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` +} + +func (cmd *BaseImageDockerfileCmd) Run(ctx context.Context) error { + opts := cmd.baseImageVersionFlags.options() + opts.NoCache = cmd.NoCache + opts.ProgressOutput = cmd.Progress + return cli.RunBaseImageDockerfile(ctx, opts) +} + +// BaseImageBuildCmd implements "cog base-image build". +type BaseImageBuildCmd struct { + baseImageVersionFlags `embed:""` +} + +func (cmd *BaseImageBuildCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunBaseImageBuild(ctx, dockerClient, cmd.baseImageVersionFlags.options()) +} diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index bc18ad9859..85b6f636aa 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -278,3 +278,15 @@ func TestKongWeightsCommandFlagsParse(t *testing.T) { require.NoErrorf(t, err, "parse %v", args) } } + +func TestKongBaseImageCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"base-image", "dockerfile", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--no-cache", "--progress", "plain", "--help"}, + {"base-image", "build", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--help"}, + } { + _, err := parser.Parse(args) + require.NoErrorf(t, err, "parse %v", args) + } +} diff --git a/cmd/cog-kong/stubs.go b/cmd/cog-kong/stubs.go deleted file mode 100644 index 27bb6245cf..0000000000 --- a/cmd/cog-kong/stubs.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import "errors" - -var errKongCommandNotImplemented = errors.New("kong command not implemented") - -type BaseImageCmd struct { - Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Generate a Dockerfile for a Cog base image."` - Build BaseImageBuildCmd `cmd:"" help:"Build a Cog base image."` -} - -type BaseImageDockerfileCmd struct{} - -func (cmd *BaseImageDockerfileCmd) Run() error { - return errKongCommandNotImplemented -} - -type BaseImageBuildCmd struct{} - -func (cmd *BaseImageBuildCmd) Run() error { - return errKongCommandNotImplemented -} - diff --git a/pkg/cli/baseimage.go b/pkg/cli/baseimage.go index f9831f6fff..6ce6f95e78 100644 --- a/pkg/cli/baseimage.go +++ b/pkg/cli/baseimage.go @@ -132,18 +132,7 @@ func newBaseImageDockerfileCommand() *cobra.Command { Use: "dockerfile", Short: "Display Cog base image Dockerfile", RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - generator, err := baseImageGeneratorFromFlags(ctx) - if err != nil { - return err - } - dockerfile, err := generator.GenerateDockerfile(ctx) - if err != nil { - return err - } - fmt.Println(dockerfile) - return nil + return RunBaseImageDockerfile(cmd.Context(), baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -159,42 +148,11 @@ func newBaseImageBuildCommand() *cobra.Command { Use: "build", Short: "Build Cog base image", RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - dockerClient, err := docker.NewClient(ctx) - if err != nil { - return err - } - - generator, err := baseImageGeneratorFromFlags(ctx) - if err != nil { - return err - } - dockerfileContents, err := generator.GenerateDockerfile(ctx) - if err != nil { - return err - } - - cwd, err := os.Getwd() + dockerClient, err := docker.NewClient(cmd.Context()) if err != nil { return err } - baseImageName := dockerfile.BaseImageName(baseImageCUDAVersion, baseImagePythonVersion, baseImageTorchVersion) - - buildOpts := command.ImageBuildOptions{ - WorkingDir: cwd, - DockerfileContents: dockerfileContents, - ImageName: baseImageName, - NoCache: buildNoCache, - ProgressOutput: buildProgressOutput, - Epoch: &config.BuildSourceEpochTimestamp, - ContextDir: ".", - } - if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { - return err - } - fmt.Println("Successfully built image: " + baseImageName) - return nil + return RunBaseImageBuild(cmd.Context(), dockerClient, baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -203,6 +161,84 @@ func newBaseImageBuildCommand() *cobra.Command { return cmd } +// BaseImageOptions holds the parser-independent options shared by the +// base-image dockerfile and build commands. +type BaseImageOptions struct { + CUDAVersion string + PythonVersion string + TorchVersion string + BreakSystemPackages bool + BuildContextDir string + NoCache bool + ProgressOutput string + Timestamp int64 +} + +// baseImageOptionsFromFlags reads the package-level Cobra flag globals into a +// BaseImageOptions value. +func baseImageOptionsFromFlags() BaseImageOptions { + return BaseImageOptions{ + CUDAVersion: baseImageCUDAVersion, + PythonVersion: baseImagePythonVersion, + TorchVersion: baseImageTorchVersion, + BreakSystemPackages: baseImageBreakSystemPackages, + BuildContextDir: baseImageBuildContextDir, + NoCache: buildNoCache, + ProgressOutput: buildProgressOutput, + Timestamp: config.BuildSourceEpochTimestamp, + } +} + +// RunBaseImageDockerfile generates and prints the base image Dockerfile. It is +// shared by both the Cobra and Kong base-image dockerfile commands. +func RunBaseImageDockerfile(ctx context.Context, opts BaseImageOptions) error { + generator, err := baseImageGenerator(ctx, opts) + if err != nil { + return err + } + contents, err := generator.GenerateDockerfile(ctx) + if err != nil { + return err + } + fmt.Println(contents) + return nil +} + +// RunBaseImageBuild builds a Cog base image. It is shared by both the Cobra and +// Kong base-image build commands. +func RunBaseImageBuild(ctx context.Context, dockerClient command.Command, opts BaseImageOptions) error { + generator, err := baseImageGenerator(ctx, opts) + if err != nil { + return err + } + dockerfileContents, err := generator.GenerateDockerfile(ctx) + if err != nil { + return err + } + + cwd, err := os.Getwd() + if err != nil { + return err + } + baseImageName := dockerfile.BaseImageName(opts.CUDAVersion, opts.PythonVersion, opts.TorchVersion) + + timestamp := opts.Timestamp + buildOpts := command.ImageBuildOptions{ + WorkingDir: cwd, + DockerfileContents: dockerfileContents, + ImageName: baseImageName, + NoCache: opts.NoCache, + ProgressOutput: opts.ProgressOutput, + Epoch: ×tamp, + ContextDir: ".", + } + if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { + return err + } + fmt.Println("Successfully built image: " + baseImageName) + return nil +} + func addBaseImageFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&baseImageCUDAVersion, "cuda", "", "CUDA version") cmd.Flags().StringVar(&baseImagePythonVersion, "python", "", "Python version") @@ -214,7 +250,7 @@ func addBaseImageFlags(cmd *cobra.Command) { addBuildTimestampFlag(cmd) } -func baseImageGeneratorFromFlags(ctx context.Context) (*dockerfile.BaseImageGenerator, error) { +func baseImageGenerator(ctx context.Context, opts BaseImageOptions) (*dockerfile.BaseImageGenerator, error) { dockerClient, err := docker.NewClient(ctx) if err != nil { return nil, err @@ -223,16 +259,16 @@ func baseImageGeneratorFromFlags(ctx context.Context) (*dockerfile.BaseImageGene generator, err := dockerfile.NewBaseImageGenerator( ctx, client, - baseImageCUDAVersion, - baseImagePythonVersion, - baseImageTorchVersion, + opts.CUDAVersion, + opts.PythonVersion, + opts.TorchVersion, dockerClient, true, ) if err != nil { return nil, err } - generator.SetBreakSystemPackages(baseImageBreakSystemPackages) - generator.SetBuildContextDir(baseImageBuildContextDir) + generator.SetBreakSystemPackages(opts.BreakSystemPackages) + generator.SetBuildContextDir(opts.BuildContextDir) return generator, nil } From 92f0237fe6a84c21b3642d8eaca5a38c54d9d155 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Wed, 3 Jun 2026 16:55:59 -0400 Subject: [PATCH 13/19] test: verify kong cli parity and clean up unused cobra helpers --- cmd/cog-kong/baseimage.go | 4 ++-- cmd/cog-kong/build.go | 2 +- cmd/cog-kong/main_test.go | 17 +++++++++++++++++ cmd/cog-kong/predict.go | 6 +++--- cmd/cog-kong/push.go | 2 +- cmd/cog-kong/runtime.go | 4 ++-- cmd/cog-kong/train.go | 2 +- pkg/cli/build.go | 6 ------ pkg/cli/serve.go | 9 --------- 9 files changed, 27 insertions(+), 25 deletions(-) diff --git a/cmd/cog-kong/baseimage.go b/cmd/cog-kong/baseimage.go index 6d2e2fab74..0e07bff4ec 100644 --- a/cmd/cog-kong/baseimage.go +++ b/cmd/cog-kong/baseimage.go @@ -46,7 +46,7 @@ type BaseImageDockerfileCmd struct { } func (cmd *BaseImageDockerfileCmd) Run(ctx context.Context) error { - opts := cmd.baseImageVersionFlags.options() + opts := cmd.options() opts.NoCache = cmd.NoCache opts.ProgressOutput = cmd.Progress return cli.RunBaseImageDockerfile(ctx, opts) @@ -58,5 +58,5 @@ type BaseImageBuildCmd struct { } func (cmd *BaseImageBuildCmd) Run(ctx context.Context, dockerClient command.Command) error { - return cli.RunBaseImageBuild(ctx, dockerClient, cmd.baseImageVersionFlags.options()) + return cli.RunBaseImageBuild(ctx, dockerClient, cmd.options()) } diff --git a/cmd/cog-kong/build.go b/cmd/cog-kong/build.go index e73e2f72f9..15b5c07031 100644 --- a/cmd/cog-kong/build.go +++ b/cmd/cog-kong/build.go @@ -25,6 +25,6 @@ func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regC return cli.RunBuild(ctx, dockerClient, regClient, cli.BuildCommandOptions{ ConfigFilename: cmd.File, Tag: cmd.Tag, - Flags: cmd.BuildFlags.Options(), + Flags: cmd.Options(), }) } diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index 85b6f636aa..175e829635 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -290,3 +290,20 @@ func TestKongBaseImageCommandFlagsParse(t *testing.T) { require.NoErrorf(t, err, "parse %v", args) } } + +func TestKongCommandCoverageMatchesExpectedCobraSurface(t *testing.T) { + parser := newTestParser(t) + commands := map[string]bool{} + for _, node := range parser.Model.Children { + commands[node.Name] = true + } + + expected := []string{ + "base-image", "build", "debug", "doctor", "exec", "init", + "login", "predict", "push", "run", "serve", "train", "weights", + } + require.Len(t, commands, len(expected)) + for _, name := range expected { + require.Truef(t, commands[name], "missing command %q", name) + } +} diff --git a/cmd/cog-kong/predict.go b/cmd/cog-kong/predict.go index 0ca3ddf00e..6c086442e8 100644 --- a/cmd/cog-kong/predict.go +++ b/cmd/cog-kong/predict.go @@ -21,7 +21,7 @@ type predictionFlags struct { func (f predictionFlags) options(use string) cli.PredictionCommandOptions { return cli.PredictionCommandOptions{ - RuntimeBuildOptions: f.RuntimeFlags.Options(), + RuntimeBuildOptions: f.Options(), Use: use, Image: f.Image, Input: f.Input, @@ -38,7 +38,7 @@ type PredictCmd struct { } func (cmd *PredictCmd) Run(ctx context.Context, dockerClient command.Command) error { - return cli.RunPrediction(ctx, dockerClient, cmd.predictionFlags.options("predict")) + return cli.RunPrediction(ctx, dockerClient, cmd.options("predict")) } // RunCmd implements the "cog run" command. @@ -47,5 +47,5 @@ type RunCmd struct { } func (cmd *RunCmd) Run(ctx context.Context, dockerClient command.Command) error { - return cli.RunPrediction(ctx, dockerClient, cmd.predictionFlags.options("run")) + return cli.RunPrediction(ctx, dockerClient, cmd.options("run")) } diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go index be8790c96d..30727f81a5 100644 --- a/cmd/cog-kong/push.go +++ b/cmd/cog-kong/push.go @@ -26,6 +26,6 @@ func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regCl return cli.RunPush(ctx, dockerClient, regClient, providerReg, cli.PushCommandOptions{ ConfigFilename: cmd.File, Image: cmd.Image, - Flags: cmd.BuildFlags.Options(), + Flags: cmd.Options(), }) } diff --git a/cmd/cog-kong/runtime.go b/cmd/cog-kong/runtime.go index 437a2d7098..8e7acb3e24 100644 --- a/cmd/cog-kong/runtime.go +++ b/cmd/cog-kong/runtime.go @@ -47,7 +47,7 @@ type ServeCmd struct { func (cmd *ServeCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { return cli.RunServe(ctx, dockerClient, regClient, cli.ServeCommandOptions{ - RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + RuntimeBuildOptions: cmd.Options(), Port: cmd.Port, UploadURL: cmd.UploadURL, }) @@ -70,7 +70,7 @@ func (cmd *ExecCmd) Validate() error { func (cmd *ExecCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { return cli.RunExec(ctx, dockerClient, regClient, cli.ExecCommandOptions{ - RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + RuntimeBuildOptions: cmd.Options(), Args: cmd.Args, Ports: cmd.Publish, }) diff --git a/cmd/cog-kong/train.go b/cmd/cog-kong/train.go index d49f9634fc..d79f16c09a 100644 --- a/cmd/cog-kong/train.go +++ b/cmd/cog-kong/train.go @@ -19,7 +19,7 @@ type TrainCmd struct { func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) error { return cli.RunTrain(ctx, dockerClient, cli.TrainCommandOptions{ - RuntimeBuildOptions: cmd.RuntimeFlags.Options(), + RuntimeBuildOptions: cmd.Options(), Image: cmd.Image, Input: cmd.Input, OutputPath: cmd.Output, diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 73542bdc17..c4ba0cd52d 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -298,9 +298,3 @@ func addSkipSchemaValidationFlag(cmd *cobra.Command) { cmd.Flags().BoolVar(&buildSkipSchemaValidation, "skip-schema-validation", false, "Skip OpenAPI schema generation and validation") _ = cmd.Flags().MarkHidden("skip-schema-validation") } - -// buildOptionsFromFlags creates BuildOptions from the current CLI flag values. -// The imageName and annotations parameters vary by command and must be provided. -func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions { - return buildFlagsOptionsFromCobra(cmd).ModelBuildOptions(imageName, annotations) -} diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index c4b82e9cdd..21cbadf990 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -105,15 +105,6 @@ func runtimeEnv(env []string) []string { return out } -// serveBuildOptions creates BuildOptions for cog serve from Cobra flag globals. -func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { - return RuntimeBuildOptions{ - ProgressOutput: buildProgressOutput, - UseCudaBaseImage: buildUseCudaBaseImage, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - }.ServeBuildOptions() -} - // ServeCommandOptions holds everything RunServe needs that is independent of // the argument parser. type ServeCommandOptions struct { From 63c65dfcbe0f2507ded4f1518d41be9480f323fb Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Fri, 5 Jun 2026 09:26:30 -0400 Subject: [PATCH 14/19] feat: support JSON-native union inputs (#3048) * feat: support JSON-native union inputs Add support for union input types such as `str | float` and `str | float | None`. Unions are restricted to JSON-native members (str, int, float, bool, dict/Any, list[T], None) so request validation happens at the HTTP edge against the OpenAPI schema. Unions involving Path, File, Secret, custom coders, and BaseModel are rejected at build/schema-generation time, and output unions remain unsupported. - pkg/schema: add a recursive InputType model and resolver, emit OpenAPI anyOf for union inputs, and keep multi-variant nullable unions required when no default is supplied - pkg/predict: parse numeric CLI values for unions that accept a number (schemaAcceptsNumber), and fall back to a string member when a numeric parse fails for number-first unions like `float | str` (schemaAcceptsString) - python/cog/_adt: add deterministic union normalisation with strict per-member compatibility (bool not int/float, scalars not dict/Any, list unions validate elements) and anyOf json_type emission - tests: Go unit tests, Python regression tests, and end-to-end txtar integration tests for HTTP, CLI, and list unions - docs: document union inputs and nullable semantics * fix: correct CLI numeric parsing for integer-only and float unions Address two related bugs in CLI `-i` parsing of union inputs: - `int | float` resolves to the integer member first, so a fractional value like `1.5` failed ParseInt and errored instead of falling back to the float member. - `str | int` resolves to the string member, then the schemaAcceptsNumber branch parsed `1.5` as a float even though the union only accepts an integer, sending an invalid float. Add a schemaAcceptsFloat helper that matches number members but not integer-only members, and gate float parsing behind it in both the integer branch (with a float fallback) and the schemaAcceptsNumber branch. Add unit tests for `int | float` and `str | int` unions. * docs: clarify optional vs multi-variant union required behaviour The previous table conflated plain single-type optionals with multi-variant nullable unions. A plain Optional[T] / T | None is never placed in required, while a multi-variant union like A | B | None stays required unless a default is supplied. Split these into separate rows and add a runtime caveat: an optional needs a Python-level default (via = Input(...) or default=None) so an omitted value resolves to None; a bare Optional[T] annotation raises TypeError when omitted. --- architecture/02-schema.md | 52 ++-- docs/llms.txt | 24 ++ docs/python.md | 24 ++ integration-tests/tests/union_input_cli.txtar | 45 ++++ .../tests/union_input_http.txtar | 54 ++++ .../tests/union_input_list_http.txtar | 48 ++++ pkg/predict/input.go | 121 ++++++++- pkg/predict/input_test.go | 226 ++++++++++++++++ pkg/schema/openapi.go | 54 +++- pkg/schema/openapi_test.go | 75 ++++++ pkg/schema/python/inputs.go | 33 ++- pkg/schema/python/parser_test.go | 123 +++++++++ pkg/schema/types.go | 255 ++++++++++++++++++ python/cog/_adt.py | 161 ++++++++++- python/tests/test_adt.py | 104 +++++++ python/tests/test_inspector.py | 16 ++ 16 files changed, 1372 insertions(+), 43 deletions(-) create mode 100644 integration-tests/tests/union_input_cli.txtar create mode 100644 integration-tests/tests/union_input_http.txtar create mode 100644 integration-tests/tests/union_input_list_http.txtar create mode 100644 pkg/predict/input_test.go diff --git a/architecture/02-schema.md b/architecture/02-schema.md index f038219360..6ed6d16742 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -102,15 +102,15 @@ class Runner(BaseRunner): The resolver handles local imports relative to the predictor file and project root: -| Import Style | File Resolved | -| ------------------------------------ | ----------------------------------------------------------- | -| `from output_types import X` | `/output_types.py` | -| `from .output_types import X` | `/output_types.py` | -| `from models.output import X` | `/models/output.py` | -| `from .models.output import X` | `/models/output.py` | -| `from output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from . import output_types` | `/output_types.py` (module alias tracked) | +| Import Style | File Resolved | +| ---------------------------------- | -------------------------------------------------------- | +| `from output_types import X` | `/output_types.py` | +| `from .output_types import X` | `/output_types.py` | +| `from models.output import X` | `/models/output.py` | +| `from .models.output import X` | `/models/output.py` | +| `from output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from . import output_types` | `/output_types.py` (module alias tracked) | **How it distinguishes local from external**: the resolver converts the module path to a filesystem path and checks if the file exists. If `output_types.py` exists in the project directory, it's local. If not (e.g., `from transformers import ...`), it's external. Known external packages (stdlib, torch, numpy, etc.) are skipped without a filesystem check. @@ -175,18 +175,28 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Input Types -| Python | JSON Schema | Notes | -| ------------------------------------- | ---------------------------------------------------------------- | -------------------------- | -| `str` | `{"type": "string"}` | | -| `int` | `{"type": "integer"}` | | -| `float` | `{"type": "number"}` | | -| `bool` | `{"type": "boolean"}` | | -| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | -| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | -| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | -| `list[T]` | `{"type": "array", "items": {...}}` | | -| `Optional[T]` | Type T + not in `required` | Input fields only | -| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | +| Python | JSON Schema | Notes | +| ------------------------------------- | ---------------------------------------------------------------- | --------------------------------------------------------------------- | +| `str` | `{"type": "string"}` | | +| `int` | `{"type": "integer"}` | | +| `float` | `{"type": "number"}` | | +| `bool` | `{"type": "boolean"}` | | +| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | +| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | +| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | +| `list[T]` | `{"type": "array", "items": {...}}` | | +| `Optional[T]` / `T \| None` | Type T + `nullable: true`, not in `required` | Input fields only; never required | +| `A \| B` / `Union[A, B]` | `{"anyOf": [A, B]}` | Input-only, JSON-native unions only | +| `A \| B \| None` | `{"anyOf": [A, B]}` + `nullable: true` | Multi-variant union; stays in `required` unless a default is supplied | +| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | + +Input unions are intentionally narrower than output types. Cog supports JSON-native input unions (`str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`) so request validation can happen at the HTTP boundary and Python normalisation can choose a deterministic value type. Cog rejects unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` because those cases are ambiguous for clients or runtime coercion. Output unions remain unsupported (see below). + +A plain single-type optional (`Optional[T]` or `T | None`) is **never** placed in `required`, regardless of whether a default is supplied. A multi-variant nullable union (`A | B | None`) is different: because the field carries a concrete `anyOf` value type, it stays in `required` unless a default makes it omittable. This is why the two rows above differ in their `required` behaviour. + +Nullable behaviour matches every other optional field: `nullable: true` (plus omission from `required`) means an **omitted** value falls back to the default. An **explicit** JSON `null` is still validated against the field type and is rejected at the HTTP edge, because the runtime validator does not treat OpenAPI's `nullable` keyword as an additional accepted value. "May be null" therefore means "may be omitted", not "accepts an explicit null payload". + +> **Runtime caveat:** Cog marks optionals as not-`required` in the schema, but the predictor still needs a Python-level default so the omitted value resolves to `None`. Use `value: Optional[T] = Input(...)` (the `Input(...)` supplies an implicit `None`) or `Input(default=None)`. A bare `value: Optional[T]` annotation with no `= Input(...)` generates a correct "optional" schema but raises `TypeError: missing 1 required positional argument` when the field is omitted at runtime. ### Output Types diff --git a/docs/llms.txt b/docs/llms.txt index e8e9e1f7c7..92c90a36ab 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -2652,6 +2652,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. diff --git a/docs/python.md b/docs/python.md index 5f5619a9af..2603101ff9 100644 --- a/docs/python.md +++ b/docs/python.md @@ -620,6 +620,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. diff --git a/integration-tests/tests/union_input_cli.txtar b/integration-tests/tests/union_input_cli.txtar new file mode 100644 index 0000000000..aeebd7fa4d --- /dev/null +++ b/integration-tests/tests/union_input_cli.txtar @@ -0,0 +1,45 @@ +# Test schema-directed CLI parsing for JSON-native union inputs. +# +# value: str | float (string member first) +# flipped: float | str (number member first) +# +# - "hello" parses as a string because no numeric parse succeeds +# - "1.5" parses as a float because the union accepts a number +# - "1" parses as an integer (still a valid JSON number for the union) +# +# The `flipped` field exercises the case where resolveSchemaType resolves a +# union to its numeric member first: a non-numeric value must still fall back +# to the string member instead of erroring. +# +# Note: the worker does not coerce primitives at runtime (validation happens +# at the HTTP edge against the OpenAPI schema), so the CLI must choose the +# wire type. A bare integer stays an integer; only fractional values become +# floats. This matches how a plain `float` input also receives a Python int +# for `-i num=10`. + +cog build -t $TEST_IMAGE + +cog predict $TEST_IMAGE -i value=hello -i flipped=world +stdout 'value=str:hello flipped=str:world' + +cog predict $TEST_IMAGE -i value=1.5 -i flipped=2.5 +stdout 'value=float:1.5 flipped=float:2.5' + +cog predict $TEST_IMAGE -i value=1 -i flipped=2 +stdout 'value=int:1 flipped=int:2' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, value: str | float, flipped: float | str) -> str: + return ( + f"value={type(value).__name__}:{value} " + f"flipped={type(flipped).__name__}:{flipped}" + ) diff --git a/integration-tests/tests/union_input_http.txtar b/integration-tests/tests/union_input_http.txtar new file mode 100644 index 0000000000..7ec032c426 --- /dev/null +++ b/integration-tests/tests/union_input_http.txtar @@ -0,0 +1,54 @@ +# Test JSON-native union inputs over cog serve. +# value: str | float | None = Input(default=None) +# - string accepted, returns "str:..." +# - float accepted, returns "float:..." +# - integer accepted (valid JSON number), returns "int:..." +# - omitted optional defaults to None, returns "NoneType:none" +# - explicit null is rejected, matching how every optional field behaves +# today (validation happens at the HTTP edge and the runtime validator +# does not accept explicit JSON null for a typed field) +# - bool rejected (not a member of str | float), returns a validation error + +cog build -t $TEST_IMAGE + +cog serve + +# String member +curl POST /predictions '{"input":{"value":"hello"}}' +stdout '"output":"str:hello"' + +# Float member +curl POST /predictions '{"input":{"value":1.5}}' +stdout '"output":"float:1.5"' + +# Integer is a valid JSON number for the union; passed through as int +curl POST /predictions '{"input":{"value":1}}' +stdout '"output":"int:1"' + +# Omitted optional value defaults to None +curl POST /predictions '{"input":{}}' +stdout '"output":"NoneType:none"' + +# Explicit null is rejected, consistent with all optional fields +! curl POST /predictions '{"input":{"value":null}}' + +# bool is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":true}}' + +# nested object is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":{"x":1}}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, value: str | float | None = Input(default=None)) -> str: + if value is None: + return "NoneType:none" + return f"{type(value).__name__}:{value}" diff --git a/integration-tests/tests/union_input_list_http.txtar b/integration-tests/tests/union_input_list_http.txtar new file mode 100644 index 0000000000..e36fa90be5 --- /dev/null +++ b/integration-tests/tests/union_input_list_http.txtar @@ -0,0 +1,48 @@ +# Test list JSON-native union inputs over cog serve. +# +# nums: list[int] | list[float] (required list union) +# +# - list[int] | list[float] accepts [1] and [1.5] and the empty list [] +# - list element types are validated: ["3"] and [true] are rejected +# - integer elements stay int, fractional elements are float (no runtime +# coercion: the wire type is preserved) + +cog build -t $TEST_IMAGE + +cog serve + +# Integer list element kept as int +curl POST /predictions '{"input":{"nums":[1]}}' +stdout '"output":"int:1"' + +# Float list element +curl POST /predictions '{"input":{"nums":[1.5]}}' +stdout '"output":"float:1.5"' + +# Empty list is accepted +curl POST /predictions '{"input":{"nums":[]}}' +stdout '"output":"empty"' + +# String element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":["3"]}}' + +# bool element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":[true]}}' + +# A bare scalar is not a list -> rejected +! curl POST /predictions '{"input":{"nums":1}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, nums: list[int] | list[float]) -> str: + if not nums: + return "empty" + return f"{type(nums[0]).__name__}:{nums[0]}" diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 0e17993f31..e0af7d3253 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -66,11 +66,11 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b propertiesSchemas := properties.(openapi3.Schemas) property, err := propertiesSchemas.JSONLookup(key) if err == nil { - propertySchema := property.(*openapi3.Schema) + originalSchema := property.(*openapi3.Schema) // Resolve allOf/$ref to find the actual type. // cog-schema-gen emits allOf:[{$ref: ...}] for choices/enums, // where the referenced schema has the concrete type. - propertySchema = resolveSchemaType(propertySchema) + propertySchema := resolveSchemaType(originalSchema) switch { case propertySchema.Type.Is("object"): encodedVal := json.RawMessage(val) @@ -99,6 +99,13 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b } else { value, err := strconv.ParseFloat(val, 32) if err != nil { + // For a union like `float | str` the schema + // resolves to the numeric member first; a + // non-numeric value should fall back to the + // string member instead of erroring. + if schemaAcceptsString(originalSchema) { + break + } return input, err } float := float32(value) @@ -108,11 +115,48 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b case propertySchema.Type.Is("integer"): value, err := strconv.ParseInt(val, 10, 32) if err != nil { + // For a union like `int | float` the schema + // resolves to the integer member first; a + // fractional value should fall back to the float + // member instead of erroring. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } + // See the number case above: fall back to a string + // member for unions such as `int | str`. + if schemaAcceptsString(originalSchema) { + break + } return input, err } valueInt := int32(value) input[key] = Input{Int: &valueInt} continue + case schemaAcceptsNumber(originalSchema): + // Union input (anyOf) that includes a numeric member, e.g. + // `str | float`. Parse numeric-looking values as numbers so + // the runtime receives the intended type; otherwise fall + // through to the string member below. + if value, err := strconv.ParseInt(val, 10, 32); err == nil { + valueInt := int32(value) + input[key] = Input{Int: &valueInt} + continue + } + // Only parse fractional values as float when the + // union actually accepts a float member; otherwise a + // value like `1.5` for `str | int` must fall back to + // the string member below. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } } } } @@ -188,6 +232,79 @@ func fileToDataURL(filePath string) (string, error) { return dataURL, nil } +// schemaAcceptsString reports whether the schema accepts a string value, +// including union (anyOf) members. This lets CLI `-i` parsing fall back to a +// string member when a numeric parse fails for unions such as `float | str`, +// where resolveSchemaType resolves to the numeric member. +func schemaAcceptsString(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("string") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsFloat reports whether the schema accepts a floating-point +// value, including union (anyOf) members. Unlike schemaAcceptsNumber, it does +// not match integer-only members, so CLI `-i` parsing can decide whether a +// fractional value like `1.5` is valid for unions such as `int | float` +// (accepts float) versus `str | int` (does not). +func schemaAcceptsFloat(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("number") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsNumber reports whether the schema accepts a numeric value, +// including union (anyOf) members. This lets CLI `-i` parsing coerce +// numeric-looking strings for union inputs such as `str | float`, where +// resolveSchemaType resolves to a non-numeric member. +func schemaAcceptsNumber(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && (s.Type.Is("number") || s.Type.Is("integer")) { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + return false +} + // resolveSchemaType walks through allOf/anyOf/$ref wrappers to find a schema // that has a concrete Type set. This is needed because the static schema gen // emits allOf:[{$ref: "#/components/schemas/Foo"}] for enum/choices fields, diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go new file mode 100644 index 0000000000..b383d3ff93 --- /dev/null +++ b/pkg/predict/input_test.go @@ -0,0 +1,226 @@ +package predict + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +// unionInputSchema builds an OpenAPI doc whose single input field `value` +// is a union of string and number. The variant order is configurable so we +// can exercise both `str | float` (string first) and `float | str` (number +// first), which resolve differently via resolveSchemaType. +func unionInputSchema(numberFirst bool) *openapi3.T { + stringRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + numberRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + anyOf := openapi3.SchemaRefs{&stringRef, &numberRef} + if numberFirst { + anyOf = openapi3.SchemaRefs{&numberRef, &stringRef} + } + valueSchema := &openapi3.Schema{AnyOf: anyOf} + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: valueSchema}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionParsesNumber(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + numberFirst bool + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // str | float (string member first) + {name: "str|float integer", val: "1", wantInt: ptrI32(1)}, + {name: "str|float float", val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "str|float string", val: "hello", wantStr: ptrStr("hello")}, + // float | str (number member first) -- must still fall back to string + {name: "float|str integer", numberFirst: true, val: "1", wantInt: ptrI32(1)}, + {name: "float|str float", numberFirst: true, val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "float|str string", numberFirst: true, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchema(tt.numberFirst) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int) + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float) + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String) + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +// unionInputSchemaOf builds an OpenAPI doc whose single input field `value` +// is a union (anyOf) of the given JSON Schema types, in the given order. +func unionInputSchemaOf(types ...string) *openapi3.T { + anyOf := make(openapi3.SchemaRefs, len(types)) + for i, t := range types { + anyOf[i] = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{t}}} + } + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: &openapi3.Schema{AnyOf: anyOf}}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionIntFloatAndStrInt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + types []string + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // int | float: integer member resolves first; a fractional value must + // fall back to the float member instead of erroring. + {name: "int|float integer", types: []string{"integer", "number"}, val: "1", wantInt: ptrI32(1)}, + {name: "int|float fractional", types: []string{"integer", "number"}, val: "1.5", wantFlt: ptrF32(1.5)}, + // str | int: string resolves first; a fractional value is not valid for + // the integer member and must fall back to the string member. + {name: "str|int integer", types: []string{"string", "integer"}, val: "1", wantInt: ptrI32(1)}, + {name: "str|int fractional", types: []string{"string", "integer"}, val: "1.5", wantStr: ptrStr("1.5")}, + {name: "str|int string", types: []string{"string", "integer"}, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchemaOf(tt.types...) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int, "expected int") + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float, "expected float") + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String, "expected string") + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +func TestSchemaAcceptsNumber(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsNumber(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsNumber(union)) + + stringOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}}, + }, + } + require.False(t, schemaAcceptsNumber(stringOnlyUnion)) +} + +func TestSchemaAcceptsString(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsString(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + } + require.True(t, schemaAcceptsString(union)) + + numericOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsString(numericOnlyUnion)) +} + +func TestSchemaAcceptsFloat(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsFloat(nil)) + + intFloatUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsFloat(intFloatUnion)) + + strIntUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsFloat(strIntUnion)) +} + +func ptrI32(v int32) *int32 { return &v } +func ptrF32(v float32) *float32 { return &v } +func ptrStr(v string) *string { return &v } diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index f15f4471ca..41f604e462 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -268,6 +268,52 @@ type enumSchema struct { schema map[string]any } +func inputTypeJSONSchema(it InputType) map[string]any { + var schema map[string]any + switch it.Kind { + case InputKindPrimitive: + schema = it.Primitive.JSONType() + case InputKindAny: + schema = TypeAny.JSONType() + case InputKindArray: + items := TypeAny.JSONType() + if it.Elem != nil { + items = inputTypeJSONSchema(*it.Elem) + } + schema = map[string]any{ + "type": "array", + "items": items, + } + case InputKindUnion: + variants := make([]any, len(it.Variants)) + for i, variant := range it.Variants { + variantSchema := inputTypeJSONSchema(variant) + if it.Nullable { + variantSchema["nullable"] = true + } + variants[i] = variantSchema + } + // A nullable union is represented with OpenAPI's `nullable` keyword + // (set below), matching how plain optional fields behave: an omitted + // value yields the default, while explicit JSON `null` is validated + // against the field type just like any other optional input. + schema = map[string]any{"anyOf": variants} + default: + schema = TypeAny.JSONType() + } + if it.Nullable { + schema["nullable"] = true + } + return schema +} + +func inputSchemaForField(field InputField) map[string]any { + if field.InputType != nil { + return inputTypeJSONSchema(*field.InputType) + } + return field.FieldType.JSONType() +} + // buildInputSchema builds the Input schema object and any enum schemas for choices. func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { properties := newOrderedMapAny() @@ -314,7 +360,7 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } else { // Regular field — inline type prop["title"] = TitleCase(name) - maps.Copy(prop, field.FieldType.JSONType()) + maps.Copy(prop, inputSchemaForField(field)) } // Determine effective default. A default of None on a non-nullable @@ -332,6 +378,9 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { // `Optional[Secret] = Input(default=None)`. See // TestNoneDefaultOnBareSecretIsOptional for the regression case. isNullable := field.FieldType.Repetition == Optional || field.FieldType.Repetition == OptionalRepeated + if field.InputType != nil && field.InputType.Nullable { + isNullable = true + } if field.FieldType.Primitive == TypeSecret && field.FieldType.Repetition == Required && field.Default != nil && field.Default.Kind == DefaultNone { @@ -343,7 +392,8 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } // Required? - if !hasEffectiveDefault && (field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { + isUnionInput := field.InputType != nil && field.InputType.Kind == InputKindUnion + if !hasEffectiveDefault && (isUnionInput || field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { required = append(required, name) } diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index d9b528bb9f..57669b3e6c 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -40,6 +40,25 @@ func parseSpec(t *testing.T, info *PredictorInfo) map[string]any { return spec } +func extractInputProperty(t *testing.T, raw []byte, name string) string { + t.Helper() + var doc map[string]any + require.NoError(t, json.Unmarshal(raw, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties[name] + require.True(t, ok) + out, err := json.Marshal(prop) + require.NoError(t, err) + return string(out) +} + func getPath(m map[string]any, keys ...string) any { var cur any = m for _, k := range keys { @@ -534,6 +553,62 @@ func TestInputOptionalRepeatedType(t *testing.T) { assert.Nil(t, inputSchema["required"]) } +func TestOpenAPIUnionInputStringFloat(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: ptr(InputUnionOf( + InputPrimitive(TypeString), + InputPrimitive(TypeFloat), + )), + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + require.JSONEq(t, `{"anyOf":[{"type":"string"},{"type":"number"}],"title":"Value","x-order":0}`, extractInputProperty(t, out, "value")) +} + +func TestOpenAPIRequiredNullableUnionInput(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + it := InputUnionOf(InputPrimitive(TypeString), InputPrimitive(TypeFloat)) + it.Nullable = true + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: &it, + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + input := doc["components"].(map[string]any)["schemas"].(map[string]any)["Input"].(map[string]any) + require.Equal(t, []any{"value"}, input["required"]) + prop := input["properties"].(map[string]any)["value"].(map[string]any) + require.JSONEq(t, `{ + "anyOf":[ + {"nullable":true,"type":"string"}, + {"nullable":true,"type":"number"} + ], + "nullable":true, + "title":"Value", + "x-order":0 + }`, extractInputProperty(t, out, "value")) + require.Equal(t, true, prop["nullable"]) +} + // --------------------------------------------------------------------------- // Tests: Choices / Enums // --------------------------------------------------------------------------- diff --git a/pkg/schema/python/inputs.go b/pkg/schema/python/inputs.go index 7275eed636..06669f089d 100644 --- a/pkg/schema/python/inputs.go +++ b/pkg/schema/python/inputs.go @@ -262,16 +262,17 @@ func typedParameterParts(node *sitter.Node, source []byte) (string, *sitter.Node return name, typeNode } -func inputField(name string, order int, fieldType schema.FieldType) schema.InputField { +func inputField(name string, order int, inputType schema.InputType, fieldType schema.FieldType) schema.InputField { return schema.InputField{ Name: name, Order: order, FieldType: fieldType, + InputType: &inputType, } } -func inputFieldWithInfo(name string, order int, fieldType schema.FieldType, info inputCallInfo) schema.InputField { - field := inputField(name, order, fieldType) +func inputFieldWithInfo(name string, order int, inputType schema.InputType, fieldType schema.FieldType, info inputCallInfo) schema.InputField { + field := inputField(name, order, inputType, fieldType) field.Default = info.Default field.Description = info.Description field.GE = info.GE @@ -293,12 +294,12 @@ func firstParamIsSelf(params *sitter.Node, source []byte) bool { return false } -func resolveParameterFieldType(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.FieldType, error) { +func resolveParameterInputTypes(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.InputType, schema.FieldType, error) { typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { - return schema.FieldType{}, err + return schema.InputType{}, schema.FieldType{}, err } - return schema.ResolveFieldType(typeAnn, ctx.imports, ctx.typedDicts) + return schema.ResolveInputType(typeAnn, ctx.imports, ctx.typedDicts) } func extractInputs( @@ -362,12 +363,13 @@ func parseTypedParameter(node *sitter.Node, source []byte, order int, ctx *input return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx *inputParseContext) (schema.InputField, error) { @@ -382,7 +384,7 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } @@ -396,19 +398,21 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx if err != nil { return schema.InputField{}, err } - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 2. Reference to Input() via class attribute or static method if info, ok := resolveInputReference(valNode, source, ctx.registry); ok { - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 3. Plain default — must be statically resolvable if def, ok := resolveDefaultExpr(valNode, source, ctx.scope); ok { - field := inputField(name, order, fieldType) + field := inputField(name, order, inputType, fieldType) field.Default = &def - return field, nil + return field, schema.ValidateInputField(field) } // Can't resolve — hard error @@ -418,7 +422,8 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx } // No default — required parameter - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 1d26dace12..de0a71f267 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -404,6 +404,129 @@ class Predictor(BasePredictor): require.Equal(t, schema.TypeString, name.FieldType.Primitive) } +func TestOptionalInputOpenAPINotRequired(t *testing.T) { + source := []byte(` +from typing import Optional + +class Predictor: + def predict(self, value: Optional[str]) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + _, hasRequired := input["required"] + require.False(t, hasRequired) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) +} + +func TestUnionInputStringFloat(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float) -> str: + return str(value) +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.Len(t, value.InputType.Variants, 2) + require.Equal(t, schema.TypeString, value.InputType.Variants[0].Primitive) + require.Equal(t, schema.TypeFloat, value.InputType.Variants[1].Primitive) + require.False(t, value.InputType.Nullable) +} + +func TestUnionInputStringFloatNone(t *testing.T) { + source := []byte(` +from cog import Input + +class Predictor: + def predict(self, value: str | float | None = Input(default=None)) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.True(t, value.InputType.Nullable) + require.NotNil(t, value.Default) + require.Equal(t, schema.DefaultNone, value.Default.Kind) +} + +func TestUnionInputNullableWithoutDefaultOpenAPI(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float | None) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + require.Equal(t, []any{"value"}, input["required"]) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) + require.Equal(t, []any{ + map[string]any{"nullable": true, "type": "string"}, + map[string]any{"nullable": true, "type": "number"}, + }, prop["anyOf"]) +} + +func TestUnionInputRejectsPathString(t *testing.T) { + source := []byte(` +from cog import Path + +class Predictor: + def predict(self, value: Path | str) -> str: + return "ok" +`) + + _, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.Error(t, err) + require.Contains(t, err.Error(), "Path") + require.Contains(t, err.Error(), "union") +} + // --------------------------------------------------------------------------- // List inputs // --------------------------------------------------------------------------- diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 4013424a69..0d86535801 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -98,6 +98,48 @@ type FieldType struct { Repetition Repetition } +// InputTypeKind tags the recursive input type representation. +type InputTypeKind int + +const ( + InputKindPrimitive InputTypeKind = iota + InputKindAny + InputKindArray + InputKindUnion +) + +// InputType represents JSON-native input types, including unions. +type InputType struct { + Kind InputTypeKind + Primitive PrimitiveType + Elem *InputType + Variants []InputType + Nullable bool +} + +// InputPrimitive creates a primitive input type. +func InputPrimitive(primitive PrimitiveType) InputType { + if primitive == TypeAny { + return InputAnyType() + } + return InputType{Kind: InputKindPrimitive, Primitive: primitive} +} + +// InputAnyType creates an opaque JSON input type. +func InputAnyType() InputType { + return InputType{Kind: InputKindAny, Primitive: TypeAny} +} + +// InputArrayOf creates an array input type. +func InputArrayOf(elem InputType) InputType { + return InputType{Kind: InputKindArray, Elem: &elem} +} + +// InputUnionOf creates a union input type. +func InputUnionOf(variants ...InputType) InputType { + return InputType{Kind: InputKindUnion, Variants: variants} +} + // JSONType returns the JSON Schema fragment for this field type. func (ft FieldType) JSONType() map[string]any { if ft.Repetition == Repeated || ft.Repetition == OptionalRepeated { @@ -179,6 +221,7 @@ type InputField struct { Name string Order int FieldType FieldType + InputType *InputType Default *DefaultValue Description *string GE *float64 @@ -195,6 +238,16 @@ func (f *InputField) IsRequired() bool { return f.Default == nil && (f.FieldType.Repetition == Required || f.FieldType.Repetition == Repeated) } +// ValidateInputField checks combinations unsupported by the static input model. +func ValidateInputField(field InputField) error { + if field.InputType != nil && field.InputType.Kind == InputKindUnion { + if len(field.Choices) > 0 || field.GE != nil || field.LE != nil || field.MinLength != nil || field.MaxLength != nil || field.Regex != nil { + return errUnsupportedType("constraints and choices are not supported on union inputs") + } + } + return nil +} + // PredictorInfo is the top-level extraction result. type PredictorInfo struct { Inputs *OrderedMap[string, InputField] @@ -434,6 +487,208 @@ func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[str return FieldType{}, errUnsupportedType("unknown type annotation") } +// ResolveInputType resolves a TypeAnnotation into the recursive input type model +// and the legacy FieldType compatibility layer. +func ResolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, FieldType, error) { + inputType, err := resolveInputType(ann, ctx, typedDicts) + if err != nil { + return InputType{}, FieldType{}, err + } + return inputType, fieldTypeFromInputType(inputType), nil +} + +func resolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + if inner, ok := unwrapOpaqueAnnotated(ann, ctx); ok { + return inputTypeFromFieldType(opaqueFieldType(inner, ctx)), nil + } + + switch ann.Kind { + case TypeAnnotSimple: + name := ann.Name + if typedDicts[name] { + return InputAnyType(), nil + } + qualifiedEntry := ImportEntry{} + if resolved, entry, ok := ctx.ResolveQualifiedName(name); ok { + name = resolved + qualifiedEntry = entry + if typedDicts[entry.Original+"."+name] { + return InputAnyType(), nil + } + } + if typedDicts[name] { + return InputAnyType(), nil + } + if name == "dict" || name == "Dict" { + return InputAnyType(), nil + } + prim, ok := PrimitiveFromName(name) + if !ok { + if qualifiedEntry.Module != "" { + return InputType{}, errUnresolvableImportedType(name, qualifiedEntry.Module) + } + if entry, imported := ctx.Names.Get(name); imported { + return InputType{}, errUnresolvableImportedType(name, entry.Module) + } + return InputType{}, errUnsupportedType(name) + } + return InputPrimitive(prim), nil + + case TypeAnnotGeneric: + outer := ann.Name + if resolved, _, ok := ctx.ResolveQualifiedName(outer); ok { + outer = resolved + } + if outer == "dict" || outer == "Dict" { + return InputAnyType(), nil + } + if outer == "Optional" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("Optional expects exactly 1 type argument, got %d", len(ann.Args))) + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + inner.Nullable = true + return inner, nil + } + if outer == "Union" { + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + if ctx.isAnnotated(ann.Name) { + if len(ann.Args) == 0 { + return InputType{}, errUnsupportedType("Annotated expects at least 1 type argument") + } + return resolveInputType(ann.Args[0], ctx, typedDicts) + } + if outer == "List" || outer == "list" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("List expects exactly 1 type argument, got %d", len(ann.Args))) + } + if opaqueInner, ok := unwrapOpaqueAnnotated(ann.Args[0], ctx); ok { + inner := inputTypeFromFieldType(opaqueFieldType(opaqueInner, ctx)) + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + return InputType{}, errUnsupportedType(fmt.Sprintf("%s[...] is not a supported input type", outer)) + + case TypeAnnotUnion: + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + return InputType{}, errUnsupportedType("unknown type annotation") +} + +func resolveInputUnion(args []TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + variants := make([]InputType, 0, len(args)) + nullable := false + for _, arg := range args { + if arg.Kind == TypeAnnotSimple && arg.Name == "None" { + nullable = true + continue + } + if arg.Kind == TypeAnnotUnion || (arg.Kind == TypeAnnotGeneric && arg.Name == "Union") { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variant, err := resolveInputType(arg, ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if variant.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variants = append(variants, variant) + } + + if len(variants) == 0 { + return InputType{}, errUnsupportedType("union inputs must include at least one non-None type") + } + if len(variants) == 1 { + variant := variants[0] + variant.Nullable = variant.Nullable || nullable + return variant, nil + } + for _, variant := range variants { + if err := validateUnionVariant(variant); err != nil { + return InputType{}, err + } + } + + union := InputUnionOf(variants...) + union.Nullable = nullable + return union, nil +} + +func validateUnionVariant(inputType InputType) error { + if inputType.Nullable { + return errUnsupportedType("nested nullable variants are not supported in union inputs") + } + switch inputType.Kind { + case InputKindPrimitive: + if inputType.Primitive == TypePath || inputType.Primitive == TypeFile || inputType.Primitive == TypeSecret { + return errUnsupportedType(fmt.Sprintf("%s is not supported in union inputs", inputType.Primitive)) + } + case InputKindArray: + if inputType.Elem != nil { + return validateUnionVariant(*inputType.Elem) + } + case InputKindUnion: + return errUnsupportedType("nested union inputs are not supported") + } + return nil +} + +func inputTypeFromFieldType(fieldType FieldType) InputType { + var inputType InputType + if fieldType.Primitive == TypeAny { + inputType = InputAnyType() + } else { + inputType = InputPrimitive(fieldType.Primitive) + } + if fieldType.Repetition == Repeated || fieldType.Repetition == OptionalRepeated { + inputType = InputArrayOf(inputType) + } + if fieldType.Repetition == Optional || fieldType.Repetition == OptionalRepeated { + inputType.Nullable = true + } + return inputType +} + +func fieldTypeFromInputType(inputType InputType) FieldType { + repetition := Required + if inputType.Nullable { + repetition = Optional + } + switch inputType.Kind { + case InputKindPrimitive: + return FieldType{Primitive: inputType.Primitive, Repetition: repetition} + case InputKindArray: + arrayRepetition := Repeated + if inputType.Nullable { + arrayRepetition = OptionalRepeated + } + if inputType.Elem != nil && inputType.Elem.Kind == InputKindPrimitive { + return FieldType{Primitive: inputType.Elem.Primitive, Repetition: arrayRepetition} + } + return FieldType{Primitive: TypeAny, Repetition: arrayRepetition} + case InputKindAny, InputKindUnion: + return FieldType{Primitive: TypeAny, Repetition: repetition} + default: + return FieldType{Primitive: TypeAny, Repetition: repetition} + } +} + func unwrapOpaqueAnnotated(ann TypeAnnotation, ctx *ImportContext) (TypeAnnotation, bool) { if ann.Kind != TypeAnnotGeneric || !ctx.isAnnotated(ann.Name) || len(ann.Args) < 2 { return ann, false diff --git a/python/cog/_adt.py b/python/cog/_adt.py index f40f870e47..3fea80cc27 100644 --- a/python/cog/_adt.py +++ b/python/cog/_adt.py @@ -37,6 +37,10 @@ def _is_union(tpe: type) -> bool: return False +def _is_none_type(tpe: Any) -> bool: + return tpe is None or tpe is type(None) + + def _is_dict_like(tpe: Any) -> bool: """Check if a type should be treated like a dict, including TypedDict.""" if tpe is dict: @@ -52,7 +56,7 @@ def _is_dict_like(tpe: Any) -> bool: pass is_typeddict = getattr(typing, "is_typeddict", None) - return callable(is_typeddict) and is_typeddict(tpe) + return bool(callable(is_typeddict) and is_typeddict(tpe)) def _unwrap_opaque(tpe: Any) -> tuple[Any, bool]: @@ -249,6 +253,90 @@ class Repetition(Enum): OPTIONAL_REPEATED = 4 # list[X] | None +def _is_supported_union_variant(ft: "FieldType") -> bool: + if ft.union_variants is not None: + return False + if ft.primitive in { + PrimitiveType.PATH, + PrimitiveType.FILE, + PrimitiveType.SECRET, + PrimitiveType.CUSTOM, + }: + return False + return ft.repetition in {Repetition.REQUIRED, Repetition.REPEATED} + + +def _is_exact_union_match(value: Any, ft: "FieldType") -> bool: + if ft.repetition is Repetition.REPEATED: + return isinstance(value, list) + if ft.repetition is not Repetition.REQUIRED: + return False + if ft.primitive is PrimitiveType.BOOL: + return isinstance(value, bool) + if ft.primitive is PrimitiveType.INTEGER: + return isinstance(value, int) and not isinstance(value, bool) + if ft.primitive is PrimitiveType.FLOAT: + return isinstance(value, float) + if ft.primitive is PrimitiveType.STRING: + return isinstance(value, str) + if ft.primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_primitive_accepts_value(value: Any, primitive: PrimitiveType) -> bool: + if primitive is PrimitiveType.BOOL: + return type(value) is bool + if primitive is PrimitiveType.INTEGER: + return type(value) is int + if primitive is PrimitiveType.FLOAT: + return type(value) is float or type(value) is int + if primitive is PrimitiveType.STRING: + return type(value) is str + if primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_variant_accepts_value(value: Any, variant: "FieldType") -> bool: + if variant.repetition is Repetition.REPEATED: + return isinstance(value, list) and all( + _union_primitive_accepts_value(element, variant.primitive) + for element in value + ) + if variant.repetition is not Repetition.REQUIRED: + return False + return _union_primitive_accepts_value(value, variant.primitive) + + +def _union_variant_priority(ft: "FieldType") -> int: + primitive_priority = { + PrimitiveType.BOOL: 0, + PrimitiveType.INTEGER: 1, + PrimitiveType.FLOAT: 2, + PrimitiveType.STRING: 3, + PrimitiveType.ANY: 4, + } + repetition_offset = 10 if ft.repetition is Repetition.REPEATED else 0 + return repetition_offset + primitive_priority.get(ft.primitive, 100) + + +def _ordered_union_variants( + value: Any, variants: List["FieldType"] +) -> List["FieldType"]: + return [ + variant + for _, variant in sorted( + enumerate(variants), + key=lambda item: ( + not _is_exact_union_match(value, item[1]), + _union_variant_priority(item[1]), + item[0], + ), + ) + ] + + @dataclass(frozen=True) class FieldType: """Type information for an input/output field.""" @@ -256,6 +344,7 @@ class FieldType: primitive: PrimitiveType repetition: Repetition coder: Optional[Coder] + union_variants: Optional[List["FieldType"]] = None @staticmethod def from_type(tpe: type) -> "FieldType": @@ -317,9 +406,43 @@ def from_type(tpe: type) -> "FieldType": elif _is_union(tpe): t_args = typing.get_args(tpe) - if not (len(t_args) == 2 and type(None) in t_args): - raise ValueError(f"unsupported union type {tpe}") - elem_t = t_args[0] if t_args[1] is type(None) else t_args[1] + has_none = any(_is_none_type(arg) for arg in t_args) + non_none_args = [arg for arg in t_args if not _is_none_type(arg)] + + if len(non_none_args) != 1: + repetition = Repetition.OPTIONAL if has_none else Repetition.REQUIRED + variants = [] + for arg in non_none_args: + try: + variant = FieldType.from_type(arg) + except ValueError as exc: + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) from exc + if not _is_supported_union_variant(variant): + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) + variants.append(variant) + return FieldType( + primitive=PrimitiveType.ANY, + repetition=repetition, + coder=None, + union_variants=variants, + ) + + if not has_none: + elem_t = non_none_args[0] + repetition = Repetition.REQUIRED + cog_t = PrimitiveType.from_type(elem_t) + coder = None + if cog_t is PrimitiveType.CUSTOM: + coder = Coder.lookup(elem_t) + if coder is None: + raise ValueError(f"unsupported Cog type {_type_name(elem_t)}") + return FieldType(primitive=cog_t, repetition=repetition, coder=coder) + + elem_t = non_none_args[0] inner, elem_is_opaque = _unwrap_opaque(elem_t) if elem_is_opaque: return _opaque_field_type(inner | None) @@ -384,6 +507,20 @@ def from_type(tpe: type) -> "FieldType": def normalize(self, value: Any) -> Any: """Normalize a value according to this field type.""" + if self.union_variants is not None: + if value is None: + if self.repetition is Repetition.OPTIONAL: + return None + raise ValueError("missing value for required union field") + for variant in _ordered_union_variants(value, self.union_variants): + if not _union_variant_accepts_value(value, variant): + continue + try: + return variant.normalize(value) + except (TypeError, ValueError): + pass + raise ValueError(f"failed to normalize value as {self.python_type_name()}") + if self.repetition is Repetition.REQUIRED: return self.primitive.normalize(value) elif self.repetition is Repetition.OPTIONAL: @@ -398,6 +535,14 @@ def normalize(self, value: Any) -> Any: def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this field.""" + if self.union_variants is not None: + jt: Dict[str, Any] = { + "anyOf": [variant.json_type() for variant in self.union_variants] + } + if self.repetition is Repetition.OPTIONAL: + jt["nullable"] = True + return jt + if self.repetition in (Repetition.REPEATED, Repetition.OPTIONAL_REPEATED): return {"type": "array", "items": self.primitive.json_type()} return self.primitive.json_type() @@ -428,6 +573,14 @@ def json_decode(self, value: Any) -> Any: def python_type_name(self) -> str: """Get the Python type name for this field.""" + if self.union_variants is not None: + name = " | ".join( + variant.python_type_name() for variant in self.union_variants + ) + if self.repetition is Repetition.OPTIONAL: + return f"Optional[{name}]" + return name + if self.repetition is Repetition.REQUIRED: return self.primitive.python_type_name() elif self.repetition is Repetition.OPTIONAL: diff --git a/python/tests/test_adt.py b/python/tests/test_adt.py index ebffeb7d72..22800bf36b 100644 --- a/python/tests/test_adt.py +++ b/python/tests/test_adt.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Dict, List, Optional, TypedDict +import pytest from typing_extensions import TypedDict as ExtensionsTypedDict from cog import Opaque @@ -322,3 +323,106 @@ def test_optional_str(self) -> None: ft = FieldType.from_type(Optional[str]) assert ft.primitive is PrimitiveType.STRING assert ft.repetition is Repetition.OPTIONAL + + +class TestUnionInputTypes: + def test_union_str_float_field_type(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.primitive is PrimitiveType.ANY + assert ft.repetition is Repetition.REQUIRED + assert ft.union_variants is not None + assert [v.primitive for v in ft.union_variants] == [ + PrimitiveType.STRING, + PrimitiveType.FLOAT, + ] + + def test_union_str_float_none_field_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.union_variants is not None + + def test_union_int_float_prefers_int(self) -> None: + ft = FieldType.from_type(int | float) + assert ft.normalize(1) == 1 + assert isinstance(ft.normalize(1), int) + + def test_union_bool_int_prefers_bool(self) -> None: + ft = FieldType.from_type(bool | int) + value = ft.normalize(True) + assert value is True + + def test_union_int_float_rejects_bool(self) -> None: + ft = FieldType.from_type(int | float) + with pytest.raises(ValueError): + ft.normalize(True) + + def test_union_str_bool_rejects_int(self) -> None: + ft = FieldType.from_type(str | bool) + with pytest.raises(ValueError): + ft.normalize(1) + + def test_union_str_dict_rejects_scalar(self) -> None: + ft = FieldType.from_type(str | dict) + with pytest.raises(ValueError): + ft.normalize(123) + + def test_union_list_int_float_rejects_bool_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize([True]) + + def test_union_list_int_float_rejects_string_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize(["3"]) + + def test_union_list_int_float_accepts_numeric_elements(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([1]) == [1] + assert isinstance(ft.normalize([1])[0], int) + assert ft.normalize([1.5]) == [1.5] + + def test_union_optional_normalize_none(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.normalize(None) is None + + def test_union_required_normalize_none_raises(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.repetition is Repetition.REQUIRED + with pytest.raises(ValueError): + ft.normalize(None) + + def test_union_required_json_type_omits_nullable(self) -> None: + ft = FieldType.from_type(int | str) + assert ft.json_type() == { + "anyOf": [{"type": "integer"}, {"type": "string"}], + } + + def test_union_list_int_float_accepts_empty_list(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([]) == [] + + def test_union_mixed_scalar_and_list(self) -> None: + ft = FieldType.from_type(list[int] | int) + assert ft.normalize(5) == 5 + assert isinstance(ft.normalize(5), int) + assert ft.normalize([5]) == [5] + + def test_union_str_float_none_json_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.json_type() == { + "anyOf": [{"type": "string"}, {"type": "number"}], + "nullable": True, + } + + def test_union_rejects_path_string(self) -> None: + from cog import Path + + try: + FieldType.from_type(Path | str) + except ValueError as exc: + assert "Path" in str(exc) + assert "union" in str(exc) + else: + raise AssertionError("Expected ValueError for Path | str") diff --git a/python/tests/test_inspector.py b/python/tests/test_inspector.py index 7597a783c2..bf1307aae9 100644 --- a/python/tests/test_inspector.py +++ b/python/tests/test_inspector.py @@ -168,6 +168,22 @@ def predict(self, value: Annotated[ExternalObject, Opaque]) -> str: assert field.type.repetition is adt.Repetition.REQUIRED +def test_inspector_supports_union_input() -> None: + class Predictor: + def predict(self, value: str | float) -> str: + return str(value) + + info = _create_predictor_info( + "predict", "Predictor", Predictor.predict, "predict", True + ) + field = info.inputs["value"] + assert field.type.union_variants is not None + assert [v.primitive for v in field.type.union_variants] == [ + adt.PrimitiveType.STRING, + adt.PrimitiveType.FLOAT, + ] + + def test_inspector_preserves_opaque_list_input_metadata() -> None: class Predictor: def predict(self, value: Annotated[List[ExternalObject], Opaque]) -> str: From 2ab4cc4c13e600ed5eac9485dea5d2b2e6cde355 Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Fri, 5 Jun 2026 10:18:51 -0400 Subject: [PATCH 15/19] test: fuzz schema generation for union inputs + docs fixup (#3049) * test: add fuzzing for input type resolution and OpenAPI generation Add fuzz coverage for the schema-gen codepaths exercised by union inputs, which previously had no fuzzing and no validity oracle: - FuzzResolveInputType: feeds arbitrary TypeAnnotation trees through ResolveInputType, then validates that any successfully resolved input type generates an OpenAPI document the build-time validator accepts. - FuzzInputTypeJSONSchema: builds arbitrary InputType trees directly (reaching shapes the resolver never produces) and validates the generated OpenAPI document. Both use an assertValidOpenAPI oracle (the same kin-openapi validator as writeAndValidateSchema), so a union shape that resolves cleanly but emits an invalid schema (e.g. an unsupported `type: null` branch) fails the fuzzer rather than surfacing as a confusing user build error. Make the test:fuzz task auto-discover every Fuzz* target via `go test -list` instead of hardcoding names, so new fuzz tests are picked up automatically. This also surfaced targets the hardcoded list missed (pkg/config's three targets and two parser helpers): 11 total vs the 4 previously run. Manage jq (used for discovery) as a mise tool and simplify the CI fuzz-go job to call the task. * docs: correct union input type support in python.md The union-inputs feature (#3048) added input union support, but the "Type limitations" section still claimed only Optional[T] was supported, contradicting the new Union section. Scope the limitation correctly: output unions remain unsupported, JSON-native input unions are supported, and input unions of Path/File/Secret/custom-coder/BaseModel members fail at build. Also add the missing Union table-of-contents entry and regenerate llms.txt. * test: use testify require in assertValidOpenAPI Replace raw t.Fatalf with require.NoError per the project testing conventions (AGENTS.md). --- .github/workflows/ci.yaml | 14 +- docs/llms.txt | 4 +- docs/python.md | 4 +- mise.lock | 39 +++++ mise.toml | 41 +++-- pkg/schema/input_type_fuzz_test.go | 244 +++++++++++++++++++++++++++++ 6 files changed, 322 insertions(+), 24 deletions(-) create mode 100644 pkg/schema/input_type_fuzz_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d3f82eadb0..025dfebda0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -440,7 +440,9 @@ jobs: fuzz-go: name: Fuzz Go runs-on: ubuntu-latest - timeout-minutes: 10 + # test:fuzz auto-discovers every Fuzz* target and runs each for FUZZTIME + # (30s default), so this scales as targets are added. + timeout-minutes: 15 env: CGO_ENABLED: "1" steps: @@ -449,14 +451,8 @@ jobs: with: version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - - name: Fuzz schema type resolution - run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime=30s - - name: Fuzz JSON schema generation - run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime=30s - - name: Fuzz Python parser - run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime=30s - - name: Fuzz type annotation parsing - run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime=30s + - name: Fuzz all targets (auto-discovered) + run: mise run test:fuzz test-rust: name: Test Rust diff --git a/docs/llms.txt b/docs/llms.txt index 92c90a36ab..2044b2f6d7 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -2070,6 +2070,7 @@ This document defines the API of the `cog` Python module, which is used to defin - [`cog.Secret`](#cogsecret) - [Wrapper types](#wrapper-types) - [`Optional`](#optional) + - [`Union`](#union) - [`list`](#list) - [`dict`](#dict) - [`cog.Opaque`](#cogopaque) @@ -2837,7 +2838,8 @@ Fields in a `BaseModel` output support these types: The following type patterns are **not** supported: - **Nested generics**: `list[list[str]]`, `list[Optional[str]]`, `Optional[list[str]]` are not supported. -- **Union types beyond Optional**: `str | int`, `Union[str, int, None]` — only `Optional[T]` (i.e. `T | None`) is supported. +- **Output union types beyond Optional**: union _return_ types and `BaseModel` union fields are not supported. Input unions of JSON-native types (`str | int`, `str | float | None`, etc.) _are_ supported — see [`Union`](#union). +- **Input unions of non-JSON-native types**: input unions involving `Path`, `File`, `Secret`, custom coders, or `BaseModel` (e.g. `Path | str`) are not supported and fail at build time. - **`Optional` as a top-level return type**: `-> Optional[str]` is not allowed. Use a `BaseModel` with optional fields instead. - **Nested `BaseModel` fields**: A `BaseModel` field typed as another `BaseModel` is not supported in Cog's type system for schema generation. - **Tuple, Set, or other collection types**: Only `list` and `dict` are supported as collection types. diff --git a/docs/python.md b/docs/python.md index 2603101ff9..fe9307aa3b 100644 --- a/docs/python.md +++ b/docs/python.md @@ -38,6 +38,7 @@ This document defines the API of the `cog` Python module, which is used to defin - [`cog.Secret`](#cogsecret) - [Wrapper types](#wrapper-types) - [`Optional`](#optional) + - [`Union`](#union) - [`list`](#list) - [`dict`](#dict) - [`cog.Opaque`](#cogopaque) @@ -805,7 +806,8 @@ Fields in a `BaseModel` output support these types: The following type patterns are **not** supported: - **Nested generics**: `list[list[str]]`, `list[Optional[str]]`, `Optional[list[str]]` are not supported. -- **Union types beyond Optional**: `str | int`, `Union[str, int, None]` — only `Optional[T]` (i.e. `T | None`) is supported. +- **Output union types beyond Optional**: union _return_ types and `BaseModel` union fields are not supported. Input unions of JSON-native types (`str | int`, `str | float | None`, etc.) _are_ supported — see [`Union`](#union). +- **Input unions of non-JSON-native types**: input unions involving `Path`, `File`, `Secret`, custom coders, or `BaseModel` (e.g. `Path | str`) are not supported and fail at build time. - **`Optional` as a top-level return type**: `-> Optional[str]` is not allowed. Use a `BaseModel` with optional fields instead. - **Nested `BaseModel` fields**: A `BaseModel` field typed as another `BaseModel` is not supported in Cog's type system for schema generation. - **Tuple, Set, or other collection types**: Only `list` and `dict` are supported as collection types. diff --git a/mise.lock b/mise.lock index 5c6661efd8..3fe3fb9ec2 100644 --- a/mise.lock +++ b/mise.lock @@ -103,6 +103,45 @@ url = "https://github.com/gotestyourself/gotestsum/releases/download/v1.13.0/got checksum = "sha256:fd5a6dc69e46a0970593e70d85a7e75f16714e9c61d6d72ccc324eb82df5bb8a" url = "https://github.com/gotestyourself/gotestsum/releases/download/v1.13.0/gotestsum_1.13.0_windows_amd64.tar.gz" +[[tools."aqua:jqlang/jq"]] +version = "1.8.1" +backend = "aqua:jqlang/jq" + +[tools."aqua:jqlang/jq"."platforms.linux-arm64"] +checksum = "sha256:6bc62f25981328edd3cfcfe6fe51b073f2d7e7710d7ef7fcdac28d4e384fc3d4" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-arm64-musl"] +checksum = "sha256:6bc62f25981328edd3cfcfe6fe51b073f2d7e7710d7ef7fcdac28d4e384fc3d4" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-x64"] +checksum = "sha256:020468de7539ce70ef1bceaf7cde2e8c4f2ca6c3afb84642aabc5c97d9fc2a0d" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-x64-musl"] +checksum = "sha256:020468de7539ce70ef1bceaf7cde2e8c4f2ca6c3afb84642aabc5c97d9fc2a0d" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.macos-arm64"] +checksum = "sha256:a9fe3ea2f86dfc72f6728417521ec9067b343277152b114f4e98d8cb0e263603" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-macos-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.macos-x64"] +checksum = "sha256:e80dbe0d2a2597e3c11c404f03337b981d74b4a8504b70586c354b7697a7c27f" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-macos-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.windows-x64"] +checksum = "sha256:23cb60a1354eed6bcc8d9b9735e8c7b388cd1fdcb75726b93bc299ef22dd9334" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-windows-amd64.exe" +provenance = "github-attestations" + [[tools."aqua:mitsuhiko/insta"]] version = "1.46.0" backend = "aqua:mitsuhiko/insta" diff --git a/mise.toml b/mise.toml index f7a6fe99ab..e25f278c84 100644 --- a/mise.toml +++ b/mise.toml @@ -53,6 +53,7 @@ ruff = "0.14.13" ty = "0.0.10" "npm:prettier" = "3.6.2" "npm:markdownlint-cli2" = "0.22.0" +"aqua:jqlang/jq" = "1.8.1" "go:golang.org/x/tools/cmd/goimports" = "latest" zig = "0.15.2" @@ -325,21 +326,35 @@ depends = ["build:coglet:wheel"] run = "nox -s coglet" [tasks."test:fuzz"] -description = "Run Go fuzz tests (FUZZTIME=30s per target by default)" -run = """ +description = "Run all Go fuzz tests (auto-discovered; FUZZTIME=30s per target by default)" +run = ''' #!/usr/bin/env bash -set -e +set -euo pipefail FUZZTIME="${FUZZTIME:-30s}" -echo "Fuzzing schema type resolution ($FUZZTIME)..." -go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime="$FUZZTIME" -echo "Fuzzing JSON schema generation ($FUZZTIME)..." -go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime="$FUZZTIME" -echo "Fuzzing Python parser ($FUZZTIME)..." -go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime="$FUZZTIME" -echo "Fuzzing type annotation parsing ($FUZZTIME)..." -go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime="$FUZZTIME" -echo "All fuzz targets passed." -""" + +# Auto-discover every Fuzz* target and the package it lives in, so new fuzz +# tests are picked up automatically without editing this task. go test only +# fuzzes one target per invocation, so we still loop and run them one at a time. +# jq emits " " per line; package and target names never +# contain spaces, so `read pkg target` splits them cleanly. +count=0 +while read -r pkg target; do + [ -z "$pkg" ] && continue + echo "Fuzzing $target in $pkg ($FUZZTIME)..." + go test "$pkg" -run="^$" -fuzz="^${target}$" -fuzztime="$FUZZTIME" + count=$((count + 1)) +done < <( + go test -list "^Fuzz" -json ./... 2>/dev/null \ + | jq -r 'select(.Action=="output" and (.Output|test("^Fuzz"))) | .Package + " " + (.Output | rtrimstr("\n"))' +) + +if [ "$count" -eq 0 ]; then + echo "No fuzz targets found." >&2 + exit 1 +fi + +echo "All $count fuzz targets passed." +''' [tasks."test:integration"] description = "Run integration tests (skips slow tests by default, set SHORT=0 for full suite)" diff --git a/pkg/schema/input_type_fuzz_test.go b/pkg/schema/input_type_fuzz_test.go new file mode 100644 index 0000000000..d8ae1ab064 --- /dev/null +++ b/pkg/schema/input_type_fuzz_test.go @@ -0,0 +1,244 @@ +package schema + +import ( + "context" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +// FuzzResolveInputType builds arbitrary TypeAnnotation trees from fuzz input +// and verifies that ResolveInputType never panics. When resolution succeeds, +// it feeds the resulting InputType through OpenAPI generation and validates +// the emitted document with the same kin-openapi validator used at build time +// (writeAndValidateSchema). This is the key oracle: a union input type that +// resolves cleanly but emits an OpenAPI document the build-time validator +// rejects (e.g. an unsupported `type: null` branch) is a real bug, not just a +// panic. +func FuzzResolveInputType(f *testing.F) { + // Seed corpus — union and JSON-native input shapes, plus tricky cases. + seeds := []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "int"}, + {Kind: TypeAnnotSimple, Name: "float"}, + {Kind: TypeAnnotSimple, Name: "bool"}, + {Kind: TypeAnnotSimple, Name: "dict"}, + {Kind: TypeAnnotSimple, Name: "Any"}, + // str | float + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + // str | float | None + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "float"}, + {Kind: TypeAnnotSimple, Name: "None"}, + }}, + // str | None (single-variant collapse to nullable) + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "None"}, + }}, + // int | float + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "int"}, + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + // list[int] | list[float] + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "int"}, + }}, + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + }}, + // dict | list[dict] + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "dict"}, + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "dict"}, + }}, + }}, + // Unsupported member: Path | str (must be rejected, not panic) + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "Path"}, + {Kind: TypeAnnotSimple, Name: "str"}, + }}, + // Optional[str] + {Kind: TypeAnnotGeneric, Name: "Optional", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + }}, + } + for _, s := range seeds { + f.Add(encodeAnnotation(s)) + } + + ctx := NewImportContext() + typedDicts := map[string]bool{} + + f.Fuzz(func(t *testing.T, data []byte) { + ann, _ := decodeAnnotation(data, 0, 0) + + // Must not panic regardless of input. + it, ft, err := ResolveInputType(ann, ctx, typedDicts) + if err != nil { + return + } + + // A resolved input type must build a field that validates and emits a + // valid OpenAPI document. + field := InputField{ + Name: "value", + Order: 0, + FieldType: ft, + InputType: &it, + } + if err := ValidateInputField(field); err != nil { + return + } + + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", field) + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + if err != nil { + return + } + + // Oracle: the generated schema must be a valid OpenAPI document, the + // same check writeAndValidateSchema performs at build time. A schema + // that fails here would surface as a confusing build failure for users. + assertValidOpenAPI(t, out) + }) +} + +// FuzzInputTypeJSONSchema constructs arbitrary InputType trees directly (not via +// the annotation resolver) and ensures both the per-field JSON schema helper +// and full OpenAPI generation never panic and always emit a valid document. +// Building InputType directly reaches shapes the resolver may not produce, +// stressing inputTypeJSONSchema and buildInputSchema in isolation. +func FuzzInputTypeJSONSchema(f *testing.F) { + f.Add([]byte{0, 3}) // primitive string + f.Add([]byte{1}) // any + f.Add([]byte{2, 0, 3}) // array of string + f.Add([]byte{3, 2, 0, 3, 0, 1}) // union of string and float + f.Add([]byte{3, 2, 0, 3, 0, 1, 0xff}) // nullable union of string and float + + f.Fuzz(func(t *testing.T, data []byte) { + it, _ := decodeInputType(data, 0, 0) + + // Field with both a compat FieldType and the recursive InputType set, + // mirroring how the parser populates InputField. + field := InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: &it, + } + + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", field) + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + if err != nil { + return + } + assertValidOpenAPI(t, out) + }) +} + +// assertValidOpenAPI loads and validates a generated OpenAPI document with the +// same kin-openapi validator used by writeAndValidateSchema at build time. +// A document that fails validation is a generation bug. +func assertValidOpenAPI(t *testing.T, schemaJSON []byte) { + t.Helper() + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + doc, err := loader.LoadFromData(schemaJSON) + require.NoError(t, err, "generated schema failed to load\n%s", string(schemaJSON)) + err = doc.Validate(context.Background()) + require.NoError(t, err, "generated schema is invalid\n%s", string(schemaJSON)) +} + +// decodeInputType builds an InputType tree from bytes, mirroring the encoding +// strategy of decodeSchemaType. The final byte of a primitive/union toggles +// nullability so the fuzzer reaches both nullable and non-nullable shapes. +func decodeInputType(data []byte, offset int, depth int) (InputType, int) { + if depth > maxFuzzDepth || offset >= len(data) { + return InputPrimitive(TypeString), offset + } + + kind := InputTypeKind(data[offset] % 4) + offset++ + + switch kind { + case InputKindPrimitive: + prim := PrimitiveType(0) + if offset < len(data) { + prim = PrimitiveType(data[offset] % 9) + offset++ + } + it := InputPrimitive(prim) + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + case InputKindAny: + it := InputAnyType() + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + case InputKindArray: + elem, newOffset := decodeInputType(data, offset, depth+1) + it := InputArrayOf(elem) + if newOffset < len(data) { + if data[newOffset]%2 == 1 { + it.Nullable = true + } + newOffset++ + } + return it, newOffset + + case InputKindUnion: + numVariants := 0 + if offset < len(data) { + numVariants = int(data[offset]) % 4 // cap at 3 variants + offset++ + } + variants := make([]InputType, 0, numVariants) + for i := 0; i < numVariants && offset < len(data); i++ { + v, newOffset := decodeInputType(data, offset, depth+1) + variants = append(variants, v) + offset = newOffset + } + it := InputUnionOf(variants...) + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + default: + return InputPrimitive(TypeString), offset + } +} From d6c2b7064006b93cf4e608fd7157950bd5296417 Mon Sep 17 00:00:00 2001 From: Mark Phelps <209477+markphelps@users.noreply.github.com> Date: Fri, 5 Jun 2026 10:23:53 -0400 Subject: [PATCH 16/19] Bump version to 0.21.0-rc.3 (#3050) --- VERSION.txt | 2 +- crates/Cargo.lock | 4 ++-- crates/Cargo.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index c22d34c81e..0407900b35 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.21.0-rc.2 +0.21.0-rc.3 diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 002d9a53cf..e8b1bbb932 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -259,7 +259,7 @@ dependencies = [ [[package]] name = "coglet" -version = "0.21.0-rc.2" +version = "0.21.0-rc.3" dependencies = [ "anyhow", "async-trait", @@ -291,7 +291,7 @@ dependencies = [ [[package]] name = "coglet-python" -version = "0.21.0-rc.2" +version = "0.21.0-rc.3" dependencies = [ "async-trait", "base64", diff --git a/crates/Cargo.toml b/crates/Cargo.toml index ba042fd22a..619ee6aa35 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = ["coglet", "coglet-python"] [workspace.package] -version = "0.21.0-rc.2" +version = "0.21.0-rc.3" edition = "2024" license = "Apache-2.0" repository = "https://github.com/replicate/cog" From bf8a249939a8499027c3651f3522013d050425a2 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 8 Jun 2026 14:18:07 -0400 Subject: [PATCH 17/19] fix: hide kong predict/train/weights/debug and drop dead os.Exit - Add hidden:"" tags to predict, train, weights, and debug commands to match Cobra's Hidden: true (run stays visible). - Remove unreachable os.Exit(1) after parser.FatalIfErrorf(err), which already exits the process. - Update root help test to assert hidden commands are hidden in the Kong model rather than absent as substrings. --- cmd/cog-kong/main.go | 10 ++++------ cmd/cog-kong/main_test.go | 13 ++++++++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/cmd/cog-kong/main.go b/cmd/cog-kong/main.go index 63cf073a08..a4125b1fc5 100644 --- a/cmd/cog-kong/main.go +++ b/cmd/cog-kong/main.go @@ -29,17 +29,17 @@ type CLI struct { BaseImage BaseImageCmd `cmd:"" name:"base-image" help:"Tools for working with Cog base images."` Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` - Debug DebugCmd `cmd:"" help:"Debug Cog internals."` + Debug DebugCmd `cmd:"" hidden:"" help:"Debug Cog internals."` Doctor DoctorCmd `cmd:"" help:"Check your project for common issues and fix them (experimental)."` Exec ExecCmd `cmd:"" help:"Execute a command inside a Docker environment."` Init InitCmd `cmd:"" help:"Configure your project for use with Cog."` Login LoginCmd `cmd:"" help:"Log in to a container registry."` - Predict PredictCmd `cmd:"" help:"Run a prediction."` + Predict PredictCmd `cmd:"" hidden:"" help:"Run a prediction."` Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` RunCommand RunCmd `cmd:"" name:"run" help:"Run a prediction."` Serve ServeCmd `cmd:"" help:"Run an HTTP server."` - Train TrainCmd `cmd:"" help:"Run a training job."` - Weights WeightsCmd `cmd:"" help:"Commands for managing model weight files."` + Train TrainCmd `cmd:"" hidden:"" help:"Run a training job."` + Weights WeightsCmd `cmd:"" hidden:"" help:"Commands for managing model weight files."` } func main() { @@ -67,7 +67,6 @@ func main() { // otherwise it's a real parse error (e.g. unexpected command or flag), so print the error and exit non-zero. parser.FatalIfErrorf(err) - os.Exit(1) } if cli.Help { _ = kctx.PrintUsage(false) @@ -80,7 +79,6 @@ func main() { // command returned an error. Print and exit non-zero. if err != nil { parser.FatalIfErrorf(err) - os.Exit(1) } } diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index 175e829635..b103c5382b 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -146,8 +146,19 @@ func TestKongRootHelpParses(t *testing.T) { require.Contains(t, help, "Usage: cog [flags]") require.Contains(t, help, "build") require.Contains(t, help, "push") - require.Contains(t, help, "weights") require.NotContains(t, help, "Usage: cog default") + + // Commands hidden in Cobra must also be hidden in the Kong model so they + // don't appear in the root help output. + hiddenByName := map[string]bool{} + for _, node := range kctx.Model.Children { + if node.Hidden { + hiddenByName[node.Name] = true + } + } + for _, name := range []string{"predict", "train", "weights", "debug"} { + require.Truef(t, hiddenByName[name], "command %q should be hidden", name) + } } func TestKongRootGlobalFlagsParse(t *testing.T) { From f83efb34bbfe313608967bd271af3520cde918c0 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 8 Jun 2026 15:13:35 -0400 Subject: [PATCH 18/19] fix: close kong/cobra CLI parity gaps from review - build: add hidden --skip-schema-validation; make --timestamp build-only (split build-only flags out of shared BuildFlags so they no longer leak onto push, matching Cobra). - push: no longer exposes --timestamp/--skip-schema-validation (parity). - serve: drop --env (Cobra serve has no --env); move --env to a dedicated EnvFlag embedded only by exec/predict/run/train. - train (kong): emit the deprecation warning Cobra prints. - base-image: add the generate-matrix subcommand; extract shared RunBaseImageGenerateMatrix runner. - exec (kong): fix missing-arg error text to match Cobra's MinimumNArgs(1). - openapi-schema: accept any string (Cobra defers file-existence to build time) instead of kong existingfile validation. - Remove dead --dockerfile flag from cobra predict/exec/train/debug (RuntimeBuildOptions never reads it); keep it on build/push where it is functional. Guard checkMutuallyExclusiveFlags against absent flags. - tests: derive command/hidden parity from the real Cobra surface (NewRootCommand/NewBaseImageRootCommand) instead of a hardcoded list; assert run/visible commands are not hidden; add value-asserting tests for build-only flags, --env placement, exec arg arity, and generate-matrix. --- cmd/cog-kong/baseimage.go | 14 +++- cmd/cog-kong/build.go | 18 ++++- cmd/cog-kong/flags.go | 24 +++---- cmd/cog-kong/main_test.go | 134 +++++++++++++++++++++++++++++++++++--- cmd/cog-kong/predict.go | 5 +- cmd/cog-kong/runtime.go | 33 ++++++---- cmd/cog-kong/simple.go | 6 +- cmd/cog-kong/train.go | 8 ++- pkg/cli/baseimage.go | 100 ++++++++++++---------------- pkg/cli/build.go | 5 +- pkg/cli/debug.go | 1 - pkg/cli/exec.go | 1 - pkg/cli/predict.go | 1 - pkg/cli/train.go | 1 - 14 files changed, 243 insertions(+), 108 deletions(-) diff --git a/cmd/cog-kong/baseimage.go b/cmd/cog-kong/baseimage.go index 0e07bff4ec..09a5d5ea13 100644 --- a/cmd/cog-kong/baseimage.go +++ b/cmd/cog-kong/baseimage.go @@ -9,8 +9,9 @@ import ( // BaseImageCmd implements the experimental "cog base-image" command group. type BaseImageCmd struct { - Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Display Cog base image Dockerfile."` - Build BaseImageBuildCmd `cmd:"" help:"Build Cog base image."` + Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Display Cog base image Dockerfile."` + Build BaseImageBuildCmd `cmd:"" help:"Build Cog base image."` + GenerateMatrix BaseImageGenerateMatrixCmd `cmd:"" name:"generate-matrix" help:"Generate a matrix of Cog base image versions (JSON)."` } // baseImageVersionFlags groups the version-selecting flags shared by the @@ -60,3 +61,12 @@ type BaseImageBuildCmd struct { func (cmd *BaseImageBuildCmd) Run(ctx context.Context, dockerClient command.Command) error { return cli.RunBaseImageBuild(ctx, dockerClient, cmd.options()) } + +// BaseImageGenerateMatrixCmd implements "cog base-image generate-matrix". +type BaseImageGenerateMatrixCmd struct { + baseImageVersionFlags `embed:""` +} + +func (cmd *BaseImageGenerateMatrixCmd) Run() error { + return cli.RunBaseImageGenerateMatrix(cmd.options()) +} diff --git a/cmd/cog-kong/build.go b/cmd/cog-kong/build.go index 15b5c07031..db6fa13595 100644 --- a/cmd/cog-kong/build.go +++ b/cmd/cog-kong/build.go @@ -8,11 +8,15 @@ import ( "github.com/replicate/cog/pkg/registry" ) -// BuildCmd implements the "cog build" command. +// BuildCmd implements the "cog build" command. Timestamp and +// SkipSchemaValidation are build-only hidden flags (push does not expose them), +// matching the Cobra CLI. type BuildCmd struct { BuildFlags `embed:""` - Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."` + Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` + SkipSchemaValidation bool `name:"skip-schema-validation" hidden:"" help:"Skip OpenAPI schema generation and validation."` } // Validate is called by Kong after parsing, before Run. It replaces Cobra's PreRunE. @@ -20,11 +24,19 @@ func (cmd *BuildCmd) Validate() error { return cmd.ValidateMutualExclusivity() } +// options returns the build flags with the build-only fields applied. +func (cmd *BuildCmd) options() cli.BuildFlagsOptions { + opts := cmd.Options() + opts.Timestamp = cmd.Timestamp + opts.SkipSchemaValidation = cmd.SkipSchemaValidation + return opts +} + // Run executes the build command via the shared cli.RunBuild runner. func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { return cli.RunBuild(ctx, dockerClient, regClient, cli.BuildCommandOptions{ ConfigFilename: cmd.File, Tag: cmd.Tag, - Flags: cmd.Options(), + Flags: cmd.options(), }) } diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go index 72e8f9b6d6..6435ea4c1d 100644 --- a/cmd/cog-kong/flags.go +++ b/cmd/cog-kong/flags.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/replicate/cog/pkg/cli" - "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/model" ) @@ -23,8 +22,9 @@ func (c *ConfigFlag) ProvideModelSource() (*model.Source, error) { return model.NewSource(c.File) } -// BuildFlags groups all flags shared across commands that build images. -// Embed this in any command struct that calls resolver.Build(). +// BuildFlags groups the image-building flags shared by the build and push +// commands. Build-only flags (timestamp, skip-schema-validation) live on the +// BuildCmd struct itself so they don't leak onto push, matching the Cobra CLI. type BuildFlags struct { ConfigFlag `embed:""` @@ -34,24 +34,20 @@ type BuildFlags struct { Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image)."` UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` - OpenAPISchema string `name:"openapi-schema" type:"existingfile" help:"Load OpenAPI schema from a file."` + // Cobra accepts --openapi-schema as a plain string and defers the + // file-existence check to build time, so do not use type:"existingfile". + OpenAPISchema string `name:"openapi-schema" help:"Load OpenAPI schema from a file."` // Hidden flags Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` - Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` Strip bool `name:"strip" hidden:"" help:"Whether to strip shared libraries for faster inference times."` Precompile bool `name:"precompile" hidden:"" help:"Whether to precompile python files for faster load times."` } -// AfterApply syncs parsed flag values to package-level globals that the build -// pipeline reads. This runs after Kong parses flags but before Run(). -func (b *BuildFlags) AfterApply() error { - config.BuildSourceEpochTimestamp = b.Timestamp - return nil -} - // Options converts the Kong build flags into the parser-independent -// cli.BuildFlagsOptions shared with the Cobra CLI. +// cli.BuildFlagsOptions shared with the Cobra CLI. Timestamp defaults to -1 +// (timestamp rewriting disabled), matching push, which has no --timestamp flag; +// BuildCmd overrides Timestamp and SkipSchemaValidation from its own flags. func (b *BuildFlags) Options() cli.BuildFlagsOptions { return cli.BuildFlagsOptions{ NoCache: b.NoCache, @@ -64,7 +60,7 @@ func (b *BuildFlags) Options() cli.BuildFlagsOptions { DockerfileFile: b.Dockerfile, Strip: b.Strip, Precompile: b.Precompile, - Timestamp: b.Timestamp, + Timestamp: -1, } } diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index b103c5382b..063ee950ec 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -11,6 +11,7 @@ import ( "github.com/alecthomas/kong" "github.com/stretchr/testify/require" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" ) @@ -159,6 +160,11 @@ func TestKongRootHelpParses(t *testing.T) { for _, name := range []string{"predict", "train", "weights", "debug"} { require.Truef(t, hiddenByName[name], "command %q should be hidden", name) } + // Visible commands (notably run, the non-hidden twin of predict) must NOT be + // hidden — guards against accidentally hiding the wrong prediction command. + for _, name := range []string{"run", "build", "push", "serve", "exec", "init", "login", "doctor", "base-image"} { + require.Falsef(t, hiddenByName[name], "command %q should be visible", name) + } } func TestKongRootGlobalFlagsParse(t *testing.T) { @@ -296,25 +302,135 @@ func TestKongBaseImageCommandFlagsParse(t *testing.T) { for _, args := range [][]string{ {"base-image", "dockerfile", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--no-cache", "--progress", "plain", "--help"}, {"base-image", "build", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--help"}, + {"base-image", "generate-matrix", "--cuda", "12.4", "--python", "3.12", "--help"}, } { _, err := parser.Parse(args) require.NoErrorf(t, err, "parse %v", args) } } -func TestKongCommandCoverageMatchesExpectedCobraSurface(t *testing.T) { +// TestKongBuildOnlyFlagsParity asserts that --timestamp and +// --skip-schema-validation are build-only (not on push), matching Cobra, and +// that build maps them through to the shared options. +func TestKongBuildOnlyFlagsParity(t *testing.T) { + preserveKongGlobalState(t) + + // build accepts the build-only flags and maps them to options. + var buildCLI CLI + buildParser, err := newParser(t.Context(), &buildCLI) + require.NoError(t, err) + _, err = buildParser.Parse([]string{"build", "--timestamp", "42", "--skip-schema-validation", "--tag", "img:latest"}) + require.NoError(t, err) + opts := buildCLI.Build.options() + require.Equal(t, int64(42), opts.Timestamp) + require.True(t, opts.SkipSchemaValidation) + + // push does NOT expose either flag (parity with Cobra push). + for _, flag := range []string{"--timestamp", "--skip-schema-validation"} { + pushParser := newTestParser(t) + _, err := pushParser.Parse([]string{"push", flag, "x", "img:latest"}) + require.Errorf(t, err, "push should reject %s", flag) + require.Containsf(t, err.Error(), "unknown flag", "push should reject %s as unknown", flag) + } + + // push still defaults Timestamp to -1 (timestamp rewriting disabled), not 0. + require.Equal(t, int64(-1), (&BuildFlags{}).Options().Timestamp) +} + +// TestKongEnvFlagParity asserts --env is available on exec/predict/run/train but +// NOT on serve, matching the Cobra CLI (Cobra serve has no --env). +func TestKongEnvFlagParity(t *testing.T) { + // serve has no --env. + serveParser := newTestParser(t) + _, err := serveParser.Parse([]string{"serve", "--env", "A=B"}) + require.Error(t, err, "serve should reject --env") + require.Contains(t, err.Error(), "unknown flag") + + // exec/predict/run/train accept --env and thread it through. + var execCLI CLI + execParser, err := newParser(t.Context(), &execCLI) + require.NoError(t, err) + _, err = execParser.Parse([]string{"exec", "--env", "A=B", "echo", "hi"}) + require.NoError(t, err) + require.Equal(t, []string{"A=B"}, execCLI.Exec.Env) + + var runCLI CLI + runParser, err := newParser(t.Context(), &runCLI) + require.NoError(t, err) + _, err = runParser.Parse([]string{"run", "--env", "A=B", "img"}) + require.NoError(t, err) + require.Equal(t, []string{"A=B"}, runCLI.RunCommand.options("run").Env) +} + +// TestKongExecRequiresCommand asserts the zero-arg exec error matches Cobra's +// MinimumNArgs(1) message. +func TestKongExecRequiresCommand(t *testing.T) { + cmd := &ExecCmd{} + err := cmd.Validate() + require.Error(t, err) + require.Equal(t, "requires at least 1 arg(s), only received 0", err.Error()) +} + +// TestKongCommandCoverageMatchesCobraSurface derives the expected command set +// and hidden status directly from the real Cobra command tree +// (cli.NewRootCommand and cli.NewBaseImageRootCommand) and asserts the Kong +// model matches it node-by-node. Deriving from the actual Cobra surface (rather +// than a hand-maintained list) means any future Cobra command/hidden-flag change +// that the Kong CLI doesn't mirror will fail this test. +func TestKongCommandCoverageMatchesCobraSurface(t *testing.T) { parser := newTestParser(t) - commands := map[string]bool{} + + // Kong top-level command name -> hidden. + kongCmds := map[string]bool{} + kongChildren := map[string]*kong.Node{} for _, node := range parser.Model.Children { - commands[node.Name] = true + if node.Type != kong.CommandNode { + continue + } + kongCmds[node.Name] = node.Hidden + kongChildren[node.Name] = node } - expected := []string{ - "base-image", "build", "debug", "doctor", "exec", "init", - "login", "predict", "push", "run", "serve", "train", "weights", + // Expected top-level surface from the real Cobra root command. + root, err := cli.NewRootCommand() + require.NoError(t, err) + expected := map[string]bool{} + for _, c := range root.Commands() { + expected[c.Name()] = c.Hidden } - require.Len(t, commands, len(expected)) - for _, name := range expected { - require.Truef(t, commands[name], "missing command %q", name) + // Cobra ships base-image as a separate binary; Kong folds it under + // `cog base-image`, so add it to the expected top-level surface. + expected["base-image"] = false + + require.Len(t, kongCmds, len(expected), "top-level command count mismatch: kong=%v expected=%v", keys(kongCmds), keys(expected)) + for name, hidden := range expected { + gotHidden, ok := kongCmds[name] + require.Truef(t, ok, "Kong is missing command %q present in Cobra", name) + require.Equalf(t, hidden, gotHidden, "hidden mismatch for command %q", name) + } + + // base-image subcommands must match the Cobra base-image binary. + bi, err := cli.NewBaseImageRootCommand() + require.NoError(t, err) + expectedBase := map[string]bool{} + for _, c := range bi.Commands() { + expectedBase[c.Name()] = c.Hidden + } + kongBase := map[string]bool{} + require.NotNil(t, kongChildren["base-image"], "Kong missing base-image group") + for _, node := range kongChildren["base-image"].Children { + if node.Type != kong.CommandNode { + continue + } + kongBase[node.Name] = node.Hidden + } + require.Equal(t, expectedBase, kongBase, "base-image subcommand surface mismatch") +} + +func keys(m map[string]bool) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) } + return out } diff --git a/cmd/cog-kong/predict.go b/cmd/cog-kong/predict.go index 6c086442e8..c273b85c04 100644 --- a/cmd/cog-kong/predict.go +++ b/cmd/cog-kong/predict.go @@ -10,6 +10,7 @@ import ( // predictionFlags groups the flags shared by the predict and run commands. type predictionFlags struct { RuntimeFlags `embed:""` + EnvFlag `embed:""` Image string `arg:"" optional:"" name:"image" help:"Image to run. If omitted, builds from cog.yaml in the current directory."` Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` @@ -20,8 +21,10 @@ type predictionFlags struct { } func (f predictionFlags) options(use string) cli.PredictionCommandOptions { + rbo := f.Options() + rbo.Env = f.Env return cli.PredictionCommandOptions{ - RuntimeBuildOptions: f.Options(), + RuntimeBuildOptions: rbo, Use: use, Image: f.Image, Input: f.Input, diff --git a/cmd/cog-kong/runtime.go b/cmd/cog-kong/runtime.go index 8e7acb3e24..d0a2413ae3 100644 --- a/cmd/cog-kong/runtime.go +++ b/cmd/cog-kong/runtime.go @@ -11,21 +11,30 @@ import ( // errMissingExecCommand matches Cobra's cobra.MinimumNArgs(1) error message for // the exec command. -var errMissingExecCommand = errors.New("accepts at least 1 arg(s), received 0") +var errMissingExecCommand = errors.New("requires at least 1 arg(s), only received 0") -// RuntimeFlags groups the build/run flags shared by the serve and exec -// commands. +// RuntimeFlags groups the build/run flags shared by the serve, exec, predict, +// run, and train commands. Note that --env is NOT part of this group: Cobra's +// serve command does not expose --env, so commands that accept it (exec, +// predict, run, train) declare their own EnvFlag. type RuntimeFlags struct { ConfigFlag `embed:""` - Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` - CudaBase string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` - CogBase *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` - GPUs string `name:"gpus" help:"GPU devices to add to the container, in the same format as docker run --gpus."` - Env []string `name:"env" short:"e" help:"Environment variables, in the form name=value."` + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` + CudaBase string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` + CogBase *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + GPUs string `name:"gpus" help:"GPU devices to add to the container, in the same format as docker run --gpus."` } -// Options converts the Kong runtime flags into cli.RuntimeBuildOptions. +// EnvFlag is an embeddable --env flag for the commands that accept it (exec, +// predict, run, train). Cobra's serve command has no --env, so ServeCmd omits +// this group. +type EnvFlag struct { + Env []string `name:"env" short:"e" help:"Environment variables, in the form name=value."` +} + +// Options converts the Kong runtime flags into cli.RuntimeBuildOptions. Env is +// populated by callers that embed EnvFlag. func (f RuntimeFlags) Options() cli.RuntimeBuildOptions { return cli.RuntimeBuildOptions{ ConfigFilename: f.File, @@ -33,7 +42,6 @@ func (f RuntimeFlags) Options() cli.RuntimeBuildOptions { UseCudaBaseImage: f.CudaBase, UseCogBaseImage: f.CogBase, GPUs: f.GPUs, - Env: f.Env, } } @@ -56,6 +64,7 @@ func (cmd *ServeCmd) Run(ctx context.Context, dockerClient command.Command, regC // ExecCmd implements the "cog exec" command. type ExecCmd struct { RuntimeFlags `embed:""` + EnvFlag `embed:""` Publish []string `name:"publish" short:"p" help:"Publish a container's port to the host, e.g. -p 8000."` Args []string `arg:"" passthrough:"" name:"command" help:"Command and arguments to execute."` @@ -69,8 +78,10 @@ func (cmd *ExecCmd) Validate() error { } func (cmd *ExecCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + opts := cmd.Options() + opts.Env = cmd.Env return cli.RunExec(ctx, dockerClient, regClient, cli.ExecCommandOptions{ - RuntimeBuildOptions: cmd.Options(), + RuntimeBuildOptions: opts, Args: cmd.Args, Ports: cmd.Publish, }) diff --git a/cmd/cog-kong/simple.go b/cmd/cog-kong/simple.go index ea91374d87..0c0a2d74ef 100644 --- a/cmd/cog-kong/simple.go +++ b/cmd/cog-kong/simple.go @@ -48,9 +48,9 @@ type DebugCmd struct { UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` - // Hidden flags for parity with the Cobra debug command. - Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile."` - Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` + // Hidden flag for parity with the Cobra debug command. RunDebug ignores it + // (it is a no-op on both CLIs), but it is accepted for surface parity. + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` } func (cmd *DebugCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { diff --git a/cmd/cog-kong/train.go b/cmd/cog-kong/train.go index d79f16c09a..4d3c115f99 100644 --- a/cmd/cog-kong/train.go +++ b/cmd/cog-kong/train.go @@ -5,11 +5,13 @@ import ( "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/util/console" ) // TrainCmd implements the hidden, deprecated "cog train" command. type TrainCmd struct { RuntimeFlags `embed:""` + EnvFlag `embed:""` Image string `arg:"" optional:"" name:"image" help:"Image to train. If omitted, builds from cog.yaml in the current directory."` Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` @@ -18,8 +20,12 @@ type TrainCmd struct { } func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) error { + // Match Cobra's Deprecated notice for the train command. + console.Warn("Command \"train\" is deprecated, the train command will be removed in a future version of Cog") + opts := cmd.Options() + opts.Env = cmd.Env return cli.RunTrain(ctx, dockerClient, cli.TrainCommandOptions{ - RuntimeBuildOptions: cmd.Options(), + RuntimeBuildOptions: opts, Image: cmd.Image, Input: cmd.Input, OutputPath: cmd.Output, diff --git a/pkg/cli/baseimage.go b/pkg/cli/baseimage.go index 6ce6f95e78..85d2dfbdfd 100644 --- a/pkg/cli/baseimage.go +++ b/pkg/cli/baseimage.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "slices" "strings" "github.com/spf13/cobra" @@ -61,65 +62,7 @@ func newBaseImageGenerateMatrix() *cobra.Command { Use: "generate-matrix", Short: "Generate a matrix of Cog base image versions (JSON)", RunE: func(cmd *cobra.Command, args []string) error { - validCudaVersions := strings.FieldsFunc(baseImageCUDAVersion, func(c rune) bool { - return c == ',' - }) - validPythonVersions := strings.FieldsFunc(baseImagePythonVersion, func(c rune) bool { - return c == ',' - }) - validTorchVersions := strings.FieldsFunc(baseImageTorchVersion, func(c rune) bool { - return c == ',' - }) - - allConfigurations := dockerfile.BaseImageConfigurations() - filteredMatrix := make([]dockerfile.BaseImageConfiguration, 0, len(allConfigurations)) - for _, config := range allConfigurations { - var found bool - if len(validCudaVersions) > 0 { - found = false - for _, validCudaVersion := range validCudaVersions { - if config.CUDAVersion == validCudaVersion { - found = true - } - } - if !found { - continue - } - } - - if len(validPythonVersions) > 0 { - found = false - for _, validPythonVersion := range validPythonVersions { - if config.PythonVersion == validPythonVersion { - found = true - } - } - if !found { - continue - } - } - - if len(validTorchVersions) > 0 { - found = false - for _, validTorchVersion := range validTorchVersions { - if config.TorchVersion == validTorchVersion { - found = true - } - } - if !found { - continue - } - } - - filteredMatrix = append(filteredMatrix, config) - } - - output, err := json.Marshal(filteredMatrix) - if err != nil { - return err - } - fmt.Println(string(output)) - return nil + return RunBaseImageGenerateMatrix(baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -127,6 +70,45 @@ func newBaseImageGenerateMatrix() *cobra.Command { return cmd } +// RunBaseImageGenerateMatrix prints, as JSON, the matrix of supported base image +// configurations filtered by the comma-separated CUDA/Python/Torch versions in +// opts. Empty version filters match all values. It is shared by both the Cobra +// and Kong base-image generate-matrix commands. +func RunBaseImageGenerateMatrix(opts BaseImageOptions) error { + split := func(s string) []string { + return strings.FieldsFunc(s, func(c rune) bool { return c == ',' }) + } + validCudaVersions := split(opts.CUDAVersion) + validPythonVersions := split(opts.PythonVersion) + validTorchVersions := split(opts.TorchVersion) + + matches := func(filters []string, value string) bool { + return len(filters) == 0 || slices.Contains(filters, value) + } + + allConfigurations := dockerfile.BaseImageConfigurations() + filteredMatrix := make([]dockerfile.BaseImageConfiguration, 0, len(allConfigurations)) + for _, config := range allConfigurations { + if !matches(validCudaVersions, config.CUDAVersion) { + continue + } + if !matches(validPythonVersions, config.PythonVersion) { + continue + } + if !matches(validTorchVersions, config.TorchVersion) { + continue + } + filteredMatrix = append(filteredMatrix, config) + } + + output, err := json.Marshal(filteredMatrix) + if err != nil { + return err + } + fmt.Println(string(output)) + return nil +} + func newBaseImageDockerfileCommand() *cobra.Command { var cmd = &cobra.Command{ Use: "dockerfile", diff --git a/pkg/cli/build.go b/pkg/cli/build.go index c4ba0cd52d..0fdde3a079 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -275,7 +275,10 @@ func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error { flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"} var flagsSet []string for _, flag := range flags { - if cmd.Flag(flag).Changed { + // Not every command that runs this check registers all of these flags + // (e.g. exec has no --dockerfile), so skip flags that aren't defined. + f := cmd.Flag(flag) + if f != nil && f.Changed { flagsSet = append(flagsSet, "--"+flag) } } diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index b158ba2889..5a7572230c 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -28,7 +28,6 @@ func newDebugCommand() *cobra.Command { addSeparateWeightsFlag(cmd) addUseCudaBaseImageFlag(cmd) - addDockerfileFlag(cmd) addUseCogBaseImageFlag(cmd) addBuildTimestampFlag(cmd) addConfigFlag(cmd) diff --git a/pkg/cli/exec.go b/pkg/cli/exec.go index 9212015047..7ad33b95ff 100644 --- a/pkg/cli/exec.go +++ b/pkg/cli/exec.go @@ -49,7 +49,6 @@ exploring the environment your model will run in.`, Args: cobra.MinimumNArgs(1), } addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addUseCudaBaseImageFlag(cmd) addUseCogBaseImageFlag(cmd) addGpusFlag(cmd) diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 27fea328f8..a4e394b85b 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -91,7 +91,6 @@ it.`, addUseCudaBaseImageFlag(cmd) addUseCogBaseImageFlag(cmd) addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addGpusFlag(cmd) addSetupTimeoutFlag(cmd) addConfigFlag(cmd) diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 337ca8e7ae..571abd49d9 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -41,7 +41,6 @@ Otherwise, it will build the model in the current directory and train it.`, } addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addUseCudaBaseImageFlag(cmd) addGpusFlag(cmd) addUseCogBaseImageFlag(cmd) From 89369f175b521035d71c5d6133563dfd97611487 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 11 Jun 2026 11:14:30 -0400 Subject: [PATCH 19/19] fix: address Kong/Cobra CLI parity review feedback - Make --help short-circuit before validation by using Kong's built-in help flag (drop NoDefaultHelp + post-parse Help check). Help now prints and exits 0 even with mutex/enum/arg/file validation errors, matching Cobra. - Drop --setup-timeout from kong train (Cobra train has no such flag). - Remove push mutual-exclusivity Validate (Cobra push has no such check). - Drop type:"existingfile" from --dockerfile so file checks defer to build time, matching Cobra and letting --help short-circuit. --- cmd/cog-kong/cli.go | 1 - cmd/cog-kong/flags.go | 4 ++- cmd/cog-kong/main.go | 5 --- cmd/cog-kong/main_test.go | 75 ++++++++++++++++++++++----------------- cmd/cog-kong/push.go | 5 --- cmd/cog-kong/train.go | 8 ++--- 6 files changed, 48 insertions(+), 50 deletions(-) diff --git a/cmd/cog-kong/cli.go b/cmd/cog-kong/cli.go index f0ec1c0cee..b6a21ce0d1 100644 --- a/cmd/cog-kong/cli.go +++ b/cmd/cog-kong/cli.go @@ -14,7 +14,6 @@ import ( type Globals struct { Debug bool `name:"debug" short:"d" env:"COG_DEBUG" help:"Show debugging output."` NoColor bool `name:"no-color" help:"Disable colored output."` - Help bool `name:"help" short:"h" help:"Show context-sensitive help."` Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."` Profile bool `name:"profile" hidden:"" help:"Enable profiling."` Version kong.VersionFlag `name:"version" help:"Show version of Cog."` diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go index 6435ea4c1d..51e4dc89b9 100644 --- a/cmd/cog-kong/flags.go +++ b/cmd/cog-kong/flags.go @@ -39,7 +39,9 @@ type BuildFlags struct { OpenAPISchema string `name:"openapi-schema" help:"Load OpenAPI schema from a file."` // Hidden flags - Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` + // Cobra accepts --dockerfile as a plain string and defers the + // file-existence check to build time, so do not use type:"existingfile". + Dockerfile string `name:"dockerfile" hidden:"" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` Strip bool `name:"strip" hidden:"" help:"Whether to strip shared libraries for faster inference times."` Precompile bool `name:"precompile" hidden:"" help:"Whether to precompile python files for faster load times."` } diff --git a/cmd/cog-kong/main.go b/cmd/cog-kong/main.go index a4125b1fc5..dabad287b9 100644 --- a/cmd/cog-kong/main.go +++ b/cmd/cog-kong/main.go @@ -68,10 +68,6 @@ func main() { // otherwise it's a real parse error (e.g. unexpected command or flag), so print the error and exit non-zero. parser.FatalIfErrorf(err) } - if cli.Help { - _ = kctx.PrintUsage(false) - return - } displayUpdateCheck(ctx) err = kctx.Run() @@ -99,7 +95,6 @@ func newParser(ctx context.Context, cli *CLI, options ...kong.Option) (*kong.Kon "registry_default": global.DefaultReplicateRegistryHost, }, kong.UsageOnError(), - kong.NoDefaultHelp(), kong.BindTo(ctx, (*context.Context)(nil)), kong.BindSingletonProvider(provideDockerClient), kong.BindToProvider(provideRegistryClient), diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go index 063ee950ec..187fffae29 100644 --- a/cmd/cog-kong/main_test.go +++ b/cmd/cog-kong/main_test.go @@ -2,10 +2,9 @@ package main import ( "bytes" - "errors" + "io" "os" "path/filepath" - "strings" "testing" "github.com/alecthomas/kong" @@ -18,11 +17,33 @@ import ( func newTestParser(t *testing.T) *kong.Kong { t.Helper() - parser, err := newParser(t.Context(), &CLI{}) + // Override Exit so Kong's built-in --help flag (which exits the process + // after printing help, before validation) doesn't kill the test process. + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) require.NoError(t, err) + parser.Stdout = io.Discard + parser.Stderr = io.Discard return parser } +// parseForTest parses args, tolerating Kong's built-in --help short-circuit, +// which prints help and calls Exit(0) (translated to a testExitCode(0) panic by +// newTestParser). It returns the parse error for non-help cases. +func parseForTest(t *testing.T, parser *kong.Kong, args []string) (err error) { + t.Helper() + defer func() { + if r := recover(); r != nil { + code, ok := r.(testExitCode) + require.Truef(t, ok, "unexpected panic parsing %v: %v", args, r) + require.Equalf(t, testExitCode(0), code, "non-zero exit parsing %v", args) + } + }() + _, err = parser.Parse(args) + return err +} + type kongGlobalState struct { debug bool noColor bool @@ -122,26 +143,25 @@ func TestKongRegistersNestedCommands(t *testing.T) { {"base-image", "dockerfile", "--help"}, {"base-image", "build", "--help"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } func TestKongRootHelpParses(t *testing.T) { preserveKongGlobalState(t) - parser := newTestParser(t) + // Override Exit so the built-in --help flag doesn't terminate the process, + // and capture stdout to assert on the rendered help output. var stdout bytes.Buffer + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) + require.NoError(t, err) parser.Stdout = &stdout parser.Stderr = &stdout - kctx, err := parser.Parse([]string{"--help"}) - if err != nil { - var parseErr *kong.ParseError - require.True(t, errors.As(err, &parseErr), "expected ParseError, got %T", err) - require.True(t, strings.HasPrefix(parseErr.Error(), "expected"), "expected command selection error, got %q", parseErr.Error()) - kctx = parseErr.Context - } - require.NoError(t, kctx.PrintUsage(false)) + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"--help"}) + }) help := stdout.String() require.Contains(t, help, "Usage: cog [flags]") @@ -152,7 +172,7 @@ func TestKongRootHelpParses(t *testing.T) { // Commands hidden in Cobra must also be hidden in the Kong model so they // don't appear in the root help output. hiddenByName := map[string]bool{} - for _, node := range kctx.Model.Children { + for _, node := range parser.Model.Children { if node.Hidden { hiddenByName[node.Name] = true } @@ -178,8 +198,7 @@ func TestKongRootGlobalFlagsParse(t *testing.T) { } { restoreKongGlobalState(t, state) parser := newTestParser(t) - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } restoreKongGlobalState(t, state) @@ -201,12 +220,7 @@ func TestKongHelpParsingDoesNotWriteUpdateState(t *testing.T) { {"build", "--help"}, } { parser := newTestParser(t) - _, err := parser.Parse(args) - if err != nil { - var parseErr *kong.ParseError - require.True(t, errors.As(err, &parseErr), "expected ParseError, got %T", err) - require.True(t, strings.HasPrefix(parseErr.Error(), "expected"), "expected command selection error, got %q", parseErr.Error()) - } + _ = parseForTest(t, parser, args) require.NoFileExists(t, filepath.Join(home, ".config", "cog", "update-state.json"), "parse %v", args) } } @@ -250,8 +264,7 @@ func TestKongRuntimeCommandFlagsParse(t *testing.T) { {"serve", "--port", "5000", "--upload-url", "https://example.com/upload", "--gpus", "all", "--help"}, {"exec", "--gpus", "all", "--publish", "8888", "--env", "A=B", "python", "-c", "print(1)"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } @@ -264,8 +277,7 @@ func TestKongSimpleCommandFlagsParse(t *testing.T) { {"doctor", "--fix", "--file", "custom.yaml", "--help"}, {"debug", "--image-name", "myimage", "--help"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } @@ -278,8 +290,7 @@ func TestKongPredictionCommandFlagsParse(t *testing.T) { {"run", "--json", "@inputs.json", "--gpus", "all", "--help"}, {"train", "example/image", "--input", "dataset=@data.json", "--help"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } @@ -291,8 +302,7 @@ func TestKongWeightsCommandFlagsParse(t *testing.T) { {"weights", "pull", "--verbose", "weights-name", "--help"}, {"weights", "status", "--json", "--verbose", "--help"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } @@ -304,8 +314,7 @@ func TestKongBaseImageCommandFlagsParse(t *testing.T) { {"base-image", "build", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--help"}, {"base-image", "generate-matrix", "--cuda", "12.4", "--python", "3.12", "--help"}, } { - _, err := parser.Parse(args) - require.NoErrorf(t, err, "parse %v", args) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) } } diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go index 30727f81a5..d505871b78 100644 --- a/cmd/cog-kong/push.go +++ b/cmd/cog-kong/push.go @@ -16,11 +16,6 @@ type PushCmd struct { Image string `arg:"" optional:"" help:"Image name to push (e.g. registry.example.com/user/model)."` } -// Validate is called by Kong after parsing, before Run. -func (cmd *PushCmd) Validate() error { - return cmd.ValidateMutualExclusivity() -} - // Run executes the push command via the shared cli.RunPush runner. func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry) error { return cli.RunPush(ctx, dockerClient, regClient, providerReg, cli.PushCommandOptions{ diff --git a/cmd/cog-kong/train.go b/cmd/cog-kong/train.go index 4d3c115f99..1caaa15e45 100644 --- a/cmd/cog-kong/train.go +++ b/cmd/cog-kong/train.go @@ -13,10 +13,9 @@ type TrainCmd struct { RuntimeFlags `embed:""` EnvFlag `embed:""` - Image string `arg:"" optional:"" name:"image" help:"Image to train. If omitted, builds from cog.yaml in the current directory."` - Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` - Output string `name:"output" short:"o" default:"weights" help:"Output path."` - SetupTimeout uint32 `name:"setup-timeout" default:"300" help:"The timeout for a container to setup (in seconds)."` + Image string `arg:"" optional:"" name:"image" help:"Image to train. If omitted, builds from cog.yaml in the current directory."` + Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` + Output string `name:"output" short:"o" default:"weights" help:"Output path."` } func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) error { @@ -29,6 +28,5 @@ func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) erro Image: cmd.Image, Input: cmd.Input, OutputPath: cmd.Output, - SetupTimeout: cmd.SetupTimeout, }) }