Skip to content

feat: implement JIT warmup and upgrade global kernel cache with persistence#81

Open
goog00 wants to merge 13 commits into
NVlabs:mainfrom
goog00:warmup
Open

feat: implement JIT warmup and upgrade global kernel cache with persistence#81
goog00 wants to merge 13 commits into
NVlabs:mainfrom
goog00:warmup

Conversation

@goog00
Copy link
Copy Markdown
Contributor

@goog00 goog00 commented Apr 8, 2026

Summary

This PR implements global thread-safe kernel caching with JIT warmup APIs and disk
persistence for cutile-rs. It solves cold-start latency in production GPU inference by
enabling three key capabilities:

  1. Thread-Safety Design

    • Global process-wide CudaContext: A single per-device context (via
      OnceLock<Mutex<HashMap>>) is shared across all threads, enabling compiled modules
      and functions to be used from any thread without reloading.
    • Single-flight compilation dedup: The global KERNEL_CACHE uses
      DashMap<String, Arc<OnceCell<CompiledKernel>>> to ensure when multiple threads
      request the same kernel specialization, only one thread performs JIT compilation
      while others wait and reuse the result. A shard-lock release pattern prevents lock
      contention during expensive compilation.
    • Atomic disk writes: FileSystemJitStore uses atomic write-then-rename with
      PID + AtomicU64 counter–based tmp filenames, guaranteeing that concurrent
      processes/threads produce correct cubins without partial files or collisions.
  2. Pre-JIT warmup (compile_warmup): Pre-compile kernel specializations at startup
    without launching, populating both memory and disk caches. Users specify only generics
    and strides.

  3. Realistic execution warmup (execute_warmup): Run a user-provided closure that
    launches kernels with production shapes and data, warming both the JIT cache and CUDA
    runtime.

  4. Cross-process cubin reuse via a JitStore trait: Compiled cubins persist to disk
    and are loaded in subsequent processes, eliminating recompilation. Cache keys encode
    source content (compile-time _SOURCE_HASH), GPU architecture, compiler version, and
    CUDA toolkit version, making false cache hits provably impossible.

The #[cutile::module] macro now emits _SOURCE_HASH, _entries(), and _compile_warmup(), reducing caller boilerplate.


Addressing Feedback

This PR directly addresses all points from #5 comments:

  • Global CudaContext: Implemented via OnceLock<Mutex<HashMap>> in device_context.rs, enabling thread-safe kernel reuse across all threads
  • Single-flight dedup: Used the suggested DashMap<Key, Arc<OnceLock<CompiledKernel>>> pattern with shard-lock release optimization
  • Automatic cache invalidation: Compile-time _SOURCE_HASH (SHA-256) emitted by the macro folds into the cache key, preventing manual cache-clear workflows
  • Metadata exposure: Macro generates _entries(), _SOURCE_HASH, and _compile_warmup() so users only specify generics and strides
  • GPU architecture in key: TileFunctionKey includes gpu_name for architecture-specific cubin artifacts
  • False cache hit prevention: Cache key encodes source hash + gpu_name + compiler_version + cuda_toolkit_version, with SHA-256 for disk keys
  • Compiler version tracking: Prevents stale cubins after cutile-rs updates (all crate versions kept identical)
  • CUDA toolkit version tracking: Different toolkit versions produce different PTX/cubins — now part of the key

Major Changes

cuda-async

  • Add jit_store.rs: JitStore trait with object-store interface (get/put/contains/delete/clear) and FileSystemJitStore implementation with atomic writes
  • Add global JitStore configuration (set_jit_store, set_jit_store_if_unset, ensure_default_jit_store) with CUTILE_NO_DISK_CACHE=1 opt-out
  • Upgrade device_context.rs: global CUDA_CONTEXTS (process-wide, per-device) and KERNEL_CACHE (DashMap<String, Arc<OnceCell<CompiledKernel>>>)
  • Add CompiledKernel struct and FunctionKey trait with get_hash_string / get_disk_hash_string
  • Make load_module_from_bytes tmp filenames unique with PID + AtomicU64 counter

cutile

  • Add compile_warmup: pre-compile kernel specializations into the global cache and disk store without launching
  • Add execute_warmup: run user-provided closure for realistic kernel launches (warms both JIT cache and CUDA runtime)
  • Add TileFunctionKey with full cache key (source_hash, gpu_name, compiler_version, cuda_toolkit_version)
  • Add WarmupSpec builder and EntryMeta types
  • Integrate 3-tier cache lookup into compile_from_context

cutile-macro

  • Emit pub const __SOURCE_HASH: &str (SHA-256 of module source at compile time)
  • Emit pub fn __entries() -> Vec<EntryMeta> with all entry point metadata
  • Emit pub fn __compile_warmup(specs) convenience helper so callers avoid passing module internals

Tests

  • Add CPU-only cache key property tests (cutile/tests/warmup.rs)
  • Add GPU integration tests including cross-process disk-hit verification (cutile/tests/gpu/warmup.rs)
  • Add performance verification benchmarks (cutile/tests/gpu/warmup_bench.rs)
  • Add jit store traits tests (cuda-async/tests/jit_store.rs)

Design Decisions

1. Global CUDA_CONTEXTS for process-wide context sharing

Previously, each thread maintained its own CudaContext (thread-local). This prevented modules or functions loaded against one context from being used by another thread. Moving to OnceLock<Mutex<HashMap<device_id, Arc<CudaContext>>>> creates a single per-device context shared across all threads. Once a module is loaded against a context, any thread can call its functions. The mutex protects the HashMap during the rare insert operation; actual kernel launches remain lock-free since each thread has an Arc to the same context.

2. Global KERNEL_CACHE with DashMap + OnceCell for single-flight dedup

The cache is DashMap<String, Arc<OnceCell<CompiledKernel>>> where the outer DashMap enables lock-free reads and fine-grained sharding across 16 shards (by default). The inner OnceCell<CompiledKernel> guarantees single-flight: when multiple threads hash to the same (shard, key), only one calls get_or_try_init (the JIT compilation closure) while others block and receive the same result. This prevents N threads from independently compiling the same kernel. See decision 8 (shard-lock pattern) for how we avoid holding the shard write lock during compilation.

3. compile_warmup for pre-JIT without launches

compile_warmup(module_asts, entries, specs) builds the cache AST once, iterates over specs, and for each spec: looks up the key in the global cache, tries disk (JitStore), then JIT-compiles if needed. Results go into memory and disk caches. No kernels are launched — it's purely compilation. This decouples warmup from realistic execution shapes, allowing deployment scenarios to pre-warm all target shapes at startup (e.g., during container initialization) and amortize compilation cost before the first real request.

4. execute_warmup for realistic end-to-end warmup

execute_warmup accepts a closure where users launch actual kernels with production-representative shapes and data. Kernels inside the closure auto-JIT via the normal compile_from_context path, but this time they hit the pre-populated cache (from compile_warmup) or compile on first encounter. This warms not just the JIT cache but also the CUDA runtime (driver initialization, occupancy calculation, shared memory allocation, stream scheduling). The two-stage approach (pre-compile, then execute realistic) is more realistic than pure compile_warmup alone.

5. JitStore trait for pluggable disk persistence

Instead of hardcoding cubin serialization to a specific filesystem location, we expose a JitStore trait with five operations: get, put, contains, delete, clear. This allows alternative backends (e.g., database, object storage, network cache) without modifying core code. The default FileSystemJitStore uses atomic write-then-rename to prevent corrupted partial files. The trait is marked Send + Sync, enabling thread-safe access from multiple threads and processes (e.g., N processes each trying to put the same cubin). Configuration is global (set_jit_store, set_jit_store_if_unset) and lazy-initialized via ensure_default_jit_store().

6. CompiledKernel bundles module + function + validator

Instead of caching just the CudaFunction, we cache a struct holding three fields: module: Arc<CudaModule>, function: Arc<CudaFunction>, validator: Arc<Validator>. The validator (parameter types/shapes) is tied to a specific specialization — it would be incorrect to reuse a validator from a different generic instantiation. By bundling all three, cache hits guarantee correctness: get one CompiledKernel and you get a matched triple. The Arc wrappers enable cheap cloning and cross-thread sharing.

7. once_cell::sync::OnceCell (crate) vs std::sync::OnceLock

The core dedup pattern requires get_or_try_init — fallible initialization that propagates errors without poisoning the cell. std::sync::OnceLock::get_or_try_init exists but is gated behind #![feature(once_cell_try)] ([rust-lang/rust#109737](rust-lang/rust#109737)), which is unavailable on stable rustc. The once_cell crate provides a stable, well-tested implementation. Once the std feature stabilizes, migration is a one-line change per call site.

8. Shard-lock release pattern

DashMap::entry() holds a shard write lock for the duration of the returned Entry reference. If get_or_try_init (which blocks during JIT compilation) were called while holding that reference, all other threads hashing to the same shard would be blocked — even for unrelated keys. The fix: clone the Arc<OnceCell> and drop the entry reference before calling get_or_try_init. This limits shard contention to the brief insert-or-lookup, not the entire compilation.

9. Atomic write for FileSystemJitStore::put

Cubin persistence uses write-to-temporary-file-then-rename. The temporary filename includes the PID and an AtomicU64 counter, ensuring uniqueness across threads and processes. On POSIX, rename() is atomic — readers either see the old file or the new complete file, never a partial write. This prevents corrupted cubins from half-written disk operations during concurrent warmup or crashes.

10. Module-level source_hash granularity

The #[cutile::module] macro computes a SHA-256 over the entire module source item at expansion time. This means changing any function in the module invalidates all cached kernels for that module. A finer-grained per-function hash would require tracking cross-function dependencies (shared types, constants, helper functions), which adds significant macro complexity. Module-level granularity is conservative (never serves stale code) and acceptable for v1 where modules are typically small and focused.

11. Warn-only on JitStore::put failure

Disk write failures (permissions, full disk, read-only filesystem) log a warning but do not propagate as errors. The JIT compilation has already succeeded and the result is in memory — the only consequence of a put failure is that the next process will re-compile instead of hitting disk. This keeps the disk cache purely advisory: it can only help, never block execution. The tradeoff is that silent write failures could mask a misconfigured cache directory.


Open Questions

  1. Module-level source_hash granularity: Any source change to a #[cutile::module] invalidates all kernels in that module. For modules with many kernels, this could cause unnecessary recompilation of unrelated functions. Is this acceptable for v1, or should we invest in per-function hashing before merge?

  2. Subprocess spawning in cache key construction: get_gpu_name() spawns nvidia-smi and get_cuda_toolkit_version() spawns nvcc on every call. While the code currently calls these once per compile_warmup (outside the loop), these are still expensive I/O operations. Should we provide cached versions (get_gpu_name_cached(), get_cuda_toolkit_version_cached()) that use OnceLock to memoize results? Or integrate these into a process-level context object to avoid repeated lookups?

  3. JitStore::put is warn-only: A silent disk write failure means the user may never know their cache is not being populated (e.g., wrong permissions on ~/.cache/cutile/). Should we surface a one-time warning to stderr, make it configurable, or keep the current behavior?

  4. Cross-process test infrastructure: The cross_process_warmup_worker test re-invokes the test binary as a subprocess with role-based dispatch via CUTILE_WARMUP_WORKER_ROLE. This is somewhat fragile (depends on --exact test name matching). Is this acceptable long-term, or should we extract a standalone test binary?

  5. Cache directory configuration: The default is ~/.cache/cutile/ (Linux) with CUTILE_NO_DISK_CACHE=1 to disable entirely. Should we also support CUTILE_CACHE_DIR for custom paths, or defer that to a follow-up?


Test Plan

JitStore unit tests (CPU-only, no GPU required)

Validates FileSystemJitStore atomic write safety (write-to-tmp-then-rename prevents partial files), concurrent put() operations via PID + AtomicU64 counter uniqueness, and all trait methods (get/contains/delete/clear) for basic functionality. Confirms directory creation and permission handling work correctly across concurrent access patterns.

CPU-only property tests (no GPU required)

Validates cache key construction, determinism, and collision-resistance across all dimensions: source content, GPU architecture, compiler version, CUDA toolkit version, generics, and strides. Confirms that module-level SHA-256 hash correctly encodes source changes and that cache keys are stable.

GPU integration tests

Verifies core functionality: compile_warmup populates memory cache, second-call cache hits, disk persistence, execute_warmup launches kernels correctly, and error handling on unknown functions or corrupted cubins (returns Error, not panic). Tests edge cases: empty spec lists, mixed cache hits and cold compiles in a batch.

Multi-thread dedup tests

Confirms that DashMap + OnceCell single-flight dedup works: when N threads request the same kernel, only one thread compiles (timing ratio ≤ 2.5x single compile). Validates that different kernel specializations can compile in parallel without interference.

Cross-process disk-hit tests

Spawns subprocess workers with different roles to prove that Process A's compiled cubins persist to disk and Process B loads them without recompilation, with support for mixed disk-hit and cold-compile scenarios in a single batch. Validates that CUTILE_NO_DISK_CACHE=1 prevents any disk writes.


Performance Benchmarks (warmup_bench.rs)

Test Result: All tests pass. Rebased on main.

Timing-based proof that warmup and caching eliminate cold-start latency (RTX 4090):

╔══════════════════════════════════════════════════════════╗
║           Full Warmup Workflow Verification              ║
╠══════════════════════════════════════════════════════════╣
║  1. compile_warmup:      202.7ms  (JIT to cache)        ║
║  2. execute_warmup:      102.5ms  (cache + CUDA init)   ║
║  3. production call:      11.2ms  (fully warm)          ║
╚══════════════════════════════════════════════════════════╝

╔══════════════════════════════════════════════════════════╗
║       Memory Cache Verification: 1st vs 2nd Call        ║
╠══════════════════════════════════════════════════════════╣
║  First  call (tile=16):    405.1ms  (JIT compile)       ║
║  Second call (tile=16):     13.8ms  (memory cache)      ║
╠══════════════════════════════════════════════════════════╣
║  ✓ Cache hit is 29x faster                              ║
╚══════════════════════════════════════════════════════════╝

╔══════════════════════════════════════════════════════════╗
║         Warmup Verification: First-Call Latency         ║
╠══════════════════════════════════════════════════════════╣
║  Without warmup (tile=32):    403.2ms  (includes JIT)   ║
║  Warmup step     (tile=64):    110.9ms  (pre-compile)   ║
║  With warmup     (tile=64):     14.5ms  (cache hit)     ║
╠══════════════════════════════════════════════════════════╣
║  ✓ Warmed-up call is 27.8x faster                       ║
╚══════════════════════════════════════════════════════════╝

@elibol This is a large PR with an extensive description — thank you for your time and careful review. Please feel free to raise any questions or concerns, and I will adjust and refine accordingly.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@goog00 goog00 changed the title feat: implement JIT warmup and upgrade global kernel cache with persi… feat: implement JIT warmup and upgrade global kernel cache with persistence Apr 10, 2026
@elibol
Copy link
Copy Markdown
Collaborator

elibol commented Apr 10, 2026

Hi @goog00,

The warmup and persistence work looks great!

We're about to merge a large API redesign to main (see discussion #80). It touches cuda-async, cutile, and cutile-macro — all three crates your PR modifies — so you should expect conflicts and some API changes when rebasing.

I'll follow up here once it's merged so you know when to rebase. Happy to help with any conflicts.

@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 10, 2026

OK, Thanks, I’ll rebase my changes accordingly.

@goog00 goog00 force-pushed the warmup branch 2 times, most recently from 7943ef4 to f90e850 Compare April 10, 2026 08:06
@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 10, 2026

@elibol I've rebased the pr and resolved all conflicts with the latest changes in main.
All tests are now passing.
Ready for review.

@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 11, 2026

hi,@elibol
Rebased on #79 and the auto-infer alignment change. A few notes on the adaptations:

  • updated warmup GPU tests for the new DeviceOp API — mechanical, no design interaction with the global kernel cache.

  • folded spec_args (SpecializationBits) and compile_options (CompileOptions) into TileFunctionKey, so both participate in the memory hash and the SHA-256 disk key.

  • migrated CPU cache-key tests to a TestKey builder and added regression tests for spec_args / compile_options.

Verification

  • cargo test -p cutile --test warmup passed
  • cargo test -p cutile --test gpu -- --nocapture passed
  • cargo test -p cuda-async --test jit_store passed

@elibol
Copy link
Copy Markdown
Collaborator

elibol commented Apr 11, 2026

Hi @goog00! This is a big one 🙂

Here are some preliminary high-priority thoughts.

I like the architecture -- the DashMap + OnceCell single-flight dedup with the shard-lock release is well designed, the cache key covering all 11 dimensions is thorough, and the test suite (especially cross-process disk-hit and multi-thread dedup) gives real confidence. The atomic write-then-rename for disk persistence is the right call.

Two areas worth looking at:

Naming and API consistency

  • The existing macro-generated items use single underscore (_module_asts). It might be worth aligning __SOURCE_HASH, __entries(), __compile_warmup() to the same convention.

  • The TestKey builder in the tests is a nice pattern -- have you considered promoting that into the real API for TileFunctionKey? With 11 positional String args, it's easy to silently transpose two and get a wrong-but-valid key. Something like:

    TileFunctionKey::builder("module", "function")
        .generics(vec!["f32".into()])
        .source_hash(hash)
        .gpu_name(gpu)
        .build()
  • For get_kernel_cache(), wrapping it behind named operations (clear_kernel_cache(), remove_kernel(), etc.) instead of exposing the raw DashMap could give more flexibility to change internals later.

  • FunctionKey::get_disk_hash_string could use a default impl that falls back to get_hash_string() -- that way existing downstream impls don't break.

Correctness

  • A corrupted .cubin on disk would currently hard-error at load_module_from_bytes. It might be worth catching that failure, deleting the bad entry, and falling through to JIT recompilation.

  • Small one: compile_from_context computes key.get_hash_string() twice into both key_str and cache_hash_str.

@goog00 goog00 force-pushed the warmup branch 3 times, most recently from 32270e4 to 0868b89 Compare April 11, 2026 10:04
@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 11, 2026

@elibol Thanks for the detailed reviews! I've addressed all the points:

  • Renamed macro-generated symbols from __ to _ prefix (_SOURCE_HASH, _entries, _compile_warmup)
  • Added TileFunctionKey::builder() API and migrated all test usages off TileFunctionKey::new()
  • Added FunctionKey::get_disk_hash_string() default impl (falls back to get_hash_string)
  • Added evict_kernel() and clear_kernel_cache() named operations
  • Implemented corrupted cubin self-healing: delete bad entry and fall through to JIT recompile
  • Added disk_cache_corrupted_cubin_self_heals test to cover the recovery path

Ready for another look!

@goog00 goog00 force-pushed the warmup branch 2 times, most recently from 174ae1e to 5e28715 Compare April 17, 2026 08:01
@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 21, 2026

Hi @elibol , just a gentle ping -- I've addressed all the feedback from your last review (builder API, corrupted cubin self-healing, naming consistency, etc.) and the branch is ready. No rush, but wanted to make sure this didn't fall off your radar. Happy to clarify anything or make further adjustments!

@elibol
Copy link
Copy Markdown
Collaborator

elibol commented Apr 29, 2026

@goog00 I'm ready to begin looking at this, but there are conflicts. I don't expect many more to come through since we've done many of the API surface changes in v0.0.2, and some other key PRs were merged. The one that immediately interacts with this PR is #114

@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 30, 2026

Thanks for the heads-up! I'll resolve the conflicts.

@goog00 goog00 force-pushed the warmup branch 4 times, most recently from 4105154 to dd23a27 Compare April 30, 2026 04:31
@goog00
Copy link
Copy Markdown
Contributor Author

goog00 commented Apr 30, 2026

@elibol Conflicts resolved (rebased main). Ready for review! Happy to make further adjustments if needed.

goog00 added 10 commits May 12, 2026 16:32
Update warmup and warmup_bench tests to use current tensor creation signature patterns,
import required prelude traits for .sync().partition()
…s after rebased

- Extend WarmupSpec with spec_args and add with_spec_args(...) builder.
- Update compile_warmup() to use spec.spec_args when building TileFunctionKey CUDATileFunctionCompiler
- Migrate warmup benchmarks to explicit runtime-derived specialization bits:
- Upgrade GPU warmup tests from placeholder vec[] to real specialization args
…compile_options

 - Adapts cpu-only cache-key tests to the 11-arg TileFunctionKey signature after rebasing on the auto-infer alignment change.
- Rename __ prefix to _ for macro-generated symbols
- Add TileFunctionKey::builder() API; migrate all tests off TileFunctionKey::new()
- Add FunctionKey::get_disk_hash_string default impl (falls back to get_hash_string)
- Add clear_kernel_cache() and evict_kernel() named operations
- Recover from corrupted disk CUBIN: delete + fall through to JIT recompile
    Covers the recovery path where a corrupted .cubin on disk is detected,
    deleted, and JIT-recompiled instead of hard-erroring.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants