feat: implement JIT warmup and upgrade global kernel cache with persistence#81
feat: implement JIT warmup and upgrade global kernel cache with persistence#81goog00 wants to merge 13 commits into
Conversation
|
Hi @goog00, The warmup and persistence work looks great! We're about to merge a large API redesign to I'll follow up here once it's merged so you know when to rebase. Happy to help with any conflicts. |
|
OK, Thanks, I’ll rebase my changes accordingly. |
7943ef4 to
f90e850
Compare
|
@elibol I've rebased the pr and resolved all conflicts with the latest changes in main. |
|
hi,@elibol
Verification
|
|
Hi @goog00! This is a big one 🙂 Here are some preliminary high-priority thoughts. I like the architecture -- the Two areas worth looking at: Naming and API consistency
Correctness
|
32270e4 to
0868b89
Compare
|
@elibol Thanks for the detailed reviews! I've addressed all the points:
Ready for another look! |
174ae1e to
5e28715
Compare
|
Hi @elibol , just a gentle ping -- I've addressed all the feedback from your last review (builder API, corrupted cubin self-healing, naming consistency, etc.) and the branch is ready. No rush, but wanted to make sure this didn't fall off your radar. Happy to clarify anything or make further adjustments! |
|
Thanks for the heads-up! I'll resolve the conflicts. |
4105154 to
dd23a27
Compare
|
@elibol Conflicts resolved (rebased main). Ready for review! Happy to make further adjustments if needed. |
Update warmup and warmup_bench tests to use current tensor creation signature patterns, import required prelude traits for .sync().partition()
…s after rebased - Extend WarmupSpec with spec_args and add with_spec_args(...) builder. - Update compile_warmup() to use spec.spec_args when building TileFunctionKey CUDATileFunctionCompiler - Migrate warmup benchmarks to explicit runtime-derived specialization bits: - Upgrade GPU warmup tests from placeholder vec[] to real specialization args
…compile_options - Adapts cpu-only cache-key tests to the 11-arg TileFunctionKey signature after rebasing on the auto-infer alignment change.
- Rename __ prefix to _ for macro-generated symbols - Add TileFunctionKey::builder() API; migrate all tests off TileFunctionKey::new() - Add FunctionKey::get_disk_hash_string default impl (falls back to get_hash_string) - Add clear_kernel_cache() and evict_kernel() named operations - Recover from corrupted disk CUBIN: delete + fall through to JIT recompile
Covers the recovery path where a corrupted .cubin on disk is detected,
deleted, and JIT-recompiled instead of hard-erroring.
Summary
This PR implements global thread-safe kernel caching with JIT warmup APIs and disk
persistence for cutile-rs. It solves cold-start latency in production GPU inference by
enabling three key capabilities:
Thread-Safety Design
CudaContext: A single per-device context (viaOnceLock<Mutex<HashMap>>) is shared across all threads, enabling compiled modulesand functions to be used from any thread without reloading.
KERNEL_CACHEusesDashMap<String, Arc<OnceCell<CompiledKernel>>>to ensure when multiple threadsrequest the same kernel specialization, only one thread performs JIT compilation
while others wait and reuse the result. A shard-lock release pattern prevents lock
contention during expensive compilation.
FileSystemJitStoreuses atomic write-then-rename withPID +
AtomicU64counter–based tmp filenames, guaranteeing that concurrentprocesses/threads produce correct cubins without partial files or collisions.
Pre-JIT warmup (
compile_warmup): Pre-compile kernel specializations at startupwithout launching, populating both memory and disk caches. Users specify only generics
and strides.
Realistic execution warmup (
execute_warmup): Run a user-provided closure thatlaunches kernels with production shapes and data, warming both the JIT cache and CUDA
runtime.
Cross-process cubin reuse via a
JitStoretrait: Compiled cubins persist to diskand are loaded in subsequent processes, eliminating recompilation. Cache keys encode
source content (compile-time
_SOURCE_HASH), GPU architecture, compiler version, andCUDA toolkit version, making false cache hits provably impossible.
The
#[cutile::module]macro now emits_SOURCE_HASH,_entries(), and_compile_warmup(), reducing caller boilerplate.Addressing Feedback
This PR directly addresses all points from #5 comments:
OnceLock<Mutex<HashMap>>indevice_context.rs, enabling thread-safe kernel reuse across all threadsDashMap<Key, Arc<OnceLock<CompiledKernel>>>pattern with shard-lock release optimization_SOURCE_HASH(SHA-256) emitted by the macro folds into the cache key, preventing manual cache-clear workflows_entries(),_SOURCE_HASH, and_compile_warmup()so users only specify generics and stridesTileFunctionKeyincludesgpu_namefor architecture-specific cubin artifactssource hash + gpu_name + compiler_version + cuda_toolkit_version, with SHA-256 for disk keysMajor Changes
cuda-async
jit_store.rs:JitStoretrait with object-store interface (get/put/contains/delete/clear) andFileSystemJitStoreimplementation with atomic writesJitStoreconfiguration (set_jit_store,set_jit_store_if_unset,ensure_default_jit_store) withCUTILE_NO_DISK_CACHE=1opt-outdevice_context.rs: globalCUDA_CONTEXTS(process-wide, per-device) andKERNEL_CACHE(DashMap<String, Arc<OnceCell<CompiledKernel>>>)CompiledKernelstruct andFunctionKeytrait withget_hash_string/get_disk_hash_stringload_module_from_bytestmp filenames unique with PID +AtomicU64countercutile
compile_warmup: pre-compile kernel specializations into the global cache and disk store without launchingexecute_warmup: run user-provided closure for realistic kernel launches (warms both JIT cache and CUDA runtime)TileFunctionKeywith full cache key (source_hash,gpu_name,compiler_version,cuda_toolkit_version)WarmupSpecbuilder andEntryMetatypescompile_from_contextcutile-macro
pub const __SOURCE_HASH: &str(SHA-256 of module source at compile time)pub fn __entries() -> Vec<EntryMeta>with all entry point metadatapub fn __compile_warmup(specs)convenience helper so callers avoid passing module internalsTests
cutile/tests/warmup.rs)cutile/tests/gpu/warmup.rs)cutile/tests/gpu/warmup_bench.rs)cuda-async/tests/jit_store.rs)Design Decisions
1. Global
CUDA_CONTEXTSfor process-wide context sharingPreviously, each thread maintained its own
CudaContext(thread-local). This prevented modules or functions loaded against one context from being used by another thread. Moving toOnceLock<Mutex<HashMap<device_id, Arc<CudaContext>>>>creates a single per-device context shared across all threads. Once a module is loaded against a context, any thread can call its functions. The mutex protects the HashMap during the rare insert operation; actual kernel launches remain lock-free since each thread has anArcto the same context.2. Global
KERNEL_CACHEwithDashMap+OnceCellfor single-flight dedupThe cache is
DashMap<String, Arc<OnceCell<CompiledKernel>>>where the outerDashMapenables lock-free reads and fine-grained sharding across 16 shards (by default). The innerOnceCell<CompiledKernel>guarantees single-flight: when multiple threads hash to the same (shard, key), only one callsget_or_try_init(the JIT compilation closure) while others block and receive the same result. This prevents N threads from independently compiling the same kernel. See decision 8 (shard-lock pattern) for how we avoid holding the shard write lock during compilation.3.
compile_warmupfor pre-JIT without launchescompile_warmup(module_asts, entries, specs)builds the cache AST once, iterates over specs, and for each spec: looks up the key in the global cache, tries disk (JitStore), then JIT-compiles if needed. Results go into memory and disk caches. No kernels are launched — it's purely compilation. This decouples warmup from realistic execution shapes, allowing deployment scenarios to pre-warm all target shapes at startup (e.g., during container initialization) and amortize compilation cost before the first real request.4.
execute_warmupfor realistic end-to-end warmupexecute_warmupaccepts a closure where users launch actual kernels with production-representative shapes and data. Kernels inside the closure auto-JIT via the normalcompile_from_contextpath, but this time they hit the pre-populated cache (fromcompile_warmup) or compile on first encounter. This warms not just the JIT cache but also the CUDA runtime (driver initialization, occupancy calculation, shared memory allocation, stream scheduling). The two-stage approach (pre-compile, then execute realistic) is more realistic than purecompile_warmupalone.5.
JitStoretrait for pluggable disk persistenceInstead of hardcoding cubin serialization to a specific filesystem location, we expose a
JitStoretrait with five operations:get,put,contains,delete,clear. This allows alternative backends (e.g., database, object storage, network cache) without modifying core code. The defaultFileSystemJitStoreuses atomic write-then-rename to prevent corrupted partial files. The trait is markedSend + Sync, enabling thread-safe access from multiple threads and processes (e.g., N processes each trying toputthe same cubin). Configuration is global (set_jit_store,set_jit_store_if_unset) and lazy-initialized viaensure_default_jit_store().6.
CompiledKernelbundles module + function + validatorInstead of caching just the
CudaFunction, we cache a struct holding three fields:module: Arc<CudaModule>,function: Arc<CudaFunction>,validator: Arc<Validator>. The validator (parameter types/shapes) is tied to a specific specialization — it would be incorrect to reuse a validator from a different generic instantiation. By bundling all three, cache hits guarantee correctness: get oneCompiledKerneland you get a matched triple. TheArcwrappers enable cheap cloning and cross-thread sharing.7.
once_cell::sync::OnceCell(crate) vsstd::sync::OnceLockThe core dedup pattern requires
get_or_try_init— fallible initialization that propagates errors without poisoning the cell.std::sync::OnceLock::get_or_try_initexists but is gated behind#), which is unavailable on stable rustc. Theonce_cellcrate provides a stable, well-tested implementation. Once the std feature stabilizes, migration is a one-line change per call site.8. Shard-lock release pattern
DashMap::entry()holds a shard write lock for the duration of the returnedEntryreference. Ifget_or_try_init(which blocks during JIT compilation) were called while holding that reference, all other threads hashing to the same shard would be blocked — even for unrelated keys. The fix: clone theArc<OnceCell>and drop the entry reference before callingget_or_try_init. This limits shard contention to the brief insert-or-lookup, not the entire compilation.9. Atomic write for
FileSystemJitStore::putCubin persistence uses write-to-temporary-file-then-rename. The temporary filename includes the PID and an
AtomicU64counter, ensuring uniqueness across threads and processes. On POSIX,rename()is atomic — readers either see the old file or the new complete file, never a partial write. This prevents corrupted cubins from half-written disk operations during concurrent warmup or crashes.10. Module-level
source_hashgranularityThe
#[cutile::module]macro computes a SHA-256 over the entire module source item at expansion time. This means changing any function in the module invalidates all cached kernels for that module. A finer-grained per-function hash would require tracking cross-function dependencies (shared types, constants, helper functions), which adds significant macro complexity. Module-level granularity is conservative (never serves stale code) and acceptable for v1 where modules are typically small and focused.11. Warn-only on
JitStore::putfailureDisk write failures (permissions, full disk, read-only filesystem) log a warning but do not propagate as errors. The JIT compilation has already succeeded and the result is in memory — the only consequence of a
putfailure is that the next process will re-compile instead of hitting disk. This keeps the disk cache purely advisory: it can only help, never block execution. The tradeoff is that silent write failures could mask a misconfigured cache directory.Open Questions
Module-level
source_hashgranularity: Any source change to a#[cutile::module]invalidates all kernels in that module. For modules with many kernels, this could cause unnecessary recompilation of unrelated functions. Is this acceptable for v1, or should we invest in per-function hashing before merge?Subprocess spawning in cache key construction:
get_gpu_name()spawnsnvidia-smiandget_cuda_toolkit_version()spawnsnvccon every call. While the code currently calls these once percompile_warmup(outside the loop), these are still expensive I/O operations. Should we provide cached versions (get_gpu_name_cached(),get_cuda_toolkit_version_cached()) that useOnceLockto memoize results? Or integrate these into a process-level context object to avoid repeated lookups?JitStore::putis warn-only: A silent disk write failure means the user may never know their cache is not being populated (e.g., wrong permissions on~/.cache/cutile/). Should we surface a one-time warning to stderr, make it configurable, or keep the current behavior?Cross-process test infrastructure: The
cross_process_warmup_workertest re-invokes the test binary as a subprocess with role-based dispatch viaCUTILE_WARMUP_WORKER_ROLE. This is somewhat fragile (depends on--exacttest name matching). Is this acceptable long-term, or should we extract a standalone test binary?Cache directory configuration: The default is
~/.cache/cutile/(Linux) withCUTILE_NO_DISK_CACHE=1to disable entirely. Should we also supportCUTILE_CACHE_DIRfor custom paths, or defer that to a follow-up?Test Plan
JitStore unit tests (CPU-only, no GPU required)
Validates
FileSystemJitStoreatomic write safety (write-to-tmp-then-rename prevents partial files), concurrentput()operations via PID +AtomicU64counter uniqueness, and all trait methods (get/contains/delete/clear) for basic functionality. Confirms directory creation and permission handling work correctly across concurrent access patterns.CPU-only property tests (no GPU required)
Validates cache key construction, determinism, and collision-resistance across all dimensions: source content, GPU architecture, compiler version, CUDA toolkit version, generics, and strides. Confirms that module-level SHA-256 hash correctly encodes source changes and that cache keys are stable.
GPU integration tests
Verifies core functionality:
compile_warmuppopulates memory cache, second-call cache hits, disk persistence,execute_warmuplaunches kernels correctly, and error handling on unknown functions or corrupted cubins (returnsError, not panic). Tests edge cases: empty spec lists, mixed cache hits and cold compiles in a batch.Multi-thread dedup tests
Confirms that
DashMap+OnceCellsingle-flight dedup works: when N threads request the same kernel, only one thread compiles (timing ratio ≤ 2.5x single compile). Validates that different kernel specializations can compile in parallel without interference.Cross-process disk-hit tests
Spawns subprocess workers with different roles to prove that Process A's compiled cubins persist to disk and Process B loads them without recompilation, with support for mixed disk-hit and cold-compile scenarios in a single batch. Validates that
CUTILE_NO_DISK_CACHE=1prevents any disk writes.Performance Benchmarks (
warmup_bench.rs)Timing-based proof that warmup and caching eliminate cold-start latency (RTX 4090):
@elibol This is a large PR with an extensive description — thank you for your time and careful review. Please feel free to raise any questions or concerns, and I will adjust and refine accordingly.