diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 66be3e736b..025dfebda0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -138,6 +138,7 @@ jobs: fetch-depth: 0 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Build SDK run: mise run ci:build:sdk @@ -190,6 +191,7 @@ jobs: fetch-depth: 0 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Get version from VERSION.txt id: version @@ -221,6 +223,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check Go formatting run: mise run fmt:go @@ -233,6 +236,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure Rust components run: rustup component add rustfmt clippy @@ -247,6 +251,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check Python formatting run: mise run fmt:python @@ -259,6 +264,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check llms.txt is up to date run: mise run docs:llm:check @@ -283,6 +289,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check stubs are up to date run: | @@ -302,6 +309,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Lint Go run: mise run lint:go @@ -318,6 +326,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure Rust components run: rustup component add rustfmt clippy @@ -332,6 +341,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check licenses, bans, and sources run: cargo deny --manifest-path crates/Cargo.toml check bans licenses sources @@ -349,6 +359,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check advisories (informational) run: cargo deny --manifest-path crates/Cargo.toml check advisories @@ -373,6 +384,7 @@ jobs: run: tar xf dist/*.tar.gz --strip-components=1 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Ensure nox is installed run: mise install pipx:nox --force @@ -387,6 +399,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Lint Markdown run: mise run lint:docs @@ -407,6 +420,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test Go shell: bash @@ -426,22 +440,19 @@ jobs: fuzz-go: name: Fuzz Go runs-on: ubuntu-latest - timeout-minutes: 10 + # test:fuzz auto-discovers every Fuzz* target and runs each for FUZZTIME + # (30s default), so this scales as targets are added. + timeout-minutes: 15 env: CGO_ENABLED: "1" steps: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - - name: Fuzz schema type resolution - run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime=30s - - name: Fuzz JSON schema generation - run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime=30s - - name: Fuzz Python parser - run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime=30s - - name: Fuzz type annotation parsing - run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime=30s + - name: Fuzz all targets (auto-discovered) + run: mise run test:fuzz test-rust: name: Test Rust @@ -455,6 +466,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test Rust run: mise run test:rust @@ -478,6 +490,7 @@ jobs: run: tar xf dist/*.tar.gz --strip-components=1 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Remove src to ensure tests run against wheel run: rm -rf python/cog @@ -508,6 +521,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Test coglet-python bindings run: uvx nox -s coglet -p ${{ matrix.python-version }} @@ -628,6 +642,7 @@ jobs: chmod +x ./cog - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Set wheel environment run: | @@ -787,6 +802,7 @@ jobs: save-if: ${{ github.ref == 'refs/heads/main' }} - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check coglet crates.io publish run: cargo publish --dry-run -p coglet --manifest-path crates/Cargo.toml diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 1010debde1..8fc3d52412 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -15,6 +15,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-docs-${{ github.job }} - name: Generate CLI docs diff --git a/.github/workflows/mirror-cog-base-images.yaml b/.github/workflows/mirror-cog-base-images.yaml index 1a08bb2d87..9c5f739158 100644 --- a/.github/workflows/mirror-cog-base-images.yaml +++ b/.github/workflows/mirror-cog-base-images.yaml @@ -26,6 +26,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-mirror-${{ github.job }} - name: Install crane diff --git a/.github/workflows/release-build.yaml b/.github/workflows/release-build.yaml index 81f3ca45ce..eb4f5f9865 100644 --- a/.github/workflows/release-build.yaml +++ b/.github/workflows/release-build.yaml @@ -129,6 +129,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-rel-${{ github.job }} - name: Update coglet version constraint @@ -219,6 +220,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-rel-${{ github.job }} - name: Check for existing release diff --git a/.github/workflows/release-publish.yaml b/.github/workflows/release-publish.yaml index 7a5a08a32c..9cef41f876 100644 --- a/.github/workflows/release-publish.yaml +++ b/.github/workflows/release-publish.yaml @@ -115,6 +115,7 @@ jobs: - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-pub-${{ github.job }} - uses: rust-lang/crates-io-auth-action@v1 id: auth diff --git a/.github/workflows/rust-advisories.yaml b/.github/workflows/rust-advisories.yaml index 143595999d..d2c5fe5a41 100644 --- a/.github/workflows/rust-advisories.yaml +++ b/.github/workflows/rust-advisories.yaml @@ -21,6 +21,7 @@ jobs: - uses: actions/checkout@v6 - uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache_key_prefix: mise-ci-${{ github.job }} - name: Check advisories id: advisories diff --git a/.github/workflows/update-compatibility-matrices.yaml b/.github/workflows/update-compatibility-matrices.yaml index 716095687d..2e9694b1aa 100644 --- a/.github/workflows/update-compatibility-matrices.yaml +++ b/.github/workflows/update-compatibility-matrices.yaml @@ -27,6 +27,7 @@ jobs: - name: Install tools uses: jdx/mise-action@v4 with: + version: 2026.4.27 cache: true cache_key_prefix: mise-compatibility-matrices diff --git a/VERSION.txt b/VERSION.txt index 1e7ac68eed..0407900b35 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.21.0-rc.1 +0.21.0-rc.3 diff --git a/architecture/02-schema.md b/architecture/02-schema.md index f038219360..6ed6d16742 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -102,15 +102,15 @@ class Runner(BaseRunner): The resolver handles local imports relative to the predictor file and project root: -| Import Style | File Resolved | -| ------------------------------------ | ----------------------------------------------------------- | -| `from output_types import X` | `/output_types.py` | -| `from .output_types import X` | `/output_types.py` | -| `from models.output import X` | `/models/output.py` | -| `from .models.output import X` | `/models/output.py` | -| `from output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from . import output_types` | `/output_types.py` (module alias tracked) | +| Import Style | File Resolved | +| ---------------------------------- | -------------------------------------------------------- | +| `from output_types import X` | `/output_types.py` | +| `from .output_types import X` | `/output_types.py` | +| `from models.output import X` | `/models/output.py` | +| `from .models.output import X` | `/models/output.py` | +| `from output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from . import output_types` | `/output_types.py` (module alias tracked) | **How it distinguishes local from external**: the resolver converts the module path to a filesystem path and checks if the file exists. If `output_types.py` exists in the project directory, it's local. If not (e.g., `from transformers import ...`), it's external. Known external packages (stdlib, torch, numpy, etc.) are skipped without a filesystem check. @@ -175,18 +175,28 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Input Types -| Python | JSON Schema | Notes | -| ------------------------------------- | ---------------------------------------------------------------- | -------------------------- | -| `str` | `{"type": "string"}` | | -| `int` | `{"type": "integer"}` | | -| `float` | `{"type": "number"}` | | -| `bool` | `{"type": "boolean"}` | | -| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | -| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | -| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | -| `list[T]` | `{"type": "array", "items": {...}}` | | -| `Optional[T]` | Type T + not in `required` | Input fields only | -| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | +| Python | JSON Schema | Notes | +| ------------------------------------- | ---------------------------------------------------------------- | --------------------------------------------------------------------- | +| `str` | `{"type": "string"}` | | +| `int` | `{"type": "integer"}` | | +| `float` | `{"type": "number"}` | | +| `bool` | `{"type": "boolean"}` | | +| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | +| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | +| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | +| `list[T]` | `{"type": "array", "items": {...}}` | | +| `Optional[T]` / `T \| None` | Type T + `nullable: true`, not in `required` | Input fields only; never required | +| `A \| B` / `Union[A, B]` | `{"anyOf": [A, B]}` | Input-only, JSON-native unions only | +| `A \| B \| None` | `{"anyOf": [A, B]}` + `nullable: true` | Multi-variant union; stays in `required` unless a default is supplied | +| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | + +Input unions are intentionally narrower than output types. Cog supports JSON-native input unions (`str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`) so request validation can happen at the HTTP boundary and Python normalisation can choose a deterministic value type. Cog rejects unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` because those cases are ambiguous for clients or runtime coercion. Output unions remain unsupported (see below). + +A plain single-type optional (`Optional[T]` or `T | None`) is **never** placed in `required`, regardless of whether a default is supplied. A multi-variant nullable union (`A | B | None`) is different: because the field carries a concrete `anyOf` value type, it stays in `required` unless a default makes it omittable. This is why the two rows above differ in their `required` behaviour. + +Nullable behaviour matches every other optional field: `nullable: true` (plus omission from `required`) means an **omitted** value falls back to the default. An **explicit** JSON `null` is still validated against the field type and is rejected at the HTTP edge, because the runtime validator does not treat OpenAPI's `nullable` keyword as an additional accepted value. "May be null" therefore means "may be omitted", not "accepts an explicit null payload". + +> **Runtime caveat:** Cog marks optionals as not-`required` in the schema, but the predictor still needs a Python-level default so the omitted value resolves to `None`. Use `value: Optional[T] = Input(...)` (the `Input(...)` supplies an implicit `None`) or `Input(default=None)`. A bare `value: Optional[T]` annotation with no `= Input(...)` generates a correct "optional" schema but raises `TypeError: missing 1 required positional argument` when the field is omitted at runtime. ### Output Types diff --git a/cmd/cog-kong/baseimage.go b/cmd/cog-kong/baseimage.go new file mode 100644 index 0000000000..09a5d5ea13 --- /dev/null +++ b/cmd/cog-kong/baseimage.go @@ -0,0 +1,72 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" +) + +// BaseImageCmd implements the experimental "cog base-image" command group. +type BaseImageCmd struct { + Dockerfile BaseImageDockerfileCmd `cmd:"" help:"Display Cog base image Dockerfile."` + Build BaseImageBuildCmd `cmd:"" help:"Build Cog base image."` + GenerateMatrix BaseImageGenerateMatrixCmd `cmd:"" name:"generate-matrix" help:"Generate a matrix of Cog base image versions (JSON)."` +} + +// baseImageVersionFlags groups the version-selecting flags shared by the +// base-image subcommands. +type baseImageVersionFlags struct { + CUDA string `name:"cuda" help:"CUDA version."` + Python string `name:"python" help:"Python version."` + Torch string `name:"torch" help:"Torch version."` + + // Hidden flags for parity with the Cobra base-image command. + BreakSystemPackages bool `name:"break-system-packages" hidden:"" help:"Allow pip to modify uv-managed Python installs."` + BuildContextDir string `name:"build-context-dir" hidden:"" help:"Directory for generated Docker build context artifacts."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` +} + +func (f baseImageVersionFlags) options() cli.BaseImageOptions { + return cli.BaseImageOptions{ + CUDAVersion: f.CUDA, + PythonVersion: f.Python, + TorchVersion: f.Torch, + BreakSystemPackages: f.BreakSystemPackages, + BuildContextDir: f.BuildContextDir, + Timestamp: f.Timestamp, + } +} + +// BaseImageDockerfileCmd implements "cog base-image dockerfile". +type BaseImageDockerfileCmd struct { + baseImageVersionFlags `embed:""` + + NoCache bool `name:"no-cache" help:"Do not use cache when building the image."` + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` +} + +func (cmd *BaseImageDockerfileCmd) Run(ctx context.Context) error { + opts := cmd.options() + opts.NoCache = cmd.NoCache + opts.ProgressOutput = cmd.Progress + return cli.RunBaseImageDockerfile(ctx, opts) +} + +// BaseImageBuildCmd implements "cog base-image build". +type BaseImageBuildCmd struct { + baseImageVersionFlags `embed:""` +} + +func (cmd *BaseImageBuildCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunBaseImageBuild(ctx, dockerClient, cmd.options()) +} + +// BaseImageGenerateMatrixCmd implements "cog base-image generate-matrix". +type BaseImageGenerateMatrixCmd struct { + baseImageVersionFlags `embed:""` +} + +func (cmd *BaseImageGenerateMatrixCmd) Run() error { + return cli.RunBaseImageGenerateMatrix(cmd.options()) +} diff --git a/cmd/cog-kong/build.go b/cmd/cog-kong/build.go index e5afb7054c..db6fa13595 100644 --- a/cmd/cog-kong/build.go +++ b/cmd/cog-kong/build.go @@ -3,18 +3,20 @@ package main import ( "context" - "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" - "github.com/replicate/cog/pkg/util/console" ) -// BuildCmd implements the "cog build" command. +// BuildCmd implements the "cog build" command. Timestamp and +// SkipSchemaValidation are build-only hidden flags (push does not expose them), +// matching the Cobra CLI. type BuildCmd struct { BuildFlags `embed:""` - Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."` + Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."` + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` + SkipSchemaValidation bool `name:"skip-schema-validation" hidden:"" help:"Skip OpenAPI schema generation and validation."` } // Validate is called by Kong after parsing, before Run. It replaces Cobra's PreRunE. @@ -22,23 +24,19 @@ func (cmd *BuildCmd) Validate() error { return cmd.ValidateMutualExclusivity() } -// Run executes the build command. -func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, src *model.Source) error { - imageName := src.Config.Image - if cmd.Tag != "" { - imageName = cmd.Tag - } - if imageName == "" { - imageName = config.DockerImageName(src.ProjectDir) - } - - resolver := model.NewResolver(dockerClient, regClient) - m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, nil)) - if err != nil { - return err - } - - console.Infof("\nImage built as %s", m.ImageRef()) +// options returns the build flags with the build-only fields applied. +func (cmd *BuildCmd) options() cli.BuildFlagsOptions { + opts := cmd.Options() + opts.Timestamp = cmd.Timestamp + opts.SkipSchemaValidation = cmd.SkipSchemaValidation + return opts +} - return nil +// Run executes the build command via the shared cli.RunBuild runner. +func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunBuild(ctx, dockerClient, regClient, cli.BuildCommandOptions{ + ConfigFilename: cmd.File, + Tag: cmd.Tag, + Flags: cmd.options(), + }) } diff --git a/cmd/cog-kong/cli.go b/cmd/cog-kong/cli.go index 78790ec139..b6a21ce0d1 100644 --- a/cmd/cog-kong/cli.go +++ b/cmd/cog-kong/cli.go @@ -1,12 +1,11 @@ package main import ( - "context" + "os" "github.com/alecthomas/kong" "github.com/replicate/cog/pkg/global" - "github.com/replicate/cog/pkg/update" "github.com/replicate/cog/pkg/util/console" ) @@ -14,25 +13,31 @@ import ( // The AfterApply hook replaces Cobra's PersistentPreRun. type Globals struct { Debug bool `name:"debug" short:"d" env:"COG_DEBUG" help:"Show debugging output."` + NoColor bool `name:"no-color" help:"Disable colored output."` Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."` Profile bool `name:"profile" hidden:"" help:"Enable profiling."` - Version kong.VersionFlag `name:"version" short:"v" help:"Show version of Cog."` + Version kong.VersionFlag `name:"version" help:"Show version of Cog."` } // AfterApply runs after flag parsing, before the command's Run. // This is the Kong equivalent of Cobra's PersistentPreRun. -func (g *Globals) AfterApply(ctx context.Context) error { +func (g *Globals) AfterApply() error { if g.Debug { global.Debug = true console.SetLevel(console.DebugLevel) } + if g.NoColor { + global.NoColor = true + } + if global.NoColor || !console.ShouldUseColor() { + console.SetColor(false) + } + if global.NoColor { + _ = os.Setenv("NO_COLOR", "1") + } if g.Profile { global.ProfilingEnabled = true } global.ReplicateRegistryHost = g.Registry - - if err := update.DisplayAndCheckForRelease(ctx); err != nil { - console.Debugf("%s", err) - } return nil } diff --git a/cmd/cog-kong/flags.go b/cmd/cog-kong/flags.go index 11c67b1aa7..51e4dc89b9 100644 --- a/cmd/cog-kong/flags.go +++ b/cmd/cog-kong/flags.go @@ -5,7 +5,7 @@ import ( "os" "strings" - "github.com/replicate/cog/pkg/config" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/model" ) @@ -22,8 +22,9 @@ func (c *ConfigFlag) ProvideModelSource() (*model.Source, error) { return model.NewSource(c.File) } -// BuildFlags groups all flags shared across commands that build images. -// Embed this in any command struct that calls resolver.Build(). +// BuildFlags groups the image-building flags shared by the build and push +// commands. Build-only flags (timestamp, skip-schema-validation) live on the +// BuildCmd struct itself so they don't leak onto push, matching the Cobra CLI. type BuildFlags struct { ConfigFlag `embed:""` @@ -33,42 +34,44 @@ type BuildFlags struct { Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image)."` UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` - OpenAPISchema string `name:"openapi-schema" type:"existingfile" help:"Load OpenAPI schema from a file."` + // Cobra accepts --openapi-schema as a plain string and defers the + // file-existence check to build time, so do not use type:"existingfile". + OpenAPISchema string `name:"openapi-schema" help:"Load OpenAPI schema from a file."` // Hidden flags - Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` - Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` + // Cobra accepts --dockerfile as a plain string and defers the + // file-existence check to build time, so do not use type:"existingfile". + Dockerfile string `name:"dockerfile" hidden:"" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."` Strip bool `name:"strip" hidden:"" help:"Whether to strip shared libraries for faster inference times."` Precompile bool `name:"precompile" hidden:"" help:"Whether to precompile python files for faster load times."` } -// AfterApply syncs parsed flag values to package-level globals that the build -// pipeline reads. This runs after Kong parses flags but before Run(). -func (b *BuildFlags) AfterApply() error { - config.BuildSourceEpochTimestamp = b.Timestamp - return nil -} - -// BuildOptions constructs a model.BuildOptions from the current flag values. -// The imageName and annotations parameters vary by caller (build vs push). -func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions { - return model.BuildOptions{ - ImageName: imageName, - Secrets: b.Secrets, +// Options converts the Kong build flags into the parser-independent +// cli.BuildFlagsOptions shared with the Cobra CLI. Timestamp defaults to -1 +// (timestamp rewriting disabled), matching push, which has no --timestamp flag; +// BuildCmd overrides Timestamp and SkipSchemaValidation from its own flags. +func (b *BuildFlags) Options() cli.BuildFlagsOptions { + return cli.BuildFlagsOptions{ NoCache: b.NoCache, SeparateWeights: b.SeparateWeights, - UseCudaBaseImage: b.UseCudaBaseImage, + Secrets: b.Secrets, ProgressOutput: b.Progress, - SchemaFile: b.OpenAPISchema, - DockerfileFile: b.Dockerfile, + UseCudaBaseImage: b.UseCudaBaseImage, UseCogBaseImage: b.UseCogBaseImage, + OpenAPISchema: b.OpenAPISchema, + DockerfileFile: b.Dockerfile, Strip: b.Strip, Precompile: b.Precompile, - Annotations: annotations, - OCIIndex: model.OCIIndexEnabled(), + Timestamp: -1, } } +// BuildOptions constructs a model.BuildOptions from the current flag values. +// The imageName and annotations parameters vary by caller (build vs push). +func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions { + return b.Options().ModelBuildOptions(imageName, annotations) +} + // ValidateMutualExclusivity ensures that at most one of --use-cog-base-image, // --use-cuda-base-image, and --dockerfile is explicitly set. func (b *BuildFlags) ValidateMutualExclusivity() error { diff --git a/cmd/cog-kong/main.go b/cmd/cog-kong/main.go index 332468b4aa..dabad287b9 100644 --- a/cmd/cog-kong/main.go +++ b/cmd/cog-kong/main.go @@ -12,6 +12,7 @@ import ( "github.com/alecthomas/kong" "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/update" "github.com/replicate/cog/pkg/util/console" ) @@ -26,8 +27,19 @@ var ( type CLI struct { Globals - Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` - Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` + BaseImage BaseImageCmd `cmd:"" name:"base-image" help:"Tools for working with Cog base images."` + Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."` + Debug DebugCmd `cmd:"" hidden:"" help:"Debug Cog internals."` + Doctor DoctorCmd `cmd:"" help:"Check your project for common issues and fix them (experimental)."` + Exec ExecCmd `cmd:"" help:"Execute a command inside a Docker environment."` + Init InitCmd `cmd:"" help:"Configure your project for use with Cog."` + Login LoginCmd `cmd:"" help:"Log in to a container registry."` + Predict PredictCmd `cmd:"" hidden:"" help:"Run a prediction."` + Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."` + RunCommand RunCmd `cmd:"" name:"run" help:"Run a prediction."` + Serve ServeCmd `cmd:"" help:"Run an HTTP server."` + Train TrainCmd `cmd:"" hidden:"" help:"Run a training job."` + Weights WeightsCmd `cmd:"" hidden:"" help:"Commands for managing model weight files."` } func main() { @@ -35,26 +47,7 @@ func main() { var cli CLI - initOpts := []kong.Option{ - // CLI metadata and variable interpolation for struct tags - kong.Name("cog"), - kong.Description("Containers for machine learning."), - kong.Vars{ - "version": fmt.Sprintf("cog version %s (built %s)", version, buildTime), - "commit": commit, - "progress_default": progressDefault(), - "registry_default": global.DefaultReplicateRegistryHost, - }, - kong.UsageOnError(), - - // bindings for lazily injecting dependencies into Run() methods - kong.BindTo(ctx, (*context.Context)(nil)), - kong.BindSingletonProvider(provideDockerClient), - kong.BindToProvider(provideRegistryClient), - kong.BindSingletonProvider(provideProviderRegistry), - } - - parser, err := kong.New(&cli, initOpts...) + parser, err := newParser(ctx, &cli) if err != nil { // Fatal error creating the parser — this is a bug, so panic to get a stack trace. panic(err) @@ -76,6 +69,7 @@ func main() { parser.FatalIfErrorf(err) } + displayUpdateCheck(ctx) err = kctx.Run() cancel() // command returned an error. Print and exit non-zero. @@ -84,6 +78,31 @@ func main() { } } +func displayUpdateCheck(ctx context.Context) { + if err := update.DisplayAndCheckForRelease(ctx); err != nil { + console.Debugf("%s", err) + } +} + +func newParser(ctx context.Context, cli *CLI, options ...kong.Option) (*kong.Kong, error) { + defaultOptions := []kong.Option{ + kong.Name("cog"), + kong.Description("Containers for machine learning."), + kong.Vars{ + "version": fmt.Sprintf("cog version %s (built %s)", version, buildTime), + "commit": commit, + "progress_default": progressDefault(), + "registry_default": global.DefaultReplicateRegistryHost, + }, + kong.UsageOnError(), + kong.BindTo(ctx, (*context.Context)(nil)), + kong.BindSingletonProvider(provideDockerClient), + kong.BindToProvider(provideRegistryClient), + kong.BindSingletonProvider(provideProviderRegistry), + } + return kong.New(cli, append(defaultOptions, options...)...) +} + func newCancellationContext() (context.Context, context.CancelFunc) { // First signal cancels the context, giving commands a chance to clean up. // Second signal force-exits immediately. diff --git a/cmd/cog-kong/main_test.go b/cmd/cog-kong/main_test.go new file mode 100644 index 0000000000..187fffae29 --- /dev/null +++ b/cmd/cog-kong/main_test.go @@ -0,0 +1,445 @@ +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + + "github.com/alecthomas/kong" + "github.com/stretchr/testify/require" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/util/console" +) + +func newTestParser(t *testing.T) *kong.Kong { + t.Helper() + // Override Exit so Kong's built-in --help flag (which exits the process + // after printing help, before validation) doesn't kill the test process. + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) + require.NoError(t, err) + parser.Stdout = io.Discard + parser.Stderr = io.Discard + return parser +} + +// parseForTest parses args, tolerating Kong's built-in --help short-circuit, +// which prints help and calls Exit(0) (translated to a testExitCode(0) panic by +// newTestParser). It returns the parse error for non-help cases. +func parseForTest(t *testing.T, parser *kong.Kong, args []string) (err error) { + t.Helper() + defer func() { + if r := recover(); r != nil { + code, ok := r.(testExitCode) + require.Truef(t, ok, "unexpected panic parsing %v: %v", args, r) + require.Equalf(t, testExitCode(0), code, "non-zero exit parsing %v", args) + } + }() + _, err = parser.Parse(args) + return err +} + +type kongGlobalState struct { + debug bool + noColor bool + profilingEnabled bool + registry string + consoleColor bool + consoleLevel console.Level + noColorEnv string + hadNoColorEnv bool +} + +func snapshotKongGlobalState() kongGlobalState { + noColorEnv, hadNoColorEnv := os.LookupEnv("NO_COLOR") + return kongGlobalState{ + debug: global.Debug, + noColor: global.NoColor, + profilingEnabled: global.ProfilingEnabled, + registry: global.ReplicateRegistryHost, + consoleColor: console.ConsoleInstance.Color, + consoleLevel: console.ConsoleInstance.Level, + noColorEnv: noColorEnv, + hadNoColorEnv: hadNoColorEnv, + } +} + +func restoreKongGlobalState(t *testing.T, state kongGlobalState) { + t.Helper() + global.Debug = state.debug + global.NoColor = state.noColor + global.ProfilingEnabled = state.profilingEnabled + global.ReplicateRegistryHost = state.registry + console.SetColor(state.consoleColor) + console.SetLevel(state.consoleLevel) + if state.hadNoColorEnv { + require.NoError(t, os.Setenv("NO_COLOR", state.noColorEnv)) + } else { + require.NoError(t, os.Unsetenv("NO_COLOR")) + } +} + +func preserveKongGlobalState(t *testing.T) kongGlobalState { + t.Helper() + state := snapshotKongGlobalState() + t.Cleanup(func() { + restoreKongGlobalState(t, state) + }) + return state +} + +type testExitCode int + +func newVersionTestParser(t *testing.T, stdout *bytes.Buffer) *kong.Kong { + t.Helper() + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) + require.NoError(t, err) + parser.Stdout = stdout + parser.Stderr = stdout + return parser +} + +func TestKongRegistersAllTopLevelCommands(t *testing.T) { + parser := newTestParser(t) + commands := map[string]bool{} + for _, node := range parser.Model.Children { + commands[node.Name] = true + } + + for _, name := range []string{ + "base-image", + "build", + "debug", + "doctor", + "exec", + "init", + "login", + "predict", + "push", + "run", + "serve", + "train", + "weights", + } { + require.Truef(t, commands[name], "missing command %q", name) + } +} + +func TestKongRegistersNestedCommands(t *testing.T) { + preserveKongGlobalState(t) + parser := newTestParser(t) + + for _, args := range [][]string{ + {"weights", "import", "--help"}, + {"weights", "pull", "--help"}, + {"weights", "status", "--help"}, + {"base-image", "dockerfile", "--help"}, + {"base-image", "build", "--help"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +func TestKongRootHelpParses(t *testing.T) { + preserveKongGlobalState(t) + // Override Exit so the built-in --help flag doesn't terminate the process, + // and capture stdout to assert on the rendered help output. + var stdout bytes.Buffer + parser, err := newParser(t.Context(), &CLI{}, kong.Exit(func(code int) { + panic(testExitCode(code)) + })) + require.NoError(t, err) + parser.Stdout = &stdout + parser.Stderr = &stdout + + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"--help"}) + }) + + help := stdout.String() + require.Contains(t, help, "Usage: cog [flags]") + require.Contains(t, help, "build") + require.Contains(t, help, "push") + require.NotContains(t, help, "Usage: cog default") + + // Commands hidden in Cobra must also be hidden in the Kong model so they + // don't appear in the root help output. + hiddenByName := map[string]bool{} + for _, node := range parser.Model.Children { + if node.Hidden { + hiddenByName[node.Name] = true + } + } + for _, name := range []string{"predict", "train", "weights", "debug"} { + require.Truef(t, hiddenByName[name], "command %q should be hidden", name) + } + // Visible commands (notably run, the non-hidden twin of predict) must NOT be + // hidden — guards against accidentally hiding the wrong prediction command. + for _, name := range []string{"run", "build", "push", "serve", "exec", "init", "login", "doctor", "base-image"} { + require.Falsef(t, hiddenByName[name], "command %q should be visible", name) + } +} + +func TestKongRootGlobalFlagsParse(t *testing.T) { + state := preserveKongGlobalState(t) + + for _, args := range [][]string{ + {"--debug", "build", "--help"}, + {"--no-color", "build", "--help"}, + {"--profile", "build", "--help"}, + {"--registry", "example.com", "build", "--help"}, + } { + restoreKongGlobalState(t, state) + parser := newTestParser(t) + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } + + restoreKongGlobalState(t, state) + var stdout bytes.Buffer + versionParser := newVersionTestParser(t, &stdout) + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = versionParser.Parse([]string{"--version"}) + }) +} + +func TestKongHelpParsingDoesNotWriteUpdateState(t *testing.T) { + preserveKongGlobalState(t) + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("COG_NO_UPDATE_CHECK", "") + + for _, args := range [][]string{ + {"--help"}, + {"build", "--help"}, + } { + parser := newTestParser(t) + _ = parseForTest(t, parser, args) + require.NoFileExists(t, filepath.Join(home, ".config", "cog", "update-state.json"), "parse %v", args) + } +} + +func TestKongNoColorAfterApplySetsGlobalNoColor(t *testing.T) { + preserveKongGlobalState(t) + require.NoError(t, os.Unsetenv("NO_COLOR")) + global.NoColor = false + + globals := Globals{NoColor: true, Registry: global.ReplicateRegistryHost} + require.NoError(t, globals.AfterApply()) + + require.True(t, global.NoColor) + require.Equal(t, "1", os.Getenv("NO_COLOR")) +} + +func TestKongVersionExitsBeforeRootUsage(t *testing.T) { + var stdout bytes.Buffer + parser := newVersionTestParser(t, &stdout) + + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"--version"}) + }) + require.Equal(t, "cog version dev (built none)\n", stdout.String()) +} + +func TestKongCommandVersionExitsBeforeCommandRun(t *testing.T) { + var stdout bytes.Buffer + parser := newVersionTestParser(t, &stdout) + + require.PanicsWithValue(t, testExitCode(0), func() { + _, _ = parser.Parse([]string{"build", "--version"}) + }) + require.Equal(t, "cog version dev (built none)\n", stdout.String()) +} + +func TestKongRuntimeCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"serve", "--port", "5000", "--upload-url", "https://example.com/upload", "--gpus", "all", "--help"}, + {"exec", "--gpus", "all", "--publish", "8888", "--env", "A=B", "python", "-c", "print(1)"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +func TestKongSimpleCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"init", "--help"}, + {"login", "--token-stdin", "--help"}, + {"doctor", "--fix", "--file", "custom.yaml", "--help"}, + {"debug", "--image-name", "myimage", "--help"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +func TestKongPredictionCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"predict", "example/image", "--input", "prompt=cat", "--help"}, + {"run", "example/image", "--input", "prompt=cat", "--output", "out.png", "--help"}, + {"run", "--json", "@inputs.json", "--gpus", "all", "--help"}, + {"train", "example/image", "--input", "dataset=@data.json", "--help"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +func TestKongWeightsCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"weights", "import", "--dry-run", "--verbose", "model.safetensors", "--help"}, + {"weights", "pull", "--verbose", "weights-name", "--help"}, + {"weights", "status", "--json", "--verbose", "--help"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +func TestKongBaseImageCommandFlagsParse(t *testing.T) { + parser := newTestParser(t) + + for _, args := range [][]string{ + {"base-image", "dockerfile", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--no-cache", "--progress", "plain", "--help"}, + {"base-image", "build", "--cuda", "12.4", "--python", "3.12", "--torch", "2.5.0", "--help"}, + {"base-image", "generate-matrix", "--cuda", "12.4", "--python", "3.12", "--help"}, + } { + require.NoErrorf(t, parseForTest(t, parser, args), "parse %v", args) + } +} + +// TestKongBuildOnlyFlagsParity asserts that --timestamp and +// --skip-schema-validation are build-only (not on push), matching Cobra, and +// that build maps them through to the shared options. +func TestKongBuildOnlyFlagsParity(t *testing.T) { + preserveKongGlobalState(t) + + // build accepts the build-only flags and maps them to options. + var buildCLI CLI + buildParser, err := newParser(t.Context(), &buildCLI) + require.NoError(t, err) + _, err = buildParser.Parse([]string{"build", "--timestamp", "42", "--skip-schema-validation", "--tag", "img:latest"}) + require.NoError(t, err) + opts := buildCLI.Build.options() + require.Equal(t, int64(42), opts.Timestamp) + require.True(t, opts.SkipSchemaValidation) + + // push does NOT expose either flag (parity with Cobra push). + for _, flag := range []string{"--timestamp", "--skip-schema-validation"} { + pushParser := newTestParser(t) + _, err := pushParser.Parse([]string{"push", flag, "x", "img:latest"}) + require.Errorf(t, err, "push should reject %s", flag) + require.Containsf(t, err.Error(), "unknown flag", "push should reject %s as unknown", flag) + } + + // push still defaults Timestamp to -1 (timestamp rewriting disabled), not 0. + require.Equal(t, int64(-1), (&BuildFlags{}).Options().Timestamp) +} + +// TestKongEnvFlagParity asserts --env is available on exec/predict/run/train but +// NOT on serve, matching the Cobra CLI (Cobra serve has no --env). +func TestKongEnvFlagParity(t *testing.T) { + // serve has no --env. + serveParser := newTestParser(t) + _, err := serveParser.Parse([]string{"serve", "--env", "A=B"}) + require.Error(t, err, "serve should reject --env") + require.Contains(t, err.Error(), "unknown flag") + + // exec/predict/run/train accept --env and thread it through. + var execCLI CLI + execParser, err := newParser(t.Context(), &execCLI) + require.NoError(t, err) + _, err = execParser.Parse([]string{"exec", "--env", "A=B", "echo", "hi"}) + require.NoError(t, err) + require.Equal(t, []string{"A=B"}, execCLI.Exec.Env) + + var runCLI CLI + runParser, err := newParser(t.Context(), &runCLI) + require.NoError(t, err) + _, err = runParser.Parse([]string{"run", "--env", "A=B", "img"}) + require.NoError(t, err) + require.Equal(t, []string{"A=B"}, runCLI.RunCommand.options("run").Env) +} + +// TestKongExecRequiresCommand asserts the zero-arg exec error matches Cobra's +// MinimumNArgs(1) message. +func TestKongExecRequiresCommand(t *testing.T) { + cmd := &ExecCmd{} + err := cmd.Validate() + require.Error(t, err) + require.Equal(t, "requires at least 1 arg(s), only received 0", err.Error()) +} + +// TestKongCommandCoverageMatchesCobraSurface derives the expected command set +// and hidden status directly from the real Cobra command tree +// (cli.NewRootCommand and cli.NewBaseImageRootCommand) and asserts the Kong +// model matches it node-by-node. Deriving from the actual Cobra surface (rather +// than a hand-maintained list) means any future Cobra command/hidden-flag change +// that the Kong CLI doesn't mirror will fail this test. +func TestKongCommandCoverageMatchesCobraSurface(t *testing.T) { + parser := newTestParser(t) + + // Kong top-level command name -> hidden. + kongCmds := map[string]bool{} + kongChildren := map[string]*kong.Node{} + for _, node := range parser.Model.Children { + if node.Type != kong.CommandNode { + continue + } + kongCmds[node.Name] = node.Hidden + kongChildren[node.Name] = node + } + + // Expected top-level surface from the real Cobra root command. + root, err := cli.NewRootCommand() + require.NoError(t, err) + expected := map[string]bool{} + for _, c := range root.Commands() { + expected[c.Name()] = c.Hidden + } + // Cobra ships base-image as a separate binary; Kong folds it under + // `cog base-image`, so add it to the expected top-level surface. + expected["base-image"] = false + + require.Len(t, kongCmds, len(expected), "top-level command count mismatch: kong=%v expected=%v", keys(kongCmds), keys(expected)) + for name, hidden := range expected { + gotHidden, ok := kongCmds[name] + require.Truef(t, ok, "Kong is missing command %q present in Cobra", name) + require.Equalf(t, hidden, gotHidden, "hidden mismatch for command %q", name) + } + + // base-image subcommands must match the Cobra base-image binary. + bi, err := cli.NewBaseImageRootCommand() + require.NoError(t, err) + expectedBase := map[string]bool{} + for _, c := range bi.Commands() { + expectedBase[c.Name()] = c.Hidden + } + kongBase := map[string]bool{} + require.NotNil(t, kongChildren["base-image"], "Kong missing base-image group") + for _, node := range kongChildren["base-image"].Children { + if node.Type != kong.CommandNode { + continue + } + kongBase[node.Name] = node.Hidden + } + require.Equal(t, expectedBase, kongBase, "base-image subcommand surface mismatch") +} + +func keys(m map[string]bool) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} diff --git a/cmd/cog-kong/predict.go b/cmd/cog-kong/predict.go new file mode 100644 index 0000000000..c273b85c04 --- /dev/null +++ b/cmd/cog-kong/predict.go @@ -0,0 +1,54 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" +) + +// predictionFlags groups the flags shared by the predict and run commands. +type predictionFlags struct { + RuntimeFlags `embed:""` + EnvFlag `embed:""` + + Image string `arg:"" optional:"" name:"image" help:"Image to run. If omitted, builds from cog.yaml in the current directory."` + Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` + Output string `name:"output" short:"o" help:"Output path."` + JSON string `name:"json" help:"Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-)."` + UseReplicate bool `name:"use-replicate-token" help:"Pass REPLICATE_API_TOKEN from local environment into the model context."` + SetupTimeout uint32 `name:"setup-timeout" default:"300" help:"The timeout for a container to setup (in seconds)."` +} + +func (f predictionFlags) options(use string) cli.PredictionCommandOptions { + rbo := f.Options() + rbo.Env = f.Env + return cli.PredictionCommandOptions{ + RuntimeBuildOptions: rbo, + Use: use, + Image: f.Image, + Input: f.Input, + InputJSON: f.JSON, + OutputPath: f.Output, + SetupTimeout: f.SetupTimeout, + UseReplicateAPI: f.UseReplicate, + } +} + +// PredictCmd implements the hidden, deprecated "cog predict" command. +type PredictCmd struct { + predictionFlags `embed:""` +} + +func (cmd *PredictCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunPrediction(ctx, dockerClient, cmd.options("predict")) +} + +// RunCmd implements the "cog run" command. +type RunCmd struct { + predictionFlags `embed:""` +} + +func (cmd *RunCmd) Run(ctx context.Context, dockerClient command.Command) error { + return cli.RunPrediction(ctx, dockerClient, cmd.options("run")) +} diff --git a/cmd/cog-kong/push.go b/cmd/cog-kong/push.go index 9a92411024..d505871b78 100644 --- a/cmd/cog-kong/push.go +++ b/cmd/cog-kong/push.go @@ -2,15 +2,11 @@ package main import ( "context" - "fmt" - - "github.com/replicate/go/uuid" + "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/registry" - "github.com/replicate/cog/pkg/util/console" ) // PushCmd implements the "cog push" command. @@ -20,67 +16,11 @@ type PushCmd struct { Image string `arg:"" optional:"" help:"Image name to push (e.g. registry.example.com/user/model)."` } -// Validate is called by Kong after parsing, before Run. -func (cmd *PushCmd) Validate() error { - return cmd.ValidateMutualExclusivity() -} - -// Run executes the push command: build then push. -func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry, src *model.Source) error { - imageName := src.Config.Image - if cmd.Image != "" { - imageName = cmd.Image - } - - if imageName == "" { - return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") - } - - // Look up the provider for the target registry - p := providerReg.ForImage(imageName) - if p == nil { - return fmt.Errorf("no provider found for image '%s'", imageName) - } - - pushOpts := provider.PushOptions{ - Image: imageName, - Config: src.Config, - ProjectDir: src.ProjectDir, - } - - // Generate a push ID for annotations - buildID, _ := uuid.NewV7() - annotations := map[string]string{} - if buildID.String() != "" { - annotations["run.cog.push_id"] = buildID.String() - } - - // Build the model - resolver := model.NewResolver(dockerClient, regClient) - m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, annotations)) - if err != nil { - _ = p.PostPush(ctx, pushOpts, err) - return err - } - - // Log weights info - weights := m.WeightArtifacts() - if len(weights) > 0 { - console.Infof("\n%d weight artifact(s)", len(weights)) - } - - // Push the model - console.Infof("\nPushing image '%s'...", m.ImageRef()) - pushErr := resolver.Push(ctx, m, model.PushOptions{}) - - // PostPush: the provider handles formatting errors and showing success messages - if err := p.PostPush(ctx, pushOpts, pushErr); err != nil { - return err - } - - if pushErr != nil { - return fmt.Errorf("failed to push image: %w", pushErr) - } - - return nil +// Run executes the push command via the shared cli.RunPush runner. +func (cmd *PushCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry) error { + return cli.RunPush(ctx, dockerClient, regClient, providerReg, cli.PushCommandOptions{ + ConfigFilename: cmd.File, + Image: cmd.Image, + Flags: cmd.Options(), + }) } diff --git a/cmd/cog-kong/runtime.go b/cmd/cog-kong/runtime.go new file mode 100644 index 0000000000..d0a2413ae3 --- /dev/null +++ b/cmd/cog-kong/runtime.go @@ -0,0 +1,88 @@ +package main + +import ( + "context" + "errors" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/registry" +) + +// errMissingExecCommand matches Cobra's cobra.MinimumNArgs(1) error message for +// the exec command. +var errMissingExecCommand = errors.New("requires at least 1 arg(s), only received 0") + +// RuntimeFlags groups the build/run flags shared by the serve, exec, predict, +// run, and train commands. Note that --env is NOT part of this group: Cobra's +// serve command does not expose --env, so commands that accept it (exec, +// predict, run, train) declare their own EnvFlag. +type RuntimeFlags struct { + ConfigFlag `embed:""` + + Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."` + CudaBase string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` + CogBase *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + GPUs string `name:"gpus" help:"GPU devices to add to the container, in the same format as docker run --gpus."` +} + +// EnvFlag is an embeddable --env flag for the commands that accept it (exec, +// predict, run, train). Cobra's serve command has no --env, so ServeCmd omits +// this group. +type EnvFlag struct { + Env []string `name:"env" short:"e" help:"Environment variables, in the form name=value."` +} + +// Options converts the Kong runtime flags into cli.RuntimeBuildOptions. Env is +// populated by callers that embed EnvFlag. +func (f RuntimeFlags) Options() cli.RuntimeBuildOptions { + return cli.RuntimeBuildOptions{ + ConfigFilename: f.File, + ProgressOutput: f.Progress, + UseCudaBaseImage: f.CudaBase, + UseCogBaseImage: f.CogBase, + GPUs: f.GPUs, + } +} + +// ServeCmd implements the "cog serve" command. +type ServeCmd struct { + RuntimeFlags `embed:""` + + Port int `name:"port" short:"p" default:"8393" help:"Port on which to listen."` + UploadURL string `name:"upload-url" help:"Upload URL for file outputs (e.g. https://example.com/upload/)."` +} + +func (cmd *ServeCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunServe(ctx, dockerClient, regClient, cli.ServeCommandOptions{ + RuntimeBuildOptions: cmd.Options(), + Port: cmd.Port, + UploadURL: cmd.UploadURL, + }) +} + +// ExecCmd implements the "cog exec" command. +type ExecCmd struct { + RuntimeFlags `embed:""` + EnvFlag `embed:""` + + Publish []string `name:"publish" short:"p" help:"Publish a container's port to the host, e.g. -p 8000."` + Args []string `arg:"" passthrough:"" name:"command" help:"Command and arguments to execute."` +} + +func (cmd *ExecCmd) Validate() error { + if len(cmd.Args) == 0 { + return errMissingExecCommand + } + return nil +} + +func (cmd *ExecCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + opts := cmd.Options() + opts.Env = cmd.Env + return cli.RunExec(ctx, dockerClient, regClient, cli.ExecCommandOptions{ + RuntimeBuildOptions: opts, + Args: cmd.Args, + Ports: cmd.Publish, + }) +} diff --git a/cmd/cog-kong/simple.go b/cmd/cog-kong/simple.go new file mode 100644 index 0000000000..0c0a2d74ef --- /dev/null +++ b/cmd/cog-kong/simple.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/global" + "github.com/replicate/cog/pkg/provider" + "github.com/replicate/cog/pkg/registry" +) + +// InitCmd implements the "cog init" command. +type InitCmd struct{} + +func (cmd *InitCmd) Run() error { return cli.RunInit() } + +// LoginCmd implements the "cog login" command. +type LoginCmd struct { + TokenStdin bool `name:"token-stdin" help:"Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token"` +} + +func (cmd *LoginCmd) Run(ctx context.Context, providerReg *provider.Registry) error { + return cli.RunLogin(ctx, providerReg, cli.LoginOptions{ + TokenStdin: cmd.TokenStdin, + Host: global.ReplicateRegistryHost, + }) +} + +// DoctorCmd implements the "cog doctor" command. +type DoctorCmd struct { + ConfigFlag `embed:""` + + Fix bool `name:"fix" help:"Automatically apply fixes."` +} + +func (cmd *DoctorCmd) Run(ctx context.Context) error { + return cli.RunDoctor(ctx, cmd.File, cmd.Fix) +} + +// DebugCmd implements the hidden "cog debug" command, which prints a generated +// Dockerfile. +type DebugCmd struct { + ConfigFlag `embed:""` + + ImageName string `name:"image-name" help:"The image name to use for the generated Dockerfile."` + SeparateWeights bool `name:"separate-weights" help:"Separate model weights from code in image layers."` + UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image."` + UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."` + + // Hidden flag for parity with the Cobra debug command. RunDebug ignores it + // (it is a no-op on both CLIs), but it is accepted for surface parity. + Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."` +} + +func (cmd *DebugCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client) error { + return cli.RunDebug(ctx, dockerClient, regClient, cli.DebugCommandOptions{ + ConfigFilename: cmd.File, + ImageName: cmd.ImageName, + SeparateWeights: cmd.SeparateWeights, + UseCudaBaseImage: cmd.UseCudaBaseImage, + UseCogBaseImage: cmd.UseCogBaseImage, + }) +} diff --git a/cmd/cog-kong/train.go b/cmd/cog-kong/train.go new file mode 100644 index 0000000000..1caaa15e45 --- /dev/null +++ b/cmd/cog-kong/train.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" + "github.com/replicate/cog/pkg/docker/command" + "github.com/replicate/cog/pkg/util/console" +) + +// TrainCmd implements the hidden, deprecated "cog train" command. +type TrainCmd struct { + RuntimeFlags `embed:""` + EnvFlag `embed:""` + + Image string `arg:"" optional:"" name:"image" help:"Image to train. If omitted, builds from cog.yaml in the current directory."` + Input []string `name:"input" short:"i" help:"Inputs, in the form name=value. If value is prefixed with @, it is read from a file on disk, e.g. -i path=@image.jpg."` + Output string `name:"output" short:"o" default:"weights" help:"Output path."` +} + +func (cmd *TrainCmd) Run(ctx context.Context, dockerClient command.Command) error { + // Match Cobra's Deprecated notice for the train command. + console.Warn("Command \"train\" is deprecated, the train command will be removed in a future version of Cog") + opts := cmd.Options() + opts.Env = cmd.Env + return cli.RunTrain(ctx, dockerClient, cli.TrainCommandOptions{ + RuntimeBuildOptions: opts, + Image: cmd.Image, + Input: cmd.Input, + OutputPath: cmd.Output, + }) +} diff --git a/cmd/cog-kong/weights.go b/cmd/cog-kong/weights.go new file mode 100644 index 0000000000..900b40b6f2 --- /dev/null +++ b/cmd/cog-kong/weights.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + + "github.com/replicate/cog/pkg/cli" +) + +// WeightsCmd implements the hidden, experimental "cog weights" command group. +type WeightsCmd struct { + Import WeightsImportCmd `cmd:"" help:"Build and push weights to a registry."` + Pull WeightsPullCmd `cmd:"" help:"Populate the local weight cache from the registry."` + Status WeightsStatusCmd `cmd:"" help:"Show the status of configured weights."` +} + +// WeightsImportCmd implements "cog weights import". +type WeightsImportCmd struct { + ConfigFlag `embed:""` + + DryRun bool `name:"dry-run" help:"Show what would be imported without making changes."` + Verbose bool `name:"verbose" short:"v" help:"Show per-file details."` + Names []string `arg:"" optional:"" name:"name" help:"Weight names to import. If omitted, all weights in cog.yaml are imported."` +} + +func (cmd *WeightsImportCmd) Run(ctx context.Context) error { + return cli.RunWeightsImport(ctx, cmd.File, cmd.Names, cmd.DryRun, cmd.Verbose) +} + +// WeightsPullCmd implements "cog weights pull". +type WeightsPullCmd struct { + ConfigFlag `embed:""` + + Verbose bool `name:"verbose" short:"v" help:"Show per-layer and per-file progress."` + Names []string `arg:"" optional:"" name:"name" help:"Weight names to pull. If omitted, all weights in cog.yaml are pulled."` +} + +func (cmd *WeightsPullCmd) Run(ctx context.Context) error { + return cli.RunWeightsPull(ctx, cmd.File, cmd.Names, cmd.Verbose) +} + +// WeightsStatusCmd implements "cog weights status". +type WeightsStatusCmd struct { + ConfigFlag `embed:""` + + JSON bool `name:"json" help:"Output as JSON."` + Verbose bool `name:"verbose" short:"v" help:"Show per-layer status."` +} + +func (cmd *WeightsStatusCmd) Run(ctx context.Context) error { + return cli.RunWeightsStatus(ctx, cmd.File, cmd.JSON, cmd.Verbose) +} diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 74e832ef9e..e8b1bbb932 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -259,7 +259,7 @@ dependencies = [ [[package]] name = "coglet" -version = "0.21.0-rc.1" +version = "0.21.0-rc.3" dependencies = [ "anyhow", "async-trait", @@ -291,7 +291,7 @@ dependencies = [ [[package]] name = "coglet-python" -version = "0.21.0-rc.1" +version = "0.21.0-rc.3" dependencies = [ "async-trait", "base64", diff --git a/crates/Cargo.toml b/crates/Cargo.toml index b95a914547..619ee6aa35 100644 --- a/crates/Cargo.toml +++ b/crates/Cargo.toml @@ -3,7 +3,7 @@ resolver = "2" members = ["coglet", "coglet-python"] [workspace.package] -version = "0.21.0-rc.1" +version = "0.21.0-rc.3" edition = "2024" license = "Apache-2.0" repository = "https://github.com/replicate/cog" diff --git a/docs/llms.txt b/docs/llms.txt index e8e9e1f7c7..2044b2f6d7 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -2070,6 +2070,7 @@ This document defines the API of the `cog` Python module, which is used to defin - [`cog.Secret`](#cogsecret) - [Wrapper types](#wrapper-types) - [`Optional`](#optional) + - [`Union`](#union) - [`list`](#list) - [`dict`](#dict) - [`cog.Opaque`](#cogopaque) @@ -2652,6 +2653,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. @@ -2813,7 +2838,8 @@ Fields in a `BaseModel` output support these types: The following type patterns are **not** supported: - **Nested generics**: `list[list[str]]`, `list[Optional[str]]`, `Optional[list[str]]` are not supported. -- **Union types beyond Optional**: `str | int`, `Union[str, int, None]` — only `Optional[T]` (i.e. `T | None`) is supported. +- **Output union types beyond Optional**: union _return_ types and `BaseModel` union fields are not supported. Input unions of JSON-native types (`str | int`, `str | float | None`, etc.) _are_ supported — see [`Union`](#union). +- **Input unions of non-JSON-native types**: input unions involving `Path`, `File`, `Secret`, custom coders, or `BaseModel` (e.g. `Path | str`) are not supported and fail at build time. - **`Optional` as a top-level return type**: `-> Optional[str]` is not allowed. Use a `BaseModel` with optional fields instead. - **Nested `BaseModel` fields**: A `BaseModel` field typed as another `BaseModel` is not supported in Cog's type system for schema generation. - **Tuple, Set, or other collection types**: Only `list` and `dict` are supported as collection types. diff --git a/docs/python.md b/docs/python.md index 5f5619a9af..fe9307aa3b 100644 --- a/docs/python.md +++ b/docs/python.md @@ -38,6 +38,7 @@ This document defines the API of the `cog` Python module, which is used to defin - [`cog.Secret`](#cogsecret) - [Wrapper types](#wrapper-types) - [`Optional`](#optional) + - [`Union`](#union) - [`list`](#list) - [`dict`](#dict) - [`cog.Opaque`](#cogopaque) @@ -620,6 +621,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. @@ -781,7 +806,8 @@ Fields in a `BaseModel` output support these types: The following type patterns are **not** supported: - **Nested generics**: `list[list[str]]`, `list[Optional[str]]`, `Optional[list[str]]` are not supported. -- **Union types beyond Optional**: `str | int`, `Union[str, int, None]` — only `Optional[T]` (i.e. `T | None`) is supported. +- **Output union types beyond Optional**: union _return_ types and `BaseModel` union fields are not supported. Input unions of JSON-native types (`str | int`, `str | float | None`, etc.) _are_ supported — see [`Union`](#union). +- **Input unions of non-JSON-native types**: input unions involving `Path`, `File`, `Secret`, custom coders, or `BaseModel` (e.g. `Path | str`) are not supported and fail at build time. - **`Optional` as a top-level return type**: `-> Optional[str]` is not allowed. Use a `BaseModel` with optional fields instead. - **Nested `BaseModel` fields**: A `BaseModel` field typed as another `BaseModel` is not supported in Cog's type system for schema generation. - **Tuple, Set, or other collection types**: Only `list` and `dict` are supported as collection types. diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 014d770ed6..6512897783 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -29,12 +29,14 @@ import ( // into testscript environments (Setup) and background processes (cmdCogServe). // Keep this list in sync: if you add a new env var to propagate, add it here. var propagatedEnvVars = []string{ - "COG_SDK_WHEEL", // SDK wheel override - "COGLET_WHEEL", // coglet wheel override - "RUST_LOG", // Rust logging control - "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) - "BUILDKIT_PROGRESS", // Docker build output format - "COG_REGISTRY_HOST", // registry host for cog base image resolution + "COG_SDK_WHEEL", // SDK wheel override + "COGLET_WHEEL", // coglet wheel override + "COG_CACHE_DIR", // isolated cache for managed weights + "COG_MODEL_REGISTRY", // test registry override for model refs + "RUST_LOG", // Rust logging control + "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) + "BUILDKIT_PROGRESS", // Docker build output format + "COG_REGISTRY_HOST", // registry host for cog base image resolution } // Harness provides utilities for running cog integration tests. @@ -242,6 +244,7 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool // Built-in commands (defined in this file) NewCommand("cog", h.cmdCog), NewCommand("curl", h.cmdCurl), + NewCommand("server-stop", h.cmdServerStop), NewCommand("wait-for", h.cmdWaitFor), NewCommand("docker-run", h.cmdDockerRun), @@ -618,6 +621,17 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { ts.Fatalf("curl: all %d attempts failed with status %d: %s", maxAttempts, lastStatus, errorMsg) } +// cmdServerStop stops the background 'cog serve' process for the current test. +func (h *Harness) cmdServerStop(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("server-stop: negation is not supported") + } + if len(args) != 0 { + ts.Fatalf("server-stop: usage: server-stop") + } + h.StopServer(ts) +} + // StopServer stops the background server process for a test script. func (h *Harness) StopServer(ts *testscript.TestScript) { workDir := ts.Getenv("WORK") @@ -643,11 +657,24 @@ func (h *Harness) stopServerByWorkDir(workDir string) { _ = resp.Body.Close() } - // Force kill the cog process if still running + // Give the cog process time to exit after /shutdown so defers can run. + done := make(chan struct{}) + go func() { + _ = info.cmd.Wait() + close(done) + }() + + // Force kill the cog process if still running. if info.cmd.Process != nil { - _ = info.cmd.Process.Kill() + select { + case <-done: + case <-time.After(5 * time.Second): + _ = info.cmd.Process.Kill() + <-done + } + } else { + <-done } - _ = info.cmd.Wait() // Also kill any Docker container that may still be running on this port // Find container by port and kill it diff --git a/integration-tests/tests/union_input_cli.txtar b/integration-tests/tests/union_input_cli.txtar new file mode 100644 index 0000000000..aeebd7fa4d --- /dev/null +++ b/integration-tests/tests/union_input_cli.txtar @@ -0,0 +1,45 @@ +# Test schema-directed CLI parsing for JSON-native union inputs. +# +# value: str | float (string member first) +# flipped: float | str (number member first) +# +# - "hello" parses as a string because no numeric parse succeeds +# - "1.5" parses as a float because the union accepts a number +# - "1" parses as an integer (still a valid JSON number for the union) +# +# The `flipped` field exercises the case where resolveSchemaType resolves a +# union to its numeric member first: a non-numeric value must still fall back +# to the string member instead of erroring. +# +# Note: the worker does not coerce primitives at runtime (validation happens +# at the HTTP edge against the OpenAPI schema), so the CLI must choose the +# wire type. A bare integer stays an integer; only fractional values become +# floats. This matches how a plain `float` input also receives a Python int +# for `-i num=10`. + +cog build -t $TEST_IMAGE + +cog predict $TEST_IMAGE -i value=hello -i flipped=world +stdout 'value=str:hello flipped=str:world' + +cog predict $TEST_IMAGE -i value=1.5 -i flipped=2.5 +stdout 'value=float:1.5 flipped=float:2.5' + +cog predict $TEST_IMAGE -i value=1 -i flipped=2 +stdout 'value=int:1 flipped=int:2' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, value: str | float, flipped: float | str) -> str: + return ( + f"value={type(value).__name__}:{value} " + f"flipped={type(flipped).__name__}:{flipped}" + ) diff --git a/integration-tests/tests/union_input_http.txtar b/integration-tests/tests/union_input_http.txtar new file mode 100644 index 0000000000..7ec032c426 --- /dev/null +++ b/integration-tests/tests/union_input_http.txtar @@ -0,0 +1,54 @@ +# Test JSON-native union inputs over cog serve. +# value: str | float | None = Input(default=None) +# - string accepted, returns "str:..." +# - float accepted, returns "float:..." +# - integer accepted (valid JSON number), returns "int:..." +# - omitted optional defaults to None, returns "NoneType:none" +# - explicit null is rejected, matching how every optional field behaves +# today (validation happens at the HTTP edge and the runtime validator +# does not accept explicit JSON null for a typed field) +# - bool rejected (not a member of str | float), returns a validation error + +cog build -t $TEST_IMAGE + +cog serve + +# String member +curl POST /predictions '{"input":{"value":"hello"}}' +stdout '"output":"str:hello"' + +# Float member +curl POST /predictions '{"input":{"value":1.5}}' +stdout '"output":"float:1.5"' + +# Integer is a valid JSON number for the union; passed through as int +curl POST /predictions '{"input":{"value":1}}' +stdout '"output":"int:1"' + +# Omitted optional value defaults to None +curl POST /predictions '{"input":{}}' +stdout '"output":"NoneType:none"' + +# Explicit null is rejected, consistent with all optional fields +! curl POST /predictions '{"input":{"value":null}}' + +# bool is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":true}}' + +# nested object is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":{"x":1}}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, value: str | float | None = Input(default=None)) -> str: + if value is None: + return "NoneType:none" + return f"{type(value).__name__}:{value}" diff --git a/integration-tests/tests/union_input_list_http.txtar b/integration-tests/tests/union_input_list_http.txtar new file mode 100644 index 0000000000..e36fa90be5 --- /dev/null +++ b/integration-tests/tests/union_input_list_http.txtar @@ -0,0 +1,48 @@ +# Test list JSON-native union inputs over cog serve. +# +# nums: list[int] | list[float] (required list union) +# +# - list[int] | list[float] accepts [1] and [1.5] and the empty list [] +# - list element types are validated: ["3"] and [true] are rejected +# - integer elements stay int, fractional elements are float (no runtime +# coercion: the wire type is preserved) + +cog build -t $TEST_IMAGE + +cog serve + +# Integer list element kept as int +curl POST /predictions '{"input":{"nums":[1]}}' +stdout '"output":"int:1"' + +# Float list element +curl POST /predictions '{"input":{"nums":[1.5]}}' +stdout '"output":"float:1.5"' + +# Empty list is accepted +curl POST /predictions '{"input":{"nums":[]}}' +stdout '"output":"empty"' + +# String element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":["3"]}}' + +# bool element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":[true]}}' + +# A bare scalar is not a list -> rejected +! curl POST /predictions '{"input":{"nums":1}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, nums: list[int] | list[float]) -> str: + if not nums: + return "empty" + return f"{type(nums[0]).__name__}:{nums[0]}" diff --git a/integration-tests/tests/weights_import_serve.txtar b/integration-tests/tests/weights_import_serve.txtar new file mode 100644 index 0000000000..47ada1ffe0 --- /dev/null +++ b/integration-tests/tests/weights_import_serve.txtar @@ -0,0 +1,69 @@ +# End-to-end test that `cog weights import` warms the local store +# enough for `cog serve` to mount the weights without a separate +# `cog weights pull`. +# +# This is the serve-side companion to weights_import_predict.txtar. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Point the model ref at the ephemeral test registry. cog.yaml carries +# the bare repo via 'model:'; COG_MODEL_REGISTRY swaps in the host +# now that registry-start has assigned a port. +env COG_MODEL_REGISTRY=$TEST_REGISTRY + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Step 1: import. After this, the local content store under +# $COG_CACHE_DIR/weights/files/sha256/ must contain greeting.txt's +# bytes — that's the cog-i12u guarantee. +cog weights import +exists weights.lock + +# Step 2: serve, with NO intervening `cog weights pull`. The +# import path warmed the store; serve's hardlink-assemble must +# work straight off the back of import. +cog serve + +# Step 3: predict via HTTP API and verify the weight is mounted. +curl POST /predictions '{"input":{"s":"world"}}' +stdout '"status":"succeeded"' +stdout 'hello world \(weight-size=1024\)' + +# Per-invocation mount dir must be cleaned up after serve exits. +server-stop +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# 'model:' carries the bare repo. The registry host is set via +# COG_MODEL_REGISTRY in the test body since $TEST_REGISTRY varies +# per run. +build: + python_version: "3.12" +predict: "predict.py:Predictor" +model: test/import-serve-model +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/integration-tests/tests/weights_pull_serve.txtar b/integration-tests/tests/weights_pull_serve.txtar new file mode 100644 index 0000000000..38388119f2 --- /dev/null +++ b/integration-tests/tests/weights_pull_serve.txtar @@ -0,0 +1,78 @@ +# End-to-end test of the managed-weights local-run flow via `cog serve`. +# Verifies: cog weights import -> cog weights pull -> cog serve, where +# the predictor container sees the weight files at the configured target +# path via the HTTP API. Also verifies mount cleanup after serve exits. +# +# This is the serve-side companion to weights_pull_predict.txtar. + +[short] skip 'requires local registry' + +env COG_CACHE_DIR=$WORK/cache + +registry-start + +# Point the model ref at the ephemeral test registry. cog serve +# picks this up via newWeightManager → resolveWeightRepo → ResolveModelRef. +env COG_MODEL_REGISTRY=$TEST_REGISTRY + +# Build a deterministic weight directory. +mkdir weights-src +exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "X" > weights-src/greeting.txt' + +# Step 1: import to the local registry. This also warms the local +# cache as a side effect of import (cog-i12u guarantee). +cog weights import +exists weights.lock + +# Purge the cache so step 2 exercises pull's cold-fetch path. The +# realistic scenario for `cog weights pull` is "lockfile checked in, +# cache empty" — e.g. a fresh clone. +exec sh -c 'rm -rf $WORK/cache/weights/files' + +# Step 2: pull into the isolated cache. +cog weights pull +stderr 'Pulled' + +# Step 3: serve. The predictor opens the file at the configured +# target path and returns its size, proving the mount is visible +# and readable inside the container. +cog serve + +curl POST /predictions '{"input":{"s":"world"}}' +stdout '"status":"succeeded"' +stdout 'hello world \(weight-size=1024\)' + +# Step 4: the per-invocation mount dir must be cleaned up after +# serve exits. .cog/mounts/ either doesn't exist (everything +# cleaned) or is empty. +server-stop +exec sh -c 'test ! -d .cog/mounts || test -z "$(ls -A .cog/mounts)"' + +-- cog.yaml -- +# 'model:' carries the bare repo. The registry host is set via +# COG_MODEL_REGISTRY in the test body since $TEST_REGISTRY varies +# per run. +build: + python_version: "3.12" +predict: "predict.py:Predictor" +model: test/pull-serve-model +weights: + - name: greeting + target: /src/mounted-weights/greeting + source: + uri: file://./weights-src + +-- predict.py -- +import os + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, s: str) -> str: + # Prove the weight was mounted at the target path by reading + # its size. Any failure to mount would raise FileNotFoundError + # here and the prediction would fail. + path = "/src/mounted-weights/greeting/greeting.txt" + size = os.path.getsize(path) + return f"hello {s} (weight-size={size})" diff --git a/mise.lock b/mise.lock index 57b2918a40..3fe3fb9ec2 100644 --- a/mise.lock +++ b/mise.lock @@ -39,30 +39,37 @@ backend = "aqua:golangci/golangci-lint" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-arm64-musl"] checksum = "sha256:6652b42ae02915eb2f9cb2a2e0cac99514c8eded8388d88ae3e06e1a52c00de8" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.linux-x64-musl"] checksum = "sha256:dfa775874cf0561b404a02a8f4481fc69b28091da95aa697259820d429b09c99" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-linux-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-arm64"] checksum = "sha256:03bfadf67e52b441b7ec21305e501c717df93c959836d66c7f97312654acb297" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-arm64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.macos-x64"] checksum = "sha256:66fb0da81b8033b477f97eea420d4b46b230ca172b8bb87c6610109f3772b6b6" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-darwin-amd64.tar.gz" +provenance = "github-attestations" [tools."aqua:golangci/golangci-lint"."platforms.windows-x64"] checksum = "sha256:c60c87695e79db8e320f0e5be885059859de52bb5ee5f11be5577828570bc2a3" url = "https://github.com/golangci/golangci-lint/releases/download/v2.10.1/golangci-lint-2.10.1-windows-amd64.zip" +provenance = "github-attestations" [[tools."aqua:gotestyourself/gotestsum"]] version = "1.13.0" @@ -96,6 +103,45 @@ url = "https://github.com/gotestyourself/gotestsum/releases/download/v1.13.0/got checksum = "sha256:fd5a6dc69e46a0970593e70d85a7e75f16714e9c61d6d72ccc324eb82df5bb8a" url = "https://github.com/gotestyourself/gotestsum/releases/download/v1.13.0/gotestsum_1.13.0_windows_amd64.tar.gz" +[[tools."aqua:jqlang/jq"]] +version = "1.8.1" +backend = "aqua:jqlang/jq" + +[tools."aqua:jqlang/jq"."platforms.linux-arm64"] +checksum = "sha256:6bc62f25981328edd3cfcfe6fe51b073f2d7e7710d7ef7fcdac28d4e384fc3d4" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-arm64-musl"] +checksum = "sha256:6bc62f25981328edd3cfcfe6fe51b073f2d7e7710d7ef7fcdac28d4e384fc3d4" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-x64"] +checksum = "sha256:020468de7539ce70ef1bceaf7cde2e8c4f2ca6c3afb84642aabc5c97d9fc2a0d" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.linux-x64-musl"] +checksum = "sha256:020468de7539ce70ef1bceaf7cde2e8c4f2ca6c3afb84642aabc5c97d9fc2a0d" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.macos-arm64"] +checksum = "sha256:a9fe3ea2f86dfc72f6728417521ec9067b343277152b114f4e98d8cb0e263603" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-macos-arm64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.macos-x64"] +checksum = "sha256:e80dbe0d2a2597e3c11c404f03337b981d74b4a8504b70586c354b7697a7c27f" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-macos-amd64" +provenance = "github-attestations" + +[tools."aqua:jqlang/jq"."platforms.windows-x64"] +checksum = "sha256:23cb60a1354eed6bcc8d9b9735e8c7b388cd1fdcb75726b93bc299ef22dd9334" +url = "https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-windows-amd64.exe" +provenance = "github-attestations" + [[tools."aqua:mitsuhiko/insta"]] version = "1.46.0" backend = "aqua:mitsuhiko/insta" @@ -161,16 +207,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-arm64-musl"] -checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" +checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.linux-x64-musl"] -checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" +checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -193,16 +239,16 @@ checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f8 url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-arm64-musl"] -checksum = "sha256:e3853c5a252fca15252d07cb23a1bdd9377a8c6f3efa01531109281ae47f841c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-gnu/rustup-init" +checksum = "sha256:a97c8f56d7462908695348dd8c71ea6740c138ce303715793a690503a94fc9a9" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/aarch64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64"] checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.linux-x64-musl"] -checksum = "sha256:20a06e644b0d9bd2fbdbfd52d42540bdde820ea7df86e92e533c073da0cdd43c" -url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-gnu/rustup-init" +checksum = "sha256:e6599a1c7be58a2d8eaca66a80e0dc006d87bbcf780a58b7343d6e14c1605cb2" +url = "https://static.rust-lang.org/rustup/archive/1.28.2/x86_64-unknown-linux-musl/rustup-init" [tools."aqua:rust-lang/rustup/rustup-init"."platforms.macos-arm64"] checksum = "sha256:20ef5516c31b1ac2290084199ba77dbbcaa1406c45c1d978ca68558ef5964ef5" @@ -248,32 +294,6 @@ url = "https://github.com/vektra/mockery/releases/download/v3.7.0/mockery_3.7.0_ checksum = "sha256:76d524ad1740cd02ed621d90015f538fdb53cf6cd4a1ad4d289db017fb69bd0e" url = "https://github.com/vektra/mockery/releases/download/v3.7.0/mockery_3.7.0_Windows_x86_64.tar.gz" -[[tools."aqua:ziglang/zig"]] -version = "0.15.2" -backend = "aqua:ziglang/zig" - -[tools."aqua:ziglang/zig"."platforms.linux-arm64"] -url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-arm64-musl"] -url = "https://ziglang.org/download/0.15.2/zig-aarch64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.linux-x64-musl"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-linux-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.macos-arm64"] -checksum = "blake3:c7d2fb746701fea2c070f66c29a0300ba42a0d5d2c09c493462b8c0f4f0bd604" -url = "https://ziglang.org/download/0.15.2/zig-aarch64-macos-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.macos-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-macos-0.15.2.tar.xz" - -[tools."aqua:ziglang/zig"."platforms.windows-x64"] -url = "https://ziglang.org/download/0.15.2/zig-x86_64-windows-0.15.2.zip" - [[tools.cargo-binstall]] version = "1.16.6" backend = "aqua:cargo-bins/cargo-binstall" @@ -282,10 +302,18 @@ backend = "aqua:cargo-bins/cargo-binstall" checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" +[tools.cargo-binstall."platforms.linux-arm64-musl"] +checksum = "sha256:b556421835ba67fa98ca1570c85b5511457956b7836ce938b47d3f73899517a3" +url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-unknown-linux-musl.tgz" + [tools.cargo-binstall."platforms.linux-x64"] checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" +[tools.cargo-binstall."platforms.linux-x64-musl"] +checksum = "sha256:3225eea8041c30d7462761a481883e3aa8fe31c58def4b6c8dd91b7c80973df0" +url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-unknown-linux-musl.tgz" + [tools.cargo-binstall."platforms.macos-arm64"] checksum = "sha256:30543b378b96fbddabee1edfaccde7914dd2f851f02c560de859f81a21ab665b" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-aarch64-apple-darwin.zip" @@ -298,22 +326,10 @@ url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/ca checksum = "sha256:fca962c3d12ae6192280111074db073c15abad3ba162a1a5a2af0f6f01872114" url = "https://github.com/cargo-bins/cargo-binstall/releases/download/v1.16.6/cargo-binstall-x86_64-pc-windows-msvc.zip" -[[tools."cargo:cargo-deny"]] -version = "0.19.0" -backend = "cargo:cargo-deny" - -[[tools."cargo:cargo-insta"]] -version = "1.46.0" -backend = "cargo:cargo-insta" - [[tools."cargo:cargo-nextest"]] version = "0.9.120" backend = "cargo:cargo-nextest" -[[tools."cargo:cargo-zigbuild"]] -version = "0.20.1" -backend = "cargo:cargo-zigbuild" - [[tools."cargo:maturin"]] version = "1.11.5" backend = "cargo:maturin" @@ -370,10 +386,6 @@ backend = "pipx:nox" uvx = "true" uvx_args = "--python-preference=managed -p 3.13" -[[tools.python]] -version = "3.12.12" -backend = "core:python" - [[tools.ruff]] version = "0.14.13" backend = "aqua:astral-sh/ruff" @@ -449,30 +461,37 @@ backend = "aqua:astral-sh/uv" [tools.uv."platforms.linux-arm64"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-arm64-musl"] checksum = "sha256:ba8698c36c00c22efed4bd3506339b03c95604d001f02eaf6fbc814c9224d801" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-x64"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.linux-x64-musl"] checksum = "sha256:708b752876aeeb753257e1d55470569789e465684c1d3bc1760db26360b6c28b" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-unknown-linux-musl.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.macos-arm64"] checksum = "sha256:fcf0a9ea6599c6ae28a4c854ac6da76f2c889354d7c36ce136ef071f7ab9721f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-aarch64-apple-darwin.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.macos-x64"] checksum = "sha256:171eb8c518313e157c5b4cec7b4f743bc6bab1bd23e09b646679a02d096a047f" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-apple-darwin.tar.gz" +provenance = "github-attestations" [tools.uv."platforms.windows-x64"] checksum = "sha256:eb02fd95d8e0eed462b4a67ecdd320d865b38c560bffcda9a0b87ec944bdf036" url = "https://github.com/astral-sh/uv/releases/download/0.9.26/uv-x86_64-pc-windows-msvc.zip" +provenance = "github-attestations" [[tools.zig]] version = "0.15.2" diff --git a/mise.toml b/mise.toml index f7a6fe99ab..e25f278c84 100644 --- a/mise.toml +++ b/mise.toml @@ -53,6 +53,7 @@ ruff = "0.14.13" ty = "0.0.10" "npm:prettier" = "3.6.2" "npm:markdownlint-cli2" = "0.22.0" +"aqua:jqlang/jq" = "1.8.1" "go:golang.org/x/tools/cmd/goimports" = "latest" zig = "0.15.2" @@ -325,21 +326,35 @@ depends = ["build:coglet:wheel"] run = "nox -s coglet" [tasks."test:fuzz"] -description = "Run Go fuzz tests (FUZZTIME=30s per target by default)" -run = """ +description = "Run all Go fuzz tests (auto-discovered; FUZZTIME=30s per target by default)" +run = ''' #!/usr/bin/env bash -set -e +set -euo pipefail FUZZTIME="${FUZZTIME:-30s}" -echo "Fuzzing schema type resolution ($FUZZTIME)..." -go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime="$FUZZTIME" -echo "Fuzzing JSON schema generation ($FUZZTIME)..." -go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime="$FUZZTIME" -echo "Fuzzing Python parser ($FUZZTIME)..." -go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime="$FUZZTIME" -echo "Fuzzing type annotation parsing ($FUZZTIME)..." -go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime="$FUZZTIME" -echo "All fuzz targets passed." -""" + +# Auto-discover every Fuzz* target and the package it lives in, so new fuzz +# tests are picked up automatically without editing this task. go test only +# fuzzes one target per invocation, so we still loop and run them one at a time. +# jq emits " " per line; package and target names never +# contain spaces, so `read pkg target` splits them cleanly. +count=0 +while read -r pkg target; do + [ -z "$pkg" ] && continue + echo "Fuzzing $target in $pkg ($FUZZTIME)..." + go test "$pkg" -run="^$" -fuzz="^${target}$" -fuzztime="$FUZZTIME" + count=$((count + 1)) +done < <( + go test -list "^Fuzz" -json ./... 2>/dev/null \ + | jq -r 'select(.Action=="output" and (.Output|test("^Fuzz"))) | .Package + " " + (.Output | rtrimstr("\n"))' +) + +if [ "$count" -eq 0 ]; then + echo "No fuzz targets found." >&2 + exit 1 +fi + +echo "All $count fuzz targets passed." +''' [tasks."test:integration"] description = "Run integration tests (skips slow tests by default, set SHORT=0 for full suite)" diff --git a/pkg/cli/baseimage.go b/pkg/cli/baseimage.go index f9831f6fff..85d2dfbdfd 100644 --- a/pkg/cli/baseimage.go +++ b/pkg/cli/baseimage.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "slices" "strings" "github.com/spf13/cobra" @@ -61,65 +62,7 @@ func newBaseImageGenerateMatrix() *cobra.Command { Use: "generate-matrix", Short: "Generate a matrix of Cog base image versions (JSON)", RunE: func(cmd *cobra.Command, args []string) error { - validCudaVersions := strings.FieldsFunc(baseImageCUDAVersion, func(c rune) bool { - return c == ',' - }) - validPythonVersions := strings.FieldsFunc(baseImagePythonVersion, func(c rune) bool { - return c == ',' - }) - validTorchVersions := strings.FieldsFunc(baseImageTorchVersion, func(c rune) bool { - return c == ',' - }) - - allConfigurations := dockerfile.BaseImageConfigurations() - filteredMatrix := make([]dockerfile.BaseImageConfiguration, 0, len(allConfigurations)) - for _, config := range allConfigurations { - var found bool - if len(validCudaVersions) > 0 { - found = false - for _, validCudaVersion := range validCudaVersions { - if config.CUDAVersion == validCudaVersion { - found = true - } - } - if !found { - continue - } - } - - if len(validPythonVersions) > 0 { - found = false - for _, validPythonVersion := range validPythonVersions { - if config.PythonVersion == validPythonVersion { - found = true - } - } - if !found { - continue - } - } - - if len(validTorchVersions) > 0 { - found = false - for _, validTorchVersion := range validTorchVersions { - if config.TorchVersion == validTorchVersion { - found = true - } - } - if !found { - continue - } - } - - filteredMatrix = append(filteredMatrix, config) - } - - output, err := json.Marshal(filteredMatrix) - if err != nil { - return err - } - fmt.Println(string(output)) - return nil + return RunBaseImageGenerateMatrix(baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -127,23 +70,51 @@ func newBaseImageGenerateMatrix() *cobra.Command { return cmd } +// RunBaseImageGenerateMatrix prints, as JSON, the matrix of supported base image +// configurations filtered by the comma-separated CUDA/Python/Torch versions in +// opts. Empty version filters match all values. It is shared by both the Cobra +// and Kong base-image generate-matrix commands. +func RunBaseImageGenerateMatrix(opts BaseImageOptions) error { + split := func(s string) []string { + return strings.FieldsFunc(s, func(c rune) bool { return c == ',' }) + } + validCudaVersions := split(opts.CUDAVersion) + validPythonVersions := split(opts.PythonVersion) + validTorchVersions := split(opts.TorchVersion) + + matches := func(filters []string, value string) bool { + return len(filters) == 0 || slices.Contains(filters, value) + } + + allConfigurations := dockerfile.BaseImageConfigurations() + filteredMatrix := make([]dockerfile.BaseImageConfiguration, 0, len(allConfigurations)) + for _, config := range allConfigurations { + if !matches(validCudaVersions, config.CUDAVersion) { + continue + } + if !matches(validPythonVersions, config.PythonVersion) { + continue + } + if !matches(validTorchVersions, config.TorchVersion) { + continue + } + filteredMatrix = append(filteredMatrix, config) + } + + output, err := json.Marshal(filteredMatrix) + if err != nil { + return err + } + fmt.Println(string(output)) + return nil +} + func newBaseImageDockerfileCommand() *cobra.Command { var cmd = &cobra.Command{ Use: "dockerfile", Short: "Display Cog base image Dockerfile", RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - generator, err := baseImageGeneratorFromFlags(ctx) - if err != nil { - return err - } - dockerfile, err := generator.GenerateDockerfile(ctx) - if err != nil { - return err - } - fmt.Println(dockerfile) - return nil + return RunBaseImageDockerfile(cmd.Context(), baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -159,42 +130,11 @@ func newBaseImageBuildCommand() *cobra.Command { Use: "build", Short: "Build Cog base image", RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - dockerClient, err := docker.NewClient(ctx) - if err != nil { - return err - } - - generator, err := baseImageGeneratorFromFlags(ctx) - if err != nil { - return err - } - dockerfileContents, err := generator.GenerateDockerfile(ctx) + dockerClient, err := docker.NewClient(cmd.Context()) if err != nil { return err } - - cwd, err := os.Getwd() - if err != nil { - return err - } - baseImageName := dockerfile.BaseImageName(baseImageCUDAVersion, baseImagePythonVersion, baseImageTorchVersion) - - buildOpts := command.ImageBuildOptions{ - WorkingDir: cwd, - DockerfileContents: dockerfileContents, - ImageName: baseImageName, - NoCache: buildNoCache, - ProgressOutput: buildProgressOutput, - Epoch: &config.BuildSourceEpochTimestamp, - ContextDir: ".", - } - if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { - return err - } - fmt.Println("Successfully built image: " + baseImageName) - return nil + return RunBaseImageBuild(cmd.Context(), dockerClient, baseImageOptionsFromFlags()) }, Args: cobra.MaximumNArgs(0), } @@ -203,6 +143,84 @@ func newBaseImageBuildCommand() *cobra.Command { return cmd } +// BaseImageOptions holds the parser-independent options shared by the +// base-image dockerfile and build commands. +type BaseImageOptions struct { + CUDAVersion string + PythonVersion string + TorchVersion string + BreakSystemPackages bool + BuildContextDir string + NoCache bool + ProgressOutput string + Timestamp int64 +} + +// baseImageOptionsFromFlags reads the package-level Cobra flag globals into a +// BaseImageOptions value. +func baseImageOptionsFromFlags() BaseImageOptions { + return BaseImageOptions{ + CUDAVersion: baseImageCUDAVersion, + PythonVersion: baseImagePythonVersion, + TorchVersion: baseImageTorchVersion, + BreakSystemPackages: baseImageBreakSystemPackages, + BuildContextDir: baseImageBuildContextDir, + NoCache: buildNoCache, + ProgressOutput: buildProgressOutput, + Timestamp: config.BuildSourceEpochTimestamp, + } +} + +// RunBaseImageDockerfile generates and prints the base image Dockerfile. It is +// shared by both the Cobra and Kong base-image dockerfile commands. +func RunBaseImageDockerfile(ctx context.Context, opts BaseImageOptions) error { + generator, err := baseImageGenerator(ctx, opts) + if err != nil { + return err + } + contents, err := generator.GenerateDockerfile(ctx) + if err != nil { + return err + } + fmt.Println(contents) + return nil +} + +// RunBaseImageBuild builds a Cog base image. It is shared by both the Cobra and +// Kong base-image build commands. +func RunBaseImageBuild(ctx context.Context, dockerClient command.Command, opts BaseImageOptions) error { + generator, err := baseImageGenerator(ctx, opts) + if err != nil { + return err + } + dockerfileContents, err := generator.GenerateDockerfile(ctx) + if err != nil { + return err + } + + cwd, err := os.Getwd() + if err != nil { + return err + } + baseImageName := dockerfile.BaseImageName(opts.CUDAVersion, opts.PythonVersion, opts.TorchVersion) + + timestamp := opts.Timestamp + buildOpts := command.ImageBuildOptions{ + WorkingDir: cwd, + DockerfileContents: dockerfileContents, + ImageName: baseImageName, + NoCache: opts.NoCache, + ProgressOutput: opts.ProgressOutput, + Epoch: ×tamp, + ContextDir: ".", + } + if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { + return err + } + fmt.Println("Successfully built image: " + baseImageName) + return nil +} + func addBaseImageFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&baseImageCUDAVersion, "cuda", "", "CUDA version") cmd.Flags().StringVar(&baseImagePythonVersion, "python", "", "Python version") @@ -214,7 +232,7 @@ func addBaseImageFlags(cmd *cobra.Command) { addBuildTimestampFlag(cmd) } -func baseImageGeneratorFromFlags(ctx context.Context) (*dockerfile.BaseImageGenerator, error) { +func baseImageGenerator(ctx context.Context, opts BaseImageOptions) (*dockerfile.BaseImageGenerator, error) { dockerClient, err := docker.NewClient(ctx) if err != nil { return nil, err @@ -223,16 +241,16 @@ func baseImageGeneratorFromFlags(ctx context.Context) (*dockerfile.BaseImageGene generator, err := dockerfile.NewBaseImageGenerator( ctx, client, - baseImageCUDAVersion, - baseImagePythonVersion, - baseImageTorchVersion, + opts.CUDAVersion, + opts.PythonVersion, + opts.TorchVersion, dockerClient, true, ) if err != nil { return nil, err } - generator.SetBreakSystemPackages(baseImageBreakSystemPackages) - generator.SetBuildContextDir(baseImageBuildContextDir) + generator.SetBreakSystemPackages(opts.BreakSystemPackages) + generator.SetBuildContextDir(opts.BuildContextDir) return generator, nil } diff --git a/pkg/cli/build.go b/pkg/cli/build.go index bbdc5f37f7..0fdde3a079 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "os" "strings" @@ -10,6 +11,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" @@ -83,7 +85,101 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + return RunBuild(ctx, dockerClient, registry.NewRegistryClient(), BuildCommandOptions{ + ConfigFilename: configFilename, + Tag: buildTag, + Flags: buildFlagsOptionsFromCobra(cmd), + }) +} + +// BuildFlagsOptions holds the parser-independent build flag values shared by +// both the Cobra and Kong CLIs. It mirrors the flags registered on the build +// and push commands. +type BuildFlagsOptions struct { + NoCache bool + SeparateWeights bool + Secrets []string + ProgressOutput string + UseCudaBaseImage string + UseCogBaseImage *bool + OpenAPISchema string + DockerfileFile string + Strip bool + Precompile bool + SkipSchemaValidation bool + Timestamp int64 +} + +// BuildCommandOptions holds everything RunBuild needs that is independent of +// the argument parser. +type BuildCommandOptions struct { + ConfigFilename string + Tag string + Flags BuildFlagsOptions +} + +// ResolveBuildImageName resolves the output image name from the config image, +// config model, an explicit tag, and the project directory, matching the +// historical precedence: tag > config image > config model > derived name. +func ResolveBuildImageName(configImage, configModel, tag, projectDir string) string { + imageName := configImage + if imageName == "" { + imageName = configModel + } + if tag != "" { + imageName = tag + } + if imageName == "" { + imageName = config.DockerImageName(projectDir) + } + return imageName +} + +// ModelBuildOptions converts the shared build flags into model.BuildOptions for +// the given image name and annotations. +func (o BuildFlagsOptions) ModelBuildOptions(imageName string, annotations map[string]string) model.BuildOptions { + return model.BuildOptions{ + ImageName: imageName, + Secrets: o.Secrets, + NoCache: o.NoCache, + SeparateWeights: o.SeparateWeights, + UseCudaBaseImage: o.UseCudaBaseImage, + ProgressOutput: o.ProgressOutput, + SchemaFile: o.OpenAPISchema, + DockerfileFile: o.DockerfileFile, + UseCogBaseImage: o.UseCogBaseImage, + Strip: o.Strip, + Precompile: o.Precompile, + Annotations: annotations, + SkipSchemaValidation: o.SkipSchemaValidation, + } +} + +// buildFlagsOptionsFromCobra reads the package-level Cobra flag globals into a +// BuildFlagsOptions value. +func buildFlagsOptionsFromCobra(cmd *cobra.Command) BuildFlagsOptions { + return BuildFlagsOptions{ + NoCache: buildNoCache, + SeparateWeights: buildSeparateWeights, + Secrets: buildSecrets, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + OpenAPISchema: buildSchemaFile, + DockerfileFile: buildDockerfileFile, + Strip: buildStrip, + Precompile: buildPrecompile, + SkipSchemaValidation: buildSkipSchemaValidation, + Timestamp: config.BuildSourceEpochTimestamp, + } +} + +// RunBuild builds a Docker image from the model source described by opts. It is +// shared by both the Cobra and Kong build commands. +func RunBuild(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts BuildCommandOptions) error { + config.BuildSourceEpochTimestamp = opts.Flags.Timestamp + + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -93,22 +189,13 @@ func buildCommand(cmd *cobra.Command, args []string) error { return err } - imageName := src.Config.Image - if imageName == "" { - imageName = src.Config.Model - } - if buildTag != "" { - imageName = buildTag - } - if imageName == "" { - imageName = config.DockerImageName(src.ProjectDir) - } + imageName := ResolveBuildImageName(src.Config.Image, src.Config.Model, opts.Tag, src.ProjectDir) console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(imageName)) console.Info("") - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.Build(ctx, src, buildOptionsFromFlags(cmd, imageName, nil)) + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, opts.Flags.ModelBuildOptions(imageName, nil)) if err != nil { return err } @@ -188,7 +275,10 @@ func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error { flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"} var flagsSet []string for _, flag := range flags { - if cmd.Flag(flag).Changed { + // Not every command that runs this check registers all of these flags + // (e.g. exec has no --dockerfile), so skip flags that aren't defined. + f := cmd.Flag(flag) + if f != nil && f.Changed { flagsSet = append(flagsSet, "--"+flag) } } @@ -211,23 +301,3 @@ func addSkipSchemaValidationFlag(cmd *cobra.Command) { cmd.Flags().BoolVar(&buildSkipSchemaValidation, "skip-schema-validation", false, "Skip OpenAPI schema generation and validation") _ = cmd.Flags().MarkHidden("skip-schema-validation") } - -// buildOptionsFromFlags creates BuildOptions from the current CLI flag values. -// The imageName and annotations parameters vary by command and must be provided. -func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions { - return model.BuildOptions{ - ImageName: imageName, - Secrets: buildSecrets, - NoCache: buildNoCache, - SeparateWeights: buildSeparateWeights, - UseCudaBaseImage: buildUseCudaBaseImage, - ProgressOutput: buildProgressOutput, - SchemaFile: buildSchemaFile, - DockerfileFile: buildDockerfileFile, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - Strip: buildStrip, - Precompile: buildPrecompile, - Annotations: annotations, - SkipSchemaValidation: buildSkipSchemaValidation, - } -} diff --git a/pkg/cli/build_test.go b/pkg/cli/build_test.go new file mode 100644 index 0000000000..80a3554ba8 --- /dev/null +++ b/pkg/cli/build_test.go @@ -0,0 +1,60 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolveBuildImageNameFallsBackToModel(t *testing.T) { + name := ResolveBuildImageName("", "r8.im/user/model", "", "/tmp/project") + require.Equal(t, "r8.im/user/model", name) +} + +func TestResolveBuildImageNamePrefersConfigImageOverModel(t *testing.T) { + name := ResolveBuildImageName("config-image", "r8.im/user/model", "", "/tmp/project") + require.Equal(t, "config-image", name) +} + +func TestResolveBuildImageNamePrefersTag(t *testing.T) { + name := ResolveBuildImageName("config-image", "r8.im/user/model", "custom:tag", "/tmp/project") + require.Equal(t, "custom:tag", name) +} + +func TestUseCogBaseImageExplicitness(t *testing.T) { + falseValue := false + opts := BuildFlagsOptions{UseCogBaseImage: &falseValue} + require.NotNil(t, opts.UseCogBaseImage) + require.False(t, *opts.UseCogBaseImage) + + opts = BuildFlagsOptions{} + require.Nil(t, opts.UseCogBaseImage) +} + +func TestBuildFlagsOptionsModelBuildOptions(t *testing.T) { + opts := BuildFlagsOptions{ + NoCache: true, + SeparateWeights: true, + Secrets: []string{"id=foo"}, + ProgressOutput: "plain", + UseCudaBaseImage: "false", + OpenAPISchema: "schema.json", + DockerfileFile: "Dockerfile", + Strip: true, + Precompile: true, + } + annotations := map[string]string{"k": "v"} + bo := opts.ModelBuildOptions("my-image", annotations) + + require.Equal(t, "my-image", bo.ImageName) + require.True(t, bo.NoCache) + require.True(t, bo.SeparateWeights) + require.Equal(t, []string{"id=foo"}, bo.Secrets) + require.Equal(t, "plain", bo.ProgressOutput) + require.Equal(t, "false", bo.UseCudaBaseImage) + require.Equal(t, "schema.json", bo.SchemaFile) + require.Equal(t, "Dockerfile", bo.DockerfileFile) + require.True(t, bo.Strip) + require.True(t, bo.Precompile) + require.Equal(t, annotations, bo.Annotations) +} diff --git a/pkg/cli/debug.go b/pkg/cli/debug.go index 4b83a2c366..5a7572230c 100644 --- a/pkg/cli/debug.go +++ b/pkg/cli/debug.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "strings" @@ -8,6 +9,7 @@ import ( "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/dockerfile" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" @@ -26,7 +28,6 @@ func newDebugCommand() *cobra.Command { addSeparateWeightsFlag(cmd) addUseCudaBaseImageFlag(cmd) - addDockerfileFlag(cmd) addUseCogBaseImageFlag(cmd) addBuildTimestampFlag(cmd) addConfigFlag(cmd) @@ -38,35 +39,56 @@ func newDebugCommand() *cobra.Command { func cmdDockerfile(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - src, err := model.NewSource(configFilename) + dockerClient, err := docker.NewClient(ctx) if err != nil { return err } - defer src.Close() - dockerClient, err := docker.NewClient(ctx) + return RunDebug(ctx, dockerClient, registry.NewRegistryClient(), DebugCommandOptions{ + ConfigFilename: configFilename, + ImageName: imageName, + SeparateWeights: buildSeparateWeights, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + }) +} + +// DebugCommandOptions holds the parser-independent options for the debug +// command, which generates and prints a Dockerfile. +type DebugCommandOptions struct { + ConfigFilename string + ImageName string + SeparateWeights bool + UseCudaBaseImage string + UseCogBaseImage *bool +} + +// RunDebug generates the Dockerfile(s) for the model and prints them to stdout. +// It is shared by both the Cobra and Kong debug commands. +func RunDebug(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts DebugCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } + defer src.Close() buildDir, err := src.DotCog.TempPath("build") if err != nil { return err } - client := registry.NewRegistryClient() - generator, err := dockerfile.NewStandardGenerator(src.Config, src.ProjectDir, buildDir, src.ConfigFilename, dockerClient, client, true) + generator, err := dockerfile.NewStandardGenerator(src.Config, src.ProjectDir, buildDir, src.ConfigFilename, dockerClient, regClient, true) if err != nil { return fmt.Errorf("Error creating Dockerfile generator: %w", err) } - generator.SetUseCudaBaseImage(buildUseCudaBaseImage) - useCogBaseImage := DetermineUseCogBaseImage(cmd) - if useCogBaseImage != nil { - generator.SetUseCogBaseImage(*useCogBaseImage) + generator.SetUseCudaBaseImage(opts.UseCudaBaseImage) + if opts.UseCogBaseImage != nil { + generator.SetUseCogBaseImage(*opts.UseCogBaseImage) } - if buildSeparateWeights { + if opts.SeparateWeights { + imageName := opts.ImageName if imageName == "" { imageName = config.DockerImageName(src.ProjectDir) } diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 62a6e114b7..3d93bdea27 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -32,7 +32,7 @@ NOTE: cog doctor is experimental. Behavior and checks may change in future versi By default, cog doctor reports problems without modifying any files. Pass --fix to automatically apply safe fixes.`, RunE: func(cmd *cobra.Command, args []string) error { - return runDoctor(cmd.Context(), fix) + return RunDoctor(cmd.Context(), configFilename, fix) }, Args: cobra.NoArgs, // printDoctorResults already prints findings to the user; suppress @@ -47,7 +47,10 @@ Pass --fix to automatically apply safe fixes.`, return cmd } -func runDoctor(ctx context.Context, fix bool) error { +// RunDoctor diagnoses (and optionally fixes) common issues in the Cog project +// described by configFilename. It is shared by both the Cobra and Kong doctor +// commands. +func RunDoctor(ctx context.Context, configFilename string, fix bool) error { console.Warnf("NOTE: cog doctor is experimental. Behavior and checks may change in future versions.") console.Info("") diff --git a/pkg/cli/exec.go b/pkg/cli/exec.go index 4d14d5b320..7ad33b95ff 100644 --- a/pkg/cli/exec.go +++ b/pkg/cli/exec.go @@ -1,8 +1,8 @@ package cli import ( + "context" "errors" - "os" "strconv" "strings" @@ -49,7 +49,6 @@ exploring the environment your model will run in.`, Args: cobra.MinimumNArgs(1), } addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addUseCudaBaseImageFlag(cmd) addUseCogBaseImageFlag(cmd) addGpusFlag(cmd) @@ -75,46 +74,59 @@ func execCmd(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + return RunExec(ctx, dockerClient, registry.NewRegistryClient(), ExecCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Args: args, + Ports: execPorts, + }) +} + +// ExecCommandOptions holds everything RunExec needs that is independent of the +// argument parser. +type ExecCommandOptions struct { + RuntimeBuildOptions + Args []string + Ports []string +} + +// RunExec builds a local image and executes an arbitrary command inside it with +// the project directory volume-mounted. It is shared by both the Cobra and Kong +// exec commands. +func RunExec(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts ExecCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } defer src.Close() - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) + resolver := model.NewResolver(dockerClient, regClient) console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - opts := serveBuildOptions(cmd) - opts.SkipSchemaValidation = true - m, err := resolver.Build(ctx, src, opts) + buildOpts := opts.ServeBuildOptions() + buildOpts.SkipSchemaValidation = true + m, err := resolver.Build(ctx, src, buildOpts) if err != nil { return err } - gpus := "" - if gpusFlag != "" { - gpus = gpusFlag - } else if m.HasGPU() { - gpus = "all" - } - - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) - } - runOptions := command.RunOptions{ - Args: args, - Env: env, - GPUs: gpus, + Args: opts.Args, + Env: runtimeEnv(opts.Env), + GPUs: runtimeGPUs(opts.GPUs, m), Image: m.ImageRef(), Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}}, Workdir: "/src", } - for _, portString := range execPorts { + for _, portString := range opts.Ports { port, err := strconv.Atoi(portString) if err != nil { return err @@ -124,7 +136,7 @@ func execCmd(cmd *cobra.Command, args []string) error { } console.Info("") - console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) + console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(opts.Args, " "))) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) diff --git a/pkg/cli/init.go b/pkg/cli/init.go index b75f864794..98b00c5f9b 100644 --- a/pkg/cli/init.go +++ b/pkg/cli/init.go @@ -37,6 +37,12 @@ and run interface. Edit them to match your model's requirements.`, } func initCommand(cmd *cobra.Command, args []string) error { + return RunInit() +} + +// RunInit sets up the current directory for use with Cog by writing the base +// project template. It is shared by both the Cobra and Kong init commands. +func RunInit() error { console.Info("Setting up the current directory for use with Cog...") console.Info("") diff --git a/pkg/cli/login.go b/pkg/cli/login.go index 721032b0c3..d5dbe82b07 100644 --- a/pkg/cli/login.go +++ b/pkg/cli/login.go @@ -1,6 +1,8 @@ package cli import ( + "context" + "github.com/spf13/cobra" "github.com/replicate/cog/pkg/global" @@ -36,25 +38,38 @@ func login(cmd *cobra.Command, args []string) error { // Initialize the provider registry setup.Init() - // Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var) - registryHost := global.ReplicateRegistryHost - tokenStdin, err := cmd.Flags().GetBool("token-stdin") if err != nil { return err } + return RunLogin(ctx, provider.DefaultRegistry(), LoginOptions{ + TokenStdin: tokenStdin, + // Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var) + Host: global.ReplicateRegistryHost, + }) +} + +// LoginOptions holds the parser-independent options for the login command. +type LoginOptions struct { + TokenStdin bool + Host string +} + +// RunLogin logs in to the container registry for opts.Host. It is shared by +// both the Cobra and Kong login commands. +func RunLogin(ctx context.Context, providerReg *provider.Registry, opts LoginOptions) error { // Look up the provider for this registry - p := provider.DefaultRegistry().ForHost(registryHost) + p := providerReg.ForHost(opts.Host) if p == nil { // This shouldn't happen since GenericProvider matches everything - console.Warnf("No provider found for registry '%s'.", registryHost) - console.Infof("Please use 'docker login %s' to authenticate.", registryHost) + console.Warnf("No provider found for registry '%s'.", opts.Host) + console.Infof("Please use 'docker login %s' to authenticate.", opts.Host) return nil } return p.Login(ctx, provider.LoginOptions{ - TokenStdin: tokenStdin, - Host: registryHost, + TokenStdin: opts.TokenStdin, + Host: opts.Host, }) } diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index a8c322215c..a4e394b85b 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -91,7 +91,6 @@ it.`, addUseCudaBaseImageFlag(cmd) addUseCogBaseImageFlag(cmd) addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addGpusFlag(cmd) addSetupTimeoutFlag(cmd) addConfigFlag(cmd) @@ -189,21 +188,64 @@ func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) { } func cmdPredict(cmd *cobra.Command, args []string) error { - if cmd.CalledAs() == "predict" || cmd.Name() == "predict" { - console.Warn(`"cog predict" is deprecated, use "cog run"`) + dockerClient, err := docker.NewClient(cmd.Context()) + if err != nil { + return err } - ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) - defer stop() + image := "" + if len(args) > 0 { + image = args[0] + } - dockerClient, err := docker.NewClient(ctx) - if err != nil { - return err + return RunPrediction(cmd.Context(), dockerClient, PredictionCommandOptions{ + Use: cmd.CalledAs(), + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Image: image, + Input: inputFlags, + InputJSON: inputJSON, + OutputPath: outPath, + SetupTimeout: setupTimeout, + UseReplicateAPI: useReplicateAPIToken, + }) +} + +// PredictionCommandOptions holds everything RunPrediction needs that is +// independent of the argument parser. +type PredictionCommandOptions struct { + RuntimeBuildOptions + + // Use is the invoked command name ("predict" or "run"); when "predict" a + // deprecation warning is printed. + Use string + Image string + Input []string + InputJSON string + OutputPath string + SetupTimeout uint32 + UseReplicateAPI bool +} + +// RunPrediction builds or pulls a model image and runs a prediction. It is +// shared by both the Cobra and Kong predict/run commands. +func RunPrediction(parent context.Context, dockerClient command.Command, opts PredictionCommandOptions) error { + if opts.Use == "predict" { + console.Warn(`"cog predict" is deprecated, use "cog run"`) } + ctx, stop := signal.NotifyContext(parent, syscall.SIGINT, syscall.SIGTERM) + defer stop() + imageName := "" volumes := []command.Volume{} - gpus := gpusFlag + gpus := opts.GPUs // The Manager is built only when we have cog.yaml in scope (the // build-from-source path). Pre-built images are opaque to Cog and @@ -212,9 +254,9 @@ func cmdPredict(cmd *cobra.Command, args []string) error { resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - if len(args) == 0 { + if opts.Image == "" { // Build image - src, err := model.NewSource(configFilename) + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -226,7 +268,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } @@ -248,7 +290,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { } } else { // Use existing image - imageName = args[0] + imageName = opts.Image // If the image name contains '=', then it's probably a mistake if strings.Contains(imageName, "=") { @@ -273,11 +315,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { console.Info("") console.Info("Starting Docker image and running setup()...") - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) - } + env := runtimeEnv(opts.Env) predictor, err := predict.NewPredictor(ctx, predict.PredictorOptions{ RunOptions: command.RunOptions{ @@ -294,7 +332,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - timeout := time.Duration(setupTimeout) * time.Second + timeout := time.Duration(opts.SetupTimeout) * time.Second if err := predictor.Start(ctx, os.Stderr, timeout); err != nil { // Only retry if we're using a GPU but the user didn't explicitly select a GPU with --gpus // If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it @@ -332,21 +370,21 @@ func cmdPredict(cmd *cobra.Command, args []string) error { } }() - if inputJSON != "" { - if len(inputFlags) > 0 { + if opts.InputJSON != "" { + if len(opts.Input) > 0 { return fmt.Errorf("Must use one of --json or --input to provide model inputs") } - return predictJSONInputs(*predictor, inputJSON, outPath, false) + return predictJSONInputs(*predictor, opts.InputJSON, opts.OutputPath, false, opts.UseReplicateAPI) } - return predictIndividualInputs(*predictor, inputFlags, outPath, false) + return predictIndividualInputs(*predictor, opts.Input, opts.OutputPath, false, opts.UseReplicateAPI) } func isURI(ref *openapi3.Schema) bool { return ref != nil && ref.Type.Is("string") && ref.Format == "uri" } -func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath string, isTrain bool) error { +func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath string, isTrain bool, useReplicateAPI bool) error { jsonInputs, err := parseJSONInput(jsonInput) if err != nil { return err @@ -373,10 +411,10 @@ func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath } } - return runPrediction(predictor, inputs, outputPath, isTrain, true) + return runPrediction(predictor, inputs, outputPath, isTrain, true, useReplicateAPI) } -func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool) error { +func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool, useReplicateAPI bool) error { schema, err := predictor.GetSchema() if err != nil { return err @@ -387,10 +425,10 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - return runPrediction(predictor, inputs, outputPath, isTrain, false) + return runPrediction(predictor, inputs, outputPath, isTrain, false, useReplicateAPI) } -func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPath string, isTrain bool, needsJSON bool) error { +func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPath string, isTrain bool, needsJSON bool, useReplicateAPI bool) error { if isTrain { console.Info("Running training...") } else { @@ -421,7 +459,7 @@ func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPat context := predict.RequestContext{} - if useReplicateAPIToken { + if useReplicateAPI { context.ReplicateAPIToken = os.Getenv("REPLICATE_API_TOKEN") if context.ReplicateAPIToken == "" { return fmt.Errorf("Failed to find REPLICATE_API_TOKEN in the current environment when called with --use-replicate-token") diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 6246c01772..23ff391f1b 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -10,7 +11,9 @@ import ( "github.com/replicate/go/uuid" + "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/provider/setup" @@ -64,7 +67,53 @@ func push(cmd *cobra.Command, args []string) error { return err } - src, err := model.NewSource(configFilename) + image := "" + if len(args) > 0 { + image = args[0] + } + + return RunPush(ctx, dockerClient, registry.NewRegistryClient(), provider.DefaultRegistry(), PushCommandOptions{ + ConfigFilename: configFilename, + Image: image, + Flags: buildFlagsOptionsFromCobra(cmd), + }) +} + +// PushCommandOptions holds everything RunPush needs that is independent of the +// argument parser. +type PushCommandOptions struct { + ConfigFilename string + Image string + Flags BuildFlagsOptions +} + +// ResolvePushTarget resolves the push target image reference from the config +// image, config model, and positional args, running the push-specific +// user-input checks. It returns the target string, the resolved model ref (or +// nil for FormatImage paths), and any validation error. +func ResolvePushTarget(configImage, configModel string, args []string) (string, *model.ResolvedRef, error) { + modelRef, err := validatePushArgs(configImage, configModel, args) + if err != nil { + return "", nil, err + } + switch { + case modelRef != nil: + return modelRef.String(), modelRef, nil + case len(args) > 0: + return args[0], nil, nil + case configImage != "": + return configImage, nil, nil + default: + return "", nil, errors.New("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") + } +} + +// RunPush builds and pushes a model image. It is shared by both the Cobra and +// Kong push commands. +func RunPush(ctx context.Context, dockerClient command.Command, regClient registry.Client, providerReg *provider.Registry, opts PushCommandOptions) error { + config.BuildSourceEpochTimestamp = opts.Flags.Timestamp + + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -81,24 +130,17 @@ func push(cmd *cobra.Command, args []string) error { // can drive credential selection from the correct host even when // Build fails (the build-error path calls p.PostPush(...) for // Replicate-specific guidance). - modelRef, err := validatePushArgs(src.Config.Image, src.Config.Model, args) + var args []string + if opts.Image != "" { + args = []string{opts.Image} + } + pushTarget, _, err := ResolvePushTarget(src.Config.Image, src.Config.Model, args) if err != nil { return err } - var pushTarget string - switch { - case modelRef != nil: - pushTarget = modelRef.String() - case len(args) > 0: - pushTarget = args[0] - case src.Config.Image != "": - pushTarget = src.Config.Image - default: - return errors.New("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'") - } // Look up the provider for the target registry - p := provider.DefaultRegistry().ForImage(pushTarget) + p := providerReg.ForImage(pushTarget) if p == nil { return fmt.Errorf("no provider found for image '%s'", pushTarget) } @@ -116,14 +158,12 @@ func push(cmd *cobra.Command, args []string) error { annotations["run.cog.push_id"] = buildID.String() } - regClient := registry.NewRegistryClient() resolver := model.NewResolver(dockerClient, regClient) // Build the model console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(pushTarget)) console.Info("") - buildOpts := buildOptionsFromFlags(cmd, pushTarget, annotations) - m, err := resolver.Build(ctx, src, buildOpts) + m, err := resolver.Build(ctx, src, opts.Flags.ModelBuildOptions(pushTarget, annotations)) if err != nil { // Call PostPush to handle error logging/analytics _ = p.PostPush(ctx, pushOpts, err) diff --git a/pkg/cli/push_test.go b/pkg/cli/push_test.go index 1e567f36a2..f0ef4a2b18 100644 --- a/pkg/cli/push_test.go +++ b/pkg/cli/push_test.go @@ -331,3 +331,34 @@ func TestFormatPushResult_DefensiveGuards(t *testing.T) { assert.NotContains(t, out, "image") }) } + +func TestResolvePushTargetUsesModelRef(t *testing.T) { + target, ref, err := ResolvePushTarget("", "r8.im/user/model", nil) + require.NoError(t, err) + require.NotNil(t, ref) + require.Equal(t, ref.String(), target) +} + +func TestResolvePushTargetRejectsPositionalWithModel(t *testing.T) { + _, _, err := ResolvePushTarget("", "r8.im/user/model", []string{"example.com/image"}) + require.ErrorContains(t, err, "positional image argument not supported") +} + +func TestResolvePushTargetUsesConfigImage(t *testing.T) { + target, ref, err := ResolvePushTarget("registry.example.com/user/model", "", nil) + require.NoError(t, err) + require.Nil(t, ref) + require.Equal(t, "registry.example.com/user/model", target) +} + +func TestResolvePushTargetUsesPositionalArg(t *testing.T) { + target, ref, err := ResolvePushTarget("", "", []string{"registry.example.com/user/model"}) + require.NoError(t, err) + require.Nil(t, ref) + require.Equal(t, "registry.example.com/user/model", target) +} + +func TestResolvePushTargetRequiresTarget(t *testing.T) { + _, _, err := ResolvePushTarget("", "", nil) + require.ErrorContains(t, err, "you must either set the 'image' option") +} diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index b99562ce0c..264eb9cc94 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "os" @@ -13,6 +14,7 @@ import ( "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/weights" ) var ( @@ -56,20 +58,62 @@ and outputs as a REST API. Compatible with the Cog HTTP protocol.`, return cmd } -// serveBuildOptions creates BuildOptions for cog serve. +// RuntimeBuildOptions holds the parser-independent options shared by the +// runtime commands (serve, exec) that build a local image and run it with the +// project directory volume-mounted. +type RuntimeBuildOptions struct { + ConfigFilename string + ProgressOutput string + UseCudaBaseImage string + UseCogBaseImage *bool + GPUs string + Env []string +} + +// ServeBuildOptions creates BuildOptions for cog serve and cog exec. // Same build path as cog build, but with ExcludeSource so COPY . /src is // skipped — source is volume-mounted at runtime instead. All other layers // (wheels, apt, etc.) share Docker layer cache with cog build. -func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { +func (o RuntimeBuildOptions) ServeBuildOptions() model.BuildOptions { return model.BuildOptions{ - UseCudaBaseImage: buildUseCudaBaseImage, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - ProgressOutput: buildProgressOutput, + UseCudaBaseImage: o.UseCudaBaseImage, + UseCogBaseImage: o.UseCogBaseImage, + ProgressOutput: o.ProgressOutput, ExcludeSource: true, SkipLabels: true, } } +// runtimeGPUs resolves the GPU spec: an explicit request wins, otherwise +// "all" if the model declares GPU usage, otherwise empty. +func runtimeGPUs(requested string, m *model.Model) string { + if requested != "" { + return requested + } + if m.HasGPU() { + return "all" + } + return "" +} + +// runtimeEnv appends RUST_LOG from the host environment (for Rust coglet +// debugging) to the provided env list without mutating the input slice. +func runtimeEnv(env []string) []string { + out := append([]string{}, env...) + if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { + out = append(out, "RUST_LOG="+rustLog) + } + return out +} + +// ServeCommandOptions holds everything RunServe needs that is independent of +// the argument parser. +type ServeCommandOptions struct { + RuntimeBuildOptions + Port int + UploadURL string +} + func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() @@ -78,27 +122,41 @@ func cmdServe(cmd *cobra.Command, arg []string) error { return err } - src, err := model.NewSource(configFilename) + return RunServe(ctx, dockerClient, registry.NewRegistryClient(), ServeCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: envFlags, + }, + Port: port, + UploadURL: uploadURL, + }) +} + +// RunServe builds the model image and runs the Cog HTTP server. It is shared by +// both the Cobra and Kong serve commands. +func RunServe(ctx context.Context, dockerClient command.Command, regClient registry.Client, opts ServeCommandOptions) error { + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } defer src.Close() + if err := weights.CheckDrift(src.ProjectDir, src.Config.Weights); err != nil { + return err + } + console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + resolver := model.NewResolver(dockerClient, regClient) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } - gpus := "" - if gpusFlag != "" { - gpus = gpusFlag - } else if m.HasGPU() { - gpus = "all" - } - args := []string{ "python", "--check-hash-based-pycs", "never", @@ -106,38 +164,52 @@ func cmdServe(cmd *cobra.Command, arg []string) error { "--await-explicit-shutdown", "true", } - if uploadURL != "" { - args = append(args, "--upload-url", uploadURL) - } - - // Automatically propagate RUST_LOG for Rust coglet debugging - env := envFlags - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env = append(env, "RUST_LOG="+rustLog) + if opts.UploadURL != "" { + args = append(args, "--upload-url", opts.UploadURL) } runOptions := command.RunOptions{ Args: args, - Env: env, - GPUs: gpus, + Env: runtimeEnv(opts.Env), + GPUs: runtimeGPUs(opts.GPUs, m), Image: m.ImageRef(), Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}}, Workdir: "/src", + Ports: []command.Port{{HostPort: opts.Port, ContainerPort: 5000}}, + } + + wm, err := newWeightManager(src) + if err != nil { + return err + } + mounts, err := wm.Prepare(ctx) + if err != nil { + return fmt.Errorf("prepare weights: %w", err) + } + defer func() { + if err := mounts.Release(); err != nil { + console.Warnf("Failed to clean up weight mounts: %s", err) + } + }() + for _, spec := range mounts.Specs { + runOptions.Volumes = append(runOptions.Volumes, command.Volume{ + Source: spec.Source, + Destination: spec.Target, + ReadOnly: true, + }) } // On Linux, host.docker.internal is not available by default — add it. // This allows the container to reach services running on the host, // e.g. when --upload-url points to a local upload server. - if uploadURL != "" { + if opts.UploadURL != "" { runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} } - runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000}) - console.Info("") console.Infof("Running %[1]s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " "))) console.Info("") - console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", port))) + console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", opts.Port))) console.Info("") err = docker.Run(ctx, dockerClient, runOptions) diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 2fa4824296..571abd49d9 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -41,7 +41,6 @@ Otherwise, it will build the model in the current directory and train it.`, } addBuildProgressOutputFlag(cmd) - addDockerfileFlag(cmd) addUseCudaBaseImageFlag(cmd) addGpusFlag(cmd) addUseCogBaseImageFlag(cmd) @@ -55,26 +54,61 @@ Otherwise, it will build the model in the current directory and train it.`, } func cmdTrain(cmd *cobra.Command, args []string) error { - ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) - defer stop() - - dockerClient, err := docker.NewClient(ctx) + dockerClient, err := docker.NewClient(cmd.Context()) if err != nil { return err } + image := "" + if len(args) > 0 { + image = args[0] + } + + return RunTrain(cmd.Context(), dockerClient, TrainCommandOptions{ + RuntimeBuildOptions: RuntimeBuildOptions{ + ConfigFilename: configFilename, + ProgressOutput: buildProgressOutput, + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + GPUs: gpusFlag, + Env: trainEnvFlags, + }, + Image: image, + Input: trainInputFlags, + OutputPath: trainOutPath, + SetupTimeout: setupTimeout, + }) +} + +// TrainCommandOptions holds everything RunTrain needs that is independent of +// the argument parser. +type TrainCommandOptions struct { + RuntimeBuildOptions + + Image string + Input []string + OutputPath string + SetupTimeout uint32 +} + +// RunTrain builds or pulls a model image and runs a training job. It is shared +// by both the Cobra and Kong train commands. +func RunTrain(parent context.Context, dockerClient command.Command, opts TrainCommandOptions) error { + ctx, stop := signal.NotifyContext(parent, syscall.SIGINT, syscall.SIGTERM) + defer stop() + imageName := "" volumes := []command.Volume{} - gpus := gpusFlag + gpus := opts.GPUs // Managed-weight mounts only apply when we have cog.yaml in scope. var wm *weights.Manager resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - if len(args) == 0 { + if opts.Image == "" { // Build image - src, err := model.NewSource(configFilename) + src, err := model.NewSource(opts.ConfigFilename) if err != nil { return err } @@ -86,7 +120,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { console.Info("Building Docker image from environment in cog.yaml...") console.Info("") - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + m, err := resolver.Build(ctx, src, opts.ServeBuildOptions()) if err != nil { return err } @@ -108,7 +142,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { } } else { // Use existing image - imageName = args[0] + imageName = opts.Image // Pull the image (if needed) and validate it's a Cog model ref, err := model.ParseRef(imageName) @@ -133,7 +167,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { GPUs: gpus, Image: imageName, Volumes: volumes, - Env: trainEnvFlags, + Env: opts.Env, Args: []string{"python", "-m", "cog.server.http", "--x-mode", "train"}, }, IsTrain: true, @@ -144,7 +178,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - if err := predictor.Start(ctx, os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil { + if err := predictor.Start(ctx, os.Stderr, time.Duration(opts.SetupTimeout)*time.Second); err != nil { return err } @@ -156,5 +190,5 @@ func cmdTrain(cmd *cobra.Command, args []string) error { } }() - return predictIndividualInputs(*predictor, trainInputFlags, trainOutPath, true) + return predictIndividualInputs(*predictor, opts.Input, opts.OutputPath, true, false) } diff --git a/pkg/cli/weights.go b/pkg/cli/weights.go index de9c1941a4..344d0e4f4b 100644 --- a/pkg/cli/weights.go +++ b/pkg/cli/weights.go @@ -27,11 +27,6 @@ func newWeightsCommand() *cobra.Command { NOTE: cog weights is experimental. Behavior may change in future versions. Do not rely on it in production workflows.`, Hidden: true, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - console.Warnf("NOTE: cog weights is experimental. Behavior may change in future versions.") - console.Warnf("Do not rely on it in production workflows.") - console.Info("") - }, } cmd.AddCommand(newWeightsImportCommand()) @@ -66,7 +61,7 @@ Use --dry-run to preview what would change without importing anything. Add --verbose to see per-file details including which files pass the filter.`, Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsImportCommand(cmd, args, dryRun, verbose) + return RunWeightsImport(cmd.Context(), configFilename, args, dryRun, verbose) }, } @@ -76,8 +71,18 @@ Add --verbose to see per-file details including which files pass the filter.`, return cmd } -func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose bool) error { - ctx := cmd.Context() +// weightsExperimentalWarning prints the experimental notice shared by all +// weights subcommands. It runs once per subcommand invocation. +func weightsExperimentalWarning() { + console.Warnf("NOTE: cog weights is experimental. Behavior may change in future versions.") + console.Warnf("Do not rely on it in production workflows.") + console.Info("") +} + +// RunWeightsImport builds and pushes weights to a registry. It is shared by +// both the Cobra and Kong weights import commands. +func RunWeightsImport(ctx context.Context, configFilename string, args []string, dryRun, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { @@ -92,7 +97,7 @@ func weightsImportCommand(cmd *cobra.Command, args []string, dryRun, verbose boo return err } - weightSpecs, err := collectWeightSpecs(src, args) + weightSpecs, err := collectWeightSpecs(src, configFilename, args) if err != nil { return err } @@ -233,7 +238,7 @@ func planStatusIcon(status model.WeightImportPlanStatus) string { // collectWeightSpecs extracts WeightSpecs from the source, optionally // filtered to only the names listed in filterNames. An error is returned // if no weights match or if a requested name doesn't exist. -func collectWeightSpecs(src *model.Source, filterNames []string) ([]*model.WeightSpec, error) { +func collectWeightSpecs(src *model.Source, configFilename string, filterNames []string) ([]*model.WeightSpec, error) { if len(src.Config.Weights) == 0 { return nil, fmt.Errorf("no weights defined in %s", configFilename) } diff --git a/pkg/cli/weights_pull.go b/pkg/cli/weights_pull.go index 97ad2b79f1..30389ed0ff 100644 --- a/pkg/cli/weights_pull.go +++ b/pkg/cli/weights_pull.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "path/filepath" @@ -40,7 +41,7 @@ home directory is on a different filesystem than your project. Use --verbose to show per-layer and per-file progress.`, Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsPullCommand(cmd, args, verbose) + return RunWeightsPull(cmd.Context(), configFilename, args, verbose) }, } @@ -49,8 +50,10 @@ Use --verbose to show per-layer and per-file progress.`, return cmd } -func weightsPullCommand(cmd *cobra.Command, args []string, verbose bool) error { - ctx := cmd.Context() +// RunWeightsPull populates the local weight cache from the registry. It is +// shared by both the Cobra and Kong weights pull commands. +func RunWeightsPull(ctx context.Context, configFilename string, args []string, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { diff --git a/pkg/cli/weights_status.go b/pkg/cli/weights_status.go index 7490964201..2102355107 100644 --- a/pkg/cli/weights_status.go +++ b/pkg/cli/weights_status.go @@ -1,6 +1,7 @@ package cli import ( + "context" "encoding/json" "errors" "fmt" @@ -81,7 +82,7 @@ Use --verbose to show per-layer status for each weight. Exit code is 0 when all weights are ready, 1 otherwise.`, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return weightsStatusCommand(cmd, jsonOutput, verbose) + return RunWeightsStatus(cmd.Context(), configFilename, jsonOutput, verbose) }, } @@ -92,8 +93,10 @@ Exit code is 0 when all weights are ready, 1 otherwise.`, return cmd } -func weightsStatusCommand(cmd *cobra.Command, jsonOutput, verbose bool) error { - ctx := cmd.Context() +// RunWeightsStatus shows the status of configured weights. It is shared by both +// the Cobra and Kong weights status commands. +func RunWeightsStatus(ctx context.Context, configFilename string, jsonOutput, verbose bool) error { + weightsExperimentalWarning() src, err := model.NewSource(configFilename) if err != nil { diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 0e17993f31..e0af7d3253 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -66,11 +66,11 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b propertiesSchemas := properties.(openapi3.Schemas) property, err := propertiesSchemas.JSONLookup(key) if err == nil { - propertySchema := property.(*openapi3.Schema) + originalSchema := property.(*openapi3.Schema) // Resolve allOf/$ref to find the actual type. // cog-schema-gen emits allOf:[{$ref: ...}] for choices/enums, // where the referenced schema has the concrete type. - propertySchema = resolveSchemaType(propertySchema) + propertySchema := resolveSchemaType(originalSchema) switch { case propertySchema.Type.Is("object"): encodedVal := json.RawMessage(val) @@ -99,6 +99,13 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b } else { value, err := strconv.ParseFloat(val, 32) if err != nil { + // For a union like `float | str` the schema + // resolves to the numeric member first; a + // non-numeric value should fall back to the + // string member instead of erroring. + if schemaAcceptsString(originalSchema) { + break + } return input, err } float := float32(value) @@ -108,11 +115,48 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b case propertySchema.Type.Is("integer"): value, err := strconv.ParseInt(val, 10, 32) if err != nil { + // For a union like `int | float` the schema + // resolves to the integer member first; a + // fractional value should fall back to the float + // member instead of erroring. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } + // See the number case above: fall back to a string + // member for unions such as `int | str`. + if schemaAcceptsString(originalSchema) { + break + } return input, err } valueInt := int32(value) input[key] = Input{Int: &valueInt} continue + case schemaAcceptsNumber(originalSchema): + // Union input (anyOf) that includes a numeric member, e.g. + // `str | float`. Parse numeric-looking values as numbers so + // the runtime receives the intended type; otherwise fall + // through to the string member below. + if value, err := strconv.ParseInt(val, 10, 32); err == nil { + valueInt := int32(value) + input[key] = Input{Int: &valueInt} + continue + } + // Only parse fractional values as float when the + // union actually accepts a float member; otherwise a + // value like `1.5` for `str | int` must fall back to + // the string member below. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } } } } @@ -188,6 +232,79 @@ func fileToDataURL(filePath string) (string, error) { return dataURL, nil } +// schemaAcceptsString reports whether the schema accepts a string value, +// including union (anyOf) members. This lets CLI `-i` parsing fall back to a +// string member when a numeric parse fails for unions such as `float | str`, +// where resolveSchemaType resolves to the numeric member. +func schemaAcceptsString(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("string") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsFloat reports whether the schema accepts a floating-point +// value, including union (anyOf) members. Unlike schemaAcceptsNumber, it does +// not match integer-only members, so CLI `-i` parsing can decide whether a +// fractional value like `1.5` is valid for unions such as `int | float` +// (accepts float) versus `str | int` (does not). +func schemaAcceptsFloat(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("number") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsNumber reports whether the schema accepts a numeric value, +// including union (anyOf) members. This lets CLI `-i` parsing coerce +// numeric-looking strings for union inputs such as `str | float`, where +// resolveSchemaType resolves to a non-numeric member. +func schemaAcceptsNumber(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && (s.Type.Is("number") || s.Type.Is("integer")) { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + return false +} + // resolveSchemaType walks through allOf/anyOf/$ref wrappers to find a schema // that has a concrete Type set. This is needed because the static schema gen // emits allOf:[{$ref: "#/components/schemas/Foo"}] for enum/choices fields, diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go new file mode 100644 index 0000000000..b383d3ff93 --- /dev/null +++ b/pkg/predict/input_test.go @@ -0,0 +1,226 @@ +package predict + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +// unionInputSchema builds an OpenAPI doc whose single input field `value` +// is a union of string and number. The variant order is configurable so we +// can exercise both `str | float` (string first) and `float | str` (number +// first), which resolve differently via resolveSchemaType. +func unionInputSchema(numberFirst bool) *openapi3.T { + stringRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + numberRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + anyOf := openapi3.SchemaRefs{&stringRef, &numberRef} + if numberFirst { + anyOf = openapi3.SchemaRefs{&numberRef, &stringRef} + } + valueSchema := &openapi3.Schema{AnyOf: anyOf} + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: valueSchema}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionParsesNumber(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + numberFirst bool + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // str | float (string member first) + {name: "str|float integer", val: "1", wantInt: ptrI32(1)}, + {name: "str|float float", val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "str|float string", val: "hello", wantStr: ptrStr("hello")}, + // float | str (number member first) -- must still fall back to string + {name: "float|str integer", numberFirst: true, val: "1", wantInt: ptrI32(1)}, + {name: "float|str float", numberFirst: true, val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "float|str string", numberFirst: true, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchema(tt.numberFirst) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int) + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float) + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String) + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +// unionInputSchemaOf builds an OpenAPI doc whose single input field `value` +// is a union (anyOf) of the given JSON Schema types, in the given order. +func unionInputSchemaOf(types ...string) *openapi3.T { + anyOf := make(openapi3.SchemaRefs, len(types)) + for i, t := range types { + anyOf[i] = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{t}}} + } + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: &openapi3.Schema{AnyOf: anyOf}}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionIntFloatAndStrInt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + types []string + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // int | float: integer member resolves first; a fractional value must + // fall back to the float member instead of erroring. + {name: "int|float integer", types: []string{"integer", "number"}, val: "1", wantInt: ptrI32(1)}, + {name: "int|float fractional", types: []string{"integer", "number"}, val: "1.5", wantFlt: ptrF32(1.5)}, + // str | int: string resolves first; a fractional value is not valid for + // the integer member and must fall back to the string member. + {name: "str|int integer", types: []string{"string", "integer"}, val: "1", wantInt: ptrI32(1)}, + {name: "str|int fractional", types: []string{"string", "integer"}, val: "1.5", wantStr: ptrStr("1.5")}, + {name: "str|int string", types: []string{"string", "integer"}, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchemaOf(tt.types...) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int, "expected int") + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float, "expected float") + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String, "expected string") + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +func TestSchemaAcceptsNumber(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsNumber(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsNumber(union)) + + stringOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}}, + }, + } + require.False(t, schemaAcceptsNumber(stringOnlyUnion)) +} + +func TestSchemaAcceptsString(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsString(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + } + require.True(t, schemaAcceptsString(union)) + + numericOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsString(numericOnlyUnion)) +} + +func TestSchemaAcceptsFloat(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsFloat(nil)) + + intFloatUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsFloat(intFloatUnion)) + + strIntUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsFloat(strIntUnion)) +} + +func ptrI32(v int32) *int32 { return &v } +func ptrF32(v float32) *float32 { return &v } +func ptrStr(v string) *string { return &v } diff --git a/pkg/schema/input_type_fuzz_test.go b/pkg/schema/input_type_fuzz_test.go new file mode 100644 index 0000000000..d8ae1ab064 --- /dev/null +++ b/pkg/schema/input_type_fuzz_test.go @@ -0,0 +1,244 @@ +package schema + +import ( + "context" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +// FuzzResolveInputType builds arbitrary TypeAnnotation trees from fuzz input +// and verifies that ResolveInputType never panics. When resolution succeeds, +// it feeds the resulting InputType through OpenAPI generation and validates +// the emitted document with the same kin-openapi validator used at build time +// (writeAndValidateSchema). This is the key oracle: a union input type that +// resolves cleanly but emits an OpenAPI document the build-time validator +// rejects (e.g. an unsupported `type: null` branch) is a real bug, not just a +// panic. +func FuzzResolveInputType(f *testing.F) { + // Seed corpus — union and JSON-native input shapes, plus tricky cases. + seeds := []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "int"}, + {Kind: TypeAnnotSimple, Name: "float"}, + {Kind: TypeAnnotSimple, Name: "bool"}, + {Kind: TypeAnnotSimple, Name: "dict"}, + {Kind: TypeAnnotSimple, Name: "Any"}, + // str | float + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + // str | float | None + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "float"}, + {Kind: TypeAnnotSimple, Name: "None"}, + }}, + // str | None (single-variant collapse to nullable) + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + {Kind: TypeAnnotSimple, Name: "None"}, + }}, + // int | float + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "int"}, + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + // list[int] | list[float] + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "int"}, + }}, + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "float"}, + }}, + }}, + // dict | list[dict] + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "dict"}, + {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "dict"}, + }}, + }}, + // Unsupported member: Path | str (must be rejected, not panic) + {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "Path"}, + {Kind: TypeAnnotSimple, Name: "str"}, + }}, + // Optional[str] + {Kind: TypeAnnotGeneric, Name: "Optional", Args: []TypeAnnotation{ + {Kind: TypeAnnotSimple, Name: "str"}, + }}, + } + for _, s := range seeds { + f.Add(encodeAnnotation(s)) + } + + ctx := NewImportContext() + typedDicts := map[string]bool{} + + f.Fuzz(func(t *testing.T, data []byte) { + ann, _ := decodeAnnotation(data, 0, 0) + + // Must not panic regardless of input. + it, ft, err := ResolveInputType(ann, ctx, typedDicts) + if err != nil { + return + } + + // A resolved input type must build a field that validates and emits a + // valid OpenAPI document. + field := InputField{ + Name: "value", + Order: 0, + FieldType: ft, + InputType: &it, + } + if err := ValidateInputField(field); err != nil { + return + } + + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", field) + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + if err != nil { + return + } + + // Oracle: the generated schema must be a valid OpenAPI document, the + // same check writeAndValidateSchema performs at build time. A schema + // that fails here would surface as a confusing build failure for users. + assertValidOpenAPI(t, out) + }) +} + +// FuzzInputTypeJSONSchema constructs arbitrary InputType trees directly (not via +// the annotation resolver) and ensures both the per-field JSON schema helper +// and full OpenAPI generation never panic and always emit a valid document. +// Building InputType directly reaches shapes the resolver may not produce, +// stressing inputTypeJSONSchema and buildInputSchema in isolation. +func FuzzInputTypeJSONSchema(f *testing.F) { + f.Add([]byte{0, 3}) // primitive string + f.Add([]byte{1}) // any + f.Add([]byte{2, 0, 3}) // array of string + f.Add([]byte{3, 2, 0, 3, 0, 1}) // union of string and float + f.Add([]byte{3, 2, 0, 3, 0, 1, 0xff}) // nullable union of string and float + + f.Fuzz(func(t *testing.T, data []byte) { + it, _ := decodeInputType(data, 0, 0) + + // Field with both a compat FieldType and the recursive InputType set, + // mirroring how the parser populates InputField. + field := InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: &it, + } + + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", field) + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + if err != nil { + return + } + assertValidOpenAPI(t, out) + }) +} + +// assertValidOpenAPI loads and validates a generated OpenAPI document with the +// same kin-openapi validator used by writeAndValidateSchema at build time. +// A document that fails validation is a generation bug. +func assertValidOpenAPI(t *testing.T, schemaJSON []byte) { + t.Helper() + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + doc, err := loader.LoadFromData(schemaJSON) + require.NoError(t, err, "generated schema failed to load\n%s", string(schemaJSON)) + err = doc.Validate(context.Background()) + require.NoError(t, err, "generated schema is invalid\n%s", string(schemaJSON)) +} + +// decodeInputType builds an InputType tree from bytes, mirroring the encoding +// strategy of decodeSchemaType. The final byte of a primitive/union toggles +// nullability so the fuzzer reaches both nullable and non-nullable shapes. +func decodeInputType(data []byte, offset int, depth int) (InputType, int) { + if depth > maxFuzzDepth || offset >= len(data) { + return InputPrimitive(TypeString), offset + } + + kind := InputTypeKind(data[offset] % 4) + offset++ + + switch kind { + case InputKindPrimitive: + prim := PrimitiveType(0) + if offset < len(data) { + prim = PrimitiveType(data[offset] % 9) + offset++ + } + it := InputPrimitive(prim) + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + case InputKindAny: + it := InputAnyType() + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + case InputKindArray: + elem, newOffset := decodeInputType(data, offset, depth+1) + it := InputArrayOf(elem) + if newOffset < len(data) { + if data[newOffset]%2 == 1 { + it.Nullable = true + } + newOffset++ + } + return it, newOffset + + case InputKindUnion: + numVariants := 0 + if offset < len(data) { + numVariants = int(data[offset]) % 4 // cap at 3 variants + offset++ + } + variants := make([]InputType, 0, numVariants) + for i := 0; i < numVariants && offset < len(data); i++ { + v, newOffset := decodeInputType(data, offset, depth+1) + variants = append(variants, v) + offset = newOffset + } + it := InputUnionOf(variants...) + if offset < len(data) { + if data[offset]%2 == 1 { + it.Nullable = true + } + offset++ + } + return it, offset + + default: + return InputPrimitive(TypeString), offset + } +} diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index f15f4471ca..41f604e462 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -268,6 +268,52 @@ type enumSchema struct { schema map[string]any } +func inputTypeJSONSchema(it InputType) map[string]any { + var schema map[string]any + switch it.Kind { + case InputKindPrimitive: + schema = it.Primitive.JSONType() + case InputKindAny: + schema = TypeAny.JSONType() + case InputKindArray: + items := TypeAny.JSONType() + if it.Elem != nil { + items = inputTypeJSONSchema(*it.Elem) + } + schema = map[string]any{ + "type": "array", + "items": items, + } + case InputKindUnion: + variants := make([]any, len(it.Variants)) + for i, variant := range it.Variants { + variantSchema := inputTypeJSONSchema(variant) + if it.Nullable { + variantSchema["nullable"] = true + } + variants[i] = variantSchema + } + // A nullable union is represented with OpenAPI's `nullable` keyword + // (set below), matching how plain optional fields behave: an omitted + // value yields the default, while explicit JSON `null` is validated + // against the field type just like any other optional input. + schema = map[string]any{"anyOf": variants} + default: + schema = TypeAny.JSONType() + } + if it.Nullable { + schema["nullable"] = true + } + return schema +} + +func inputSchemaForField(field InputField) map[string]any { + if field.InputType != nil { + return inputTypeJSONSchema(*field.InputType) + } + return field.FieldType.JSONType() +} + // buildInputSchema builds the Input schema object and any enum schemas for choices. func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { properties := newOrderedMapAny() @@ -314,7 +360,7 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } else { // Regular field — inline type prop["title"] = TitleCase(name) - maps.Copy(prop, field.FieldType.JSONType()) + maps.Copy(prop, inputSchemaForField(field)) } // Determine effective default. A default of None on a non-nullable @@ -332,6 +378,9 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { // `Optional[Secret] = Input(default=None)`. See // TestNoneDefaultOnBareSecretIsOptional for the regression case. isNullable := field.FieldType.Repetition == Optional || field.FieldType.Repetition == OptionalRepeated + if field.InputType != nil && field.InputType.Nullable { + isNullable = true + } if field.FieldType.Primitive == TypeSecret && field.FieldType.Repetition == Required && field.Default != nil && field.Default.Kind == DefaultNone { @@ -343,7 +392,8 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } // Required? - if !hasEffectiveDefault && (field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { + isUnionInput := field.InputType != nil && field.InputType.Kind == InputKindUnion + if !hasEffectiveDefault && (isUnionInput || field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { required = append(required, name) } diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index d9b528bb9f..57669b3e6c 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -40,6 +40,25 @@ func parseSpec(t *testing.T, info *PredictorInfo) map[string]any { return spec } +func extractInputProperty(t *testing.T, raw []byte, name string) string { + t.Helper() + var doc map[string]any + require.NoError(t, json.Unmarshal(raw, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties[name] + require.True(t, ok) + out, err := json.Marshal(prop) + require.NoError(t, err) + return string(out) +} + func getPath(m map[string]any, keys ...string) any { var cur any = m for _, k := range keys { @@ -534,6 +553,62 @@ func TestInputOptionalRepeatedType(t *testing.T) { assert.Nil(t, inputSchema["required"]) } +func TestOpenAPIUnionInputStringFloat(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: ptr(InputUnionOf( + InputPrimitive(TypeString), + InputPrimitive(TypeFloat), + )), + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + require.JSONEq(t, `{"anyOf":[{"type":"string"},{"type":"number"}],"title":"Value","x-order":0}`, extractInputProperty(t, out, "value")) +} + +func TestOpenAPIRequiredNullableUnionInput(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + it := InputUnionOf(InputPrimitive(TypeString), InputPrimitive(TypeFloat)) + it.Nullable = true + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: &it, + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + input := doc["components"].(map[string]any)["schemas"].(map[string]any)["Input"].(map[string]any) + require.Equal(t, []any{"value"}, input["required"]) + prop := input["properties"].(map[string]any)["value"].(map[string]any) + require.JSONEq(t, `{ + "anyOf":[ + {"nullable":true,"type":"string"}, + {"nullable":true,"type":"number"} + ], + "nullable":true, + "title":"Value", + "x-order":0 + }`, extractInputProperty(t, out, "value")) + require.Equal(t, true, prop["nullable"]) +} + // --------------------------------------------------------------------------- // Tests: Choices / Enums // --------------------------------------------------------------------------- diff --git a/pkg/schema/python/inputs.go b/pkg/schema/python/inputs.go index 7275eed636..06669f089d 100644 --- a/pkg/schema/python/inputs.go +++ b/pkg/schema/python/inputs.go @@ -262,16 +262,17 @@ func typedParameterParts(node *sitter.Node, source []byte) (string, *sitter.Node return name, typeNode } -func inputField(name string, order int, fieldType schema.FieldType) schema.InputField { +func inputField(name string, order int, inputType schema.InputType, fieldType schema.FieldType) schema.InputField { return schema.InputField{ Name: name, Order: order, FieldType: fieldType, + InputType: &inputType, } } -func inputFieldWithInfo(name string, order int, fieldType schema.FieldType, info inputCallInfo) schema.InputField { - field := inputField(name, order, fieldType) +func inputFieldWithInfo(name string, order int, inputType schema.InputType, fieldType schema.FieldType, info inputCallInfo) schema.InputField { + field := inputField(name, order, inputType, fieldType) field.Default = info.Default field.Description = info.Description field.GE = info.GE @@ -293,12 +294,12 @@ func firstParamIsSelf(params *sitter.Node, source []byte) bool { return false } -func resolveParameterFieldType(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.FieldType, error) { +func resolveParameterInputTypes(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.InputType, schema.FieldType, error) { typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { - return schema.FieldType{}, err + return schema.InputType{}, schema.FieldType{}, err } - return schema.ResolveFieldType(typeAnn, ctx.imports, ctx.typedDicts) + return schema.ResolveInputType(typeAnn, ctx.imports, ctx.typedDicts) } func extractInputs( @@ -362,12 +363,13 @@ func parseTypedParameter(node *sitter.Node, source []byte, order int, ctx *input return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx *inputParseContext) (schema.InputField, error) { @@ -382,7 +384,7 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } @@ -396,19 +398,21 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx if err != nil { return schema.InputField{}, err } - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 2. Reference to Input() via class attribute or static method if info, ok := resolveInputReference(valNode, source, ctx.registry); ok { - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 3. Plain default — must be statically resolvable if def, ok := resolveDefaultExpr(valNode, source, ctx.scope); ok { - field := inputField(name, order, fieldType) + field := inputField(name, order, inputType, fieldType) field.Default = &def - return field, nil + return field, schema.ValidateInputField(field) } // Can't resolve — hard error @@ -418,7 +422,8 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx } // No default — required parameter - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 1d26dace12..de0a71f267 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -404,6 +404,129 @@ class Predictor(BasePredictor): require.Equal(t, schema.TypeString, name.FieldType.Primitive) } +func TestOptionalInputOpenAPINotRequired(t *testing.T) { + source := []byte(` +from typing import Optional + +class Predictor: + def predict(self, value: Optional[str]) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + _, hasRequired := input["required"] + require.False(t, hasRequired) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) +} + +func TestUnionInputStringFloat(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float) -> str: + return str(value) +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.Len(t, value.InputType.Variants, 2) + require.Equal(t, schema.TypeString, value.InputType.Variants[0].Primitive) + require.Equal(t, schema.TypeFloat, value.InputType.Variants[1].Primitive) + require.False(t, value.InputType.Nullable) +} + +func TestUnionInputStringFloatNone(t *testing.T) { + source := []byte(` +from cog import Input + +class Predictor: + def predict(self, value: str | float | None = Input(default=None)) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.True(t, value.InputType.Nullable) + require.NotNil(t, value.Default) + require.Equal(t, schema.DefaultNone, value.Default.Kind) +} + +func TestUnionInputNullableWithoutDefaultOpenAPI(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float | None) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + require.Equal(t, []any{"value"}, input["required"]) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) + require.Equal(t, []any{ + map[string]any{"nullable": true, "type": "string"}, + map[string]any{"nullable": true, "type": "number"}, + }, prop["anyOf"]) +} + +func TestUnionInputRejectsPathString(t *testing.T) { + source := []byte(` +from cog import Path + +class Predictor: + def predict(self, value: Path | str) -> str: + return "ok" +`) + + _, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.Error(t, err) + require.Contains(t, err.Error(), "Path") + require.Contains(t, err.Error(), "union") +} + // --------------------------------------------------------------------------- // List inputs // --------------------------------------------------------------------------- diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 4013424a69..0d86535801 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -98,6 +98,48 @@ type FieldType struct { Repetition Repetition } +// InputTypeKind tags the recursive input type representation. +type InputTypeKind int + +const ( + InputKindPrimitive InputTypeKind = iota + InputKindAny + InputKindArray + InputKindUnion +) + +// InputType represents JSON-native input types, including unions. +type InputType struct { + Kind InputTypeKind + Primitive PrimitiveType + Elem *InputType + Variants []InputType + Nullable bool +} + +// InputPrimitive creates a primitive input type. +func InputPrimitive(primitive PrimitiveType) InputType { + if primitive == TypeAny { + return InputAnyType() + } + return InputType{Kind: InputKindPrimitive, Primitive: primitive} +} + +// InputAnyType creates an opaque JSON input type. +func InputAnyType() InputType { + return InputType{Kind: InputKindAny, Primitive: TypeAny} +} + +// InputArrayOf creates an array input type. +func InputArrayOf(elem InputType) InputType { + return InputType{Kind: InputKindArray, Elem: &elem} +} + +// InputUnionOf creates a union input type. +func InputUnionOf(variants ...InputType) InputType { + return InputType{Kind: InputKindUnion, Variants: variants} +} + // JSONType returns the JSON Schema fragment for this field type. func (ft FieldType) JSONType() map[string]any { if ft.Repetition == Repeated || ft.Repetition == OptionalRepeated { @@ -179,6 +221,7 @@ type InputField struct { Name string Order int FieldType FieldType + InputType *InputType Default *DefaultValue Description *string GE *float64 @@ -195,6 +238,16 @@ func (f *InputField) IsRequired() bool { return f.Default == nil && (f.FieldType.Repetition == Required || f.FieldType.Repetition == Repeated) } +// ValidateInputField checks combinations unsupported by the static input model. +func ValidateInputField(field InputField) error { + if field.InputType != nil && field.InputType.Kind == InputKindUnion { + if len(field.Choices) > 0 || field.GE != nil || field.LE != nil || field.MinLength != nil || field.MaxLength != nil || field.Regex != nil { + return errUnsupportedType("constraints and choices are not supported on union inputs") + } + } + return nil +} + // PredictorInfo is the top-level extraction result. type PredictorInfo struct { Inputs *OrderedMap[string, InputField] @@ -434,6 +487,208 @@ func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[str return FieldType{}, errUnsupportedType("unknown type annotation") } +// ResolveInputType resolves a TypeAnnotation into the recursive input type model +// and the legacy FieldType compatibility layer. +func ResolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, FieldType, error) { + inputType, err := resolveInputType(ann, ctx, typedDicts) + if err != nil { + return InputType{}, FieldType{}, err + } + return inputType, fieldTypeFromInputType(inputType), nil +} + +func resolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + if inner, ok := unwrapOpaqueAnnotated(ann, ctx); ok { + return inputTypeFromFieldType(opaqueFieldType(inner, ctx)), nil + } + + switch ann.Kind { + case TypeAnnotSimple: + name := ann.Name + if typedDicts[name] { + return InputAnyType(), nil + } + qualifiedEntry := ImportEntry{} + if resolved, entry, ok := ctx.ResolveQualifiedName(name); ok { + name = resolved + qualifiedEntry = entry + if typedDicts[entry.Original+"."+name] { + return InputAnyType(), nil + } + } + if typedDicts[name] { + return InputAnyType(), nil + } + if name == "dict" || name == "Dict" { + return InputAnyType(), nil + } + prim, ok := PrimitiveFromName(name) + if !ok { + if qualifiedEntry.Module != "" { + return InputType{}, errUnresolvableImportedType(name, qualifiedEntry.Module) + } + if entry, imported := ctx.Names.Get(name); imported { + return InputType{}, errUnresolvableImportedType(name, entry.Module) + } + return InputType{}, errUnsupportedType(name) + } + return InputPrimitive(prim), nil + + case TypeAnnotGeneric: + outer := ann.Name + if resolved, _, ok := ctx.ResolveQualifiedName(outer); ok { + outer = resolved + } + if outer == "dict" || outer == "Dict" { + return InputAnyType(), nil + } + if outer == "Optional" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("Optional expects exactly 1 type argument, got %d", len(ann.Args))) + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + inner.Nullable = true + return inner, nil + } + if outer == "Union" { + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + if ctx.isAnnotated(ann.Name) { + if len(ann.Args) == 0 { + return InputType{}, errUnsupportedType("Annotated expects at least 1 type argument") + } + return resolveInputType(ann.Args[0], ctx, typedDicts) + } + if outer == "List" || outer == "list" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("List expects exactly 1 type argument, got %d", len(ann.Args))) + } + if opaqueInner, ok := unwrapOpaqueAnnotated(ann.Args[0], ctx); ok { + inner := inputTypeFromFieldType(opaqueFieldType(opaqueInner, ctx)) + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + return InputType{}, errUnsupportedType(fmt.Sprintf("%s[...] is not a supported input type", outer)) + + case TypeAnnotUnion: + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + return InputType{}, errUnsupportedType("unknown type annotation") +} + +func resolveInputUnion(args []TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + variants := make([]InputType, 0, len(args)) + nullable := false + for _, arg := range args { + if arg.Kind == TypeAnnotSimple && arg.Name == "None" { + nullable = true + continue + } + if arg.Kind == TypeAnnotUnion || (arg.Kind == TypeAnnotGeneric && arg.Name == "Union") { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variant, err := resolveInputType(arg, ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if variant.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variants = append(variants, variant) + } + + if len(variants) == 0 { + return InputType{}, errUnsupportedType("union inputs must include at least one non-None type") + } + if len(variants) == 1 { + variant := variants[0] + variant.Nullable = variant.Nullable || nullable + return variant, nil + } + for _, variant := range variants { + if err := validateUnionVariant(variant); err != nil { + return InputType{}, err + } + } + + union := InputUnionOf(variants...) + union.Nullable = nullable + return union, nil +} + +func validateUnionVariant(inputType InputType) error { + if inputType.Nullable { + return errUnsupportedType("nested nullable variants are not supported in union inputs") + } + switch inputType.Kind { + case InputKindPrimitive: + if inputType.Primitive == TypePath || inputType.Primitive == TypeFile || inputType.Primitive == TypeSecret { + return errUnsupportedType(fmt.Sprintf("%s is not supported in union inputs", inputType.Primitive)) + } + case InputKindArray: + if inputType.Elem != nil { + return validateUnionVariant(*inputType.Elem) + } + case InputKindUnion: + return errUnsupportedType("nested union inputs are not supported") + } + return nil +} + +func inputTypeFromFieldType(fieldType FieldType) InputType { + var inputType InputType + if fieldType.Primitive == TypeAny { + inputType = InputAnyType() + } else { + inputType = InputPrimitive(fieldType.Primitive) + } + if fieldType.Repetition == Repeated || fieldType.Repetition == OptionalRepeated { + inputType = InputArrayOf(inputType) + } + if fieldType.Repetition == Optional || fieldType.Repetition == OptionalRepeated { + inputType.Nullable = true + } + return inputType +} + +func fieldTypeFromInputType(inputType InputType) FieldType { + repetition := Required + if inputType.Nullable { + repetition = Optional + } + switch inputType.Kind { + case InputKindPrimitive: + return FieldType{Primitive: inputType.Primitive, Repetition: repetition} + case InputKindArray: + arrayRepetition := Repeated + if inputType.Nullable { + arrayRepetition = OptionalRepeated + } + if inputType.Elem != nil && inputType.Elem.Kind == InputKindPrimitive { + return FieldType{Primitive: inputType.Elem.Primitive, Repetition: arrayRepetition} + } + return FieldType{Primitive: TypeAny, Repetition: arrayRepetition} + case InputKindAny, InputKindUnion: + return FieldType{Primitive: TypeAny, Repetition: repetition} + default: + return FieldType{Primitive: TypeAny, Repetition: repetition} + } +} + func unwrapOpaqueAnnotated(ann TypeAnnotation, ctx *ImportContext) (TypeAnnotation, bool) { if ann.Kind != TypeAnnotGeneric || !ctx.isAnnotated(ann.Name) || len(ann.Args) < 2 { return ann, false diff --git a/python/cog/_adt.py b/python/cog/_adt.py index f40f870e47..3fea80cc27 100644 --- a/python/cog/_adt.py +++ b/python/cog/_adt.py @@ -37,6 +37,10 @@ def _is_union(tpe: type) -> bool: return False +def _is_none_type(tpe: Any) -> bool: + return tpe is None or tpe is type(None) + + def _is_dict_like(tpe: Any) -> bool: """Check if a type should be treated like a dict, including TypedDict.""" if tpe is dict: @@ -52,7 +56,7 @@ def _is_dict_like(tpe: Any) -> bool: pass is_typeddict = getattr(typing, "is_typeddict", None) - return callable(is_typeddict) and is_typeddict(tpe) + return bool(callable(is_typeddict) and is_typeddict(tpe)) def _unwrap_opaque(tpe: Any) -> tuple[Any, bool]: @@ -249,6 +253,90 @@ class Repetition(Enum): OPTIONAL_REPEATED = 4 # list[X] | None +def _is_supported_union_variant(ft: "FieldType") -> bool: + if ft.union_variants is not None: + return False + if ft.primitive in { + PrimitiveType.PATH, + PrimitiveType.FILE, + PrimitiveType.SECRET, + PrimitiveType.CUSTOM, + }: + return False + return ft.repetition in {Repetition.REQUIRED, Repetition.REPEATED} + + +def _is_exact_union_match(value: Any, ft: "FieldType") -> bool: + if ft.repetition is Repetition.REPEATED: + return isinstance(value, list) + if ft.repetition is not Repetition.REQUIRED: + return False + if ft.primitive is PrimitiveType.BOOL: + return isinstance(value, bool) + if ft.primitive is PrimitiveType.INTEGER: + return isinstance(value, int) and not isinstance(value, bool) + if ft.primitive is PrimitiveType.FLOAT: + return isinstance(value, float) + if ft.primitive is PrimitiveType.STRING: + return isinstance(value, str) + if ft.primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_primitive_accepts_value(value: Any, primitive: PrimitiveType) -> bool: + if primitive is PrimitiveType.BOOL: + return type(value) is bool + if primitive is PrimitiveType.INTEGER: + return type(value) is int + if primitive is PrimitiveType.FLOAT: + return type(value) is float or type(value) is int + if primitive is PrimitiveType.STRING: + return type(value) is str + if primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_variant_accepts_value(value: Any, variant: "FieldType") -> bool: + if variant.repetition is Repetition.REPEATED: + return isinstance(value, list) and all( + _union_primitive_accepts_value(element, variant.primitive) + for element in value + ) + if variant.repetition is not Repetition.REQUIRED: + return False + return _union_primitive_accepts_value(value, variant.primitive) + + +def _union_variant_priority(ft: "FieldType") -> int: + primitive_priority = { + PrimitiveType.BOOL: 0, + PrimitiveType.INTEGER: 1, + PrimitiveType.FLOAT: 2, + PrimitiveType.STRING: 3, + PrimitiveType.ANY: 4, + } + repetition_offset = 10 if ft.repetition is Repetition.REPEATED else 0 + return repetition_offset + primitive_priority.get(ft.primitive, 100) + + +def _ordered_union_variants( + value: Any, variants: List["FieldType"] +) -> List["FieldType"]: + return [ + variant + for _, variant in sorted( + enumerate(variants), + key=lambda item: ( + not _is_exact_union_match(value, item[1]), + _union_variant_priority(item[1]), + item[0], + ), + ) + ] + + @dataclass(frozen=True) class FieldType: """Type information for an input/output field.""" @@ -256,6 +344,7 @@ class FieldType: primitive: PrimitiveType repetition: Repetition coder: Optional[Coder] + union_variants: Optional[List["FieldType"]] = None @staticmethod def from_type(tpe: type) -> "FieldType": @@ -317,9 +406,43 @@ def from_type(tpe: type) -> "FieldType": elif _is_union(tpe): t_args = typing.get_args(tpe) - if not (len(t_args) == 2 and type(None) in t_args): - raise ValueError(f"unsupported union type {tpe}") - elem_t = t_args[0] if t_args[1] is type(None) else t_args[1] + has_none = any(_is_none_type(arg) for arg in t_args) + non_none_args = [arg for arg in t_args if not _is_none_type(arg)] + + if len(non_none_args) != 1: + repetition = Repetition.OPTIONAL if has_none else Repetition.REQUIRED + variants = [] + for arg in non_none_args: + try: + variant = FieldType.from_type(arg) + except ValueError as exc: + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) from exc + if not _is_supported_union_variant(variant): + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) + variants.append(variant) + return FieldType( + primitive=PrimitiveType.ANY, + repetition=repetition, + coder=None, + union_variants=variants, + ) + + if not has_none: + elem_t = non_none_args[0] + repetition = Repetition.REQUIRED + cog_t = PrimitiveType.from_type(elem_t) + coder = None + if cog_t is PrimitiveType.CUSTOM: + coder = Coder.lookup(elem_t) + if coder is None: + raise ValueError(f"unsupported Cog type {_type_name(elem_t)}") + return FieldType(primitive=cog_t, repetition=repetition, coder=coder) + + elem_t = non_none_args[0] inner, elem_is_opaque = _unwrap_opaque(elem_t) if elem_is_opaque: return _opaque_field_type(inner | None) @@ -384,6 +507,20 @@ def from_type(tpe: type) -> "FieldType": def normalize(self, value: Any) -> Any: """Normalize a value according to this field type.""" + if self.union_variants is not None: + if value is None: + if self.repetition is Repetition.OPTIONAL: + return None + raise ValueError("missing value for required union field") + for variant in _ordered_union_variants(value, self.union_variants): + if not _union_variant_accepts_value(value, variant): + continue + try: + return variant.normalize(value) + except (TypeError, ValueError): + pass + raise ValueError(f"failed to normalize value as {self.python_type_name()}") + if self.repetition is Repetition.REQUIRED: return self.primitive.normalize(value) elif self.repetition is Repetition.OPTIONAL: @@ -398,6 +535,14 @@ def normalize(self, value: Any) -> Any: def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this field.""" + if self.union_variants is not None: + jt: Dict[str, Any] = { + "anyOf": [variant.json_type() for variant in self.union_variants] + } + if self.repetition is Repetition.OPTIONAL: + jt["nullable"] = True + return jt + if self.repetition in (Repetition.REPEATED, Repetition.OPTIONAL_REPEATED): return {"type": "array", "items": self.primitive.json_type()} return self.primitive.json_type() @@ -428,6 +573,14 @@ def json_decode(self, value: Any) -> Any: def python_type_name(self) -> str: """Get the Python type name for this field.""" + if self.union_variants is not None: + name = " | ".join( + variant.python_type_name() for variant in self.union_variants + ) + if self.repetition is Repetition.OPTIONAL: + return f"Optional[{name}]" + return name + if self.repetition is Repetition.REQUIRED: return self.primitive.python_type_name() elif self.repetition is Repetition.OPTIONAL: diff --git a/python/tests/test_adt.py b/python/tests/test_adt.py index ebffeb7d72..22800bf36b 100644 --- a/python/tests/test_adt.py +++ b/python/tests/test_adt.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Dict, List, Optional, TypedDict +import pytest from typing_extensions import TypedDict as ExtensionsTypedDict from cog import Opaque @@ -322,3 +323,106 @@ def test_optional_str(self) -> None: ft = FieldType.from_type(Optional[str]) assert ft.primitive is PrimitiveType.STRING assert ft.repetition is Repetition.OPTIONAL + + +class TestUnionInputTypes: + def test_union_str_float_field_type(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.primitive is PrimitiveType.ANY + assert ft.repetition is Repetition.REQUIRED + assert ft.union_variants is not None + assert [v.primitive for v in ft.union_variants] == [ + PrimitiveType.STRING, + PrimitiveType.FLOAT, + ] + + def test_union_str_float_none_field_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.union_variants is not None + + def test_union_int_float_prefers_int(self) -> None: + ft = FieldType.from_type(int | float) + assert ft.normalize(1) == 1 + assert isinstance(ft.normalize(1), int) + + def test_union_bool_int_prefers_bool(self) -> None: + ft = FieldType.from_type(bool | int) + value = ft.normalize(True) + assert value is True + + def test_union_int_float_rejects_bool(self) -> None: + ft = FieldType.from_type(int | float) + with pytest.raises(ValueError): + ft.normalize(True) + + def test_union_str_bool_rejects_int(self) -> None: + ft = FieldType.from_type(str | bool) + with pytest.raises(ValueError): + ft.normalize(1) + + def test_union_str_dict_rejects_scalar(self) -> None: + ft = FieldType.from_type(str | dict) + with pytest.raises(ValueError): + ft.normalize(123) + + def test_union_list_int_float_rejects_bool_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize([True]) + + def test_union_list_int_float_rejects_string_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize(["3"]) + + def test_union_list_int_float_accepts_numeric_elements(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([1]) == [1] + assert isinstance(ft.normalize([1])[0], int) + assert ft.normalize([1.5]) == [1.5] + + def test_union_optional_normalize_none(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.normalize(None) is None + + def test_union_required_normalize_none_raises(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.repetition is Repetition.REQUIRED + with pytest.raises(ValueError): + ft.normalize(None) + + def test_union_required_json_type_omits_nullable(self) -> None: + ft = FieldType.from_type(int | str) + assert ft.json_type() == { + "anyOf": [{"type": "integer"}, {"type": "string"}], + } + + def test_union_list_int_float_accepts_empty_list(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([]) == [] + + def test_union_mixed_scalar_and_list(self) -> None: + ft = FieldType.from_type(list[int] | int) + assert ft.normalize(5) == 5 + assert isinstance(ft.normalize(5), int) + assert ft.normalize([5]) == [5] + + def test_union_str_float_none_json_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.json_type() == { + "anyOf": [{"type": "string"}, {"type": "number"}], + "nullable": True, + } + + def test_union_rejects_path_string(self) -> None: + from cog import Path + + try: + FieldType.from_type(Path | str) + except ValueError as exc: + assert "Path" in str(exc) + assert "union" in str(exc) + else: + raise AssertionError("Expected ValueError for Path | str") diff --git a/python/tests/test_inspector.py b/python/tests/test_inspector.py index 7597a783c2..bf1307aae9 100644 --- a/python/tests/test_inspector.py +++ b/python/tests/test_inspector.py @@ -168,6 +168,22 @@ def predict(self, value: Annotated[ExternalObject, Opaque]) -> str: assert field.type.repetition is adt.Repetition.REQUIRED +def test_inspector_supports_union_input() -> None: + class Predictor: + def predict(self, value: str | float) -> str: + return str(value) + + info = _create_predictor_info( + "predict", "Predictor", Predictor.predict, "predict", True + ) + field = info.inputs["value"] + assert field.type.union_variants is not None + assert [v.primitive for v in field.type.union_variants] == [ + adt.PrimitiveType.STRING, + adt.PrimitiveType.FLOAT, + ] + + def test_inspector_preserves_opaque_list_input_metadata() -> None: class Predictor: def predict(self, value: Annotated[List[ExternalObject], Opaque]) -> str: