diff --git a/Cargo.lock b/Cargo.lock index 5a45bf8..9240250 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,15 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -288,6 +297,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -363,6 +381,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.4.0" @@ -390,8 +418,11 @@ version = "0.0.2" dependencies = [ "anyhow", "cuda-core", + "dashmap", "futures", "half", + "once_cell", + "sha2", "thiserror 1.0.69", ] @@ -437,8 +468,10 @@ dependencies = [ "half", "linkme", "num-traits", + "once_cell", "proc-macro2", "quote", + "sha2", "syn", ] @@ -507,10 +540,35 @@ dependencies = [ "phf", "proc-macro2", "quote", + "sha2", "syn", "trybuild", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dissimilar" version = "1.0.11" @@ -804,6 +862,16 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -850,6 +918,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -1576,6 +1650,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1829,6 +1914,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28f89b80c87b8fb0cf04ab448d5dd0dd0ade2f8891bae878de66a75a28600e" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.24" diff --git a/Cargo.toml b/Cargo.toml index 4b13969..04377b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,9 @@ trybuild = { version = "1.0.49", features = ["diff"] } phf = { version = "0.11", features = ["macros"] } convert_case = "0.8" linkme = "0.3" +dashmap = "6" +sha2 = "0.10" +once_cell = "1.19" # ── Intra-workspace ────────────────────────────────────────────────────────── cuda-bindings = { path = "cuda-bindings", version = "0.0.2" } diff --git a/cuda-async/Cargo.toml b/cuda-async/Cargo.toml index 90d2fef..406149a 100644 --- a/cuda-async/Cargo.toml +++ b/cuda-async/Cargo.toml @@ -15,3 +15,6 @@ half = { workspace = true } futures = { workspace = true } anyhow = { workspace = true } thiserror = { workspace = true } +dashmap = { workspace = true } +sha2 = { workspace = true } +once_cell = { workspace = true } diff --git a/cuda-async/src/device_context.rs b/cuda-async/src/device_context.rs index 3a3b5c3..deab8c1 100644 --- a/cuda-async/src/device_context.rs +++ b/cuda-async/src/device_context.rs @@ -3,15 +3,29 @@ * SPDX-License-Identifier: Apache-2.0 */ -//! Thread-local GPU device state, kernel cache, and scheduling policy management. +//! GPU device state, global kernel cache, and scheduling policy management. +//! +//! ## Architecture +//! +//! - **Global (process-wide)**: [`Device`] per device and compiled kernel cache are shared +//! across all threads via [`OnceLock`] and [`DashMap`]. This allows compilation results from +//! one thread (e.g. warmup) to be visible to all worker threads. +//! +//! - **Per-thread**: Scheduling policy and deallocator stream remain thread-local, since +//! different threads may want different stream assignments. +//! +//! - **Compilation dedup**: When multiple threads need the same kernel, only one compiles it +//! while the rest wait, via `DashMap>>`. use crate::error::{device_assert, device_error, DeviceError}; use crate::scheduling_policies::{SchedulingPolicy, StreamPoolRoundRobin}; use cuda_core::{Device, Function, MemPool, Module, Stream}; +use dashmap::DashMap; +use once_cell::sync::OnceCell; use std::cell::Cell; use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; /// The GPU device used when no explicit device is specified. Device 0 is the first GPU. pub const DEFAULT_DEVICE_ID: usize = 0; @@ -28,12 +42,24 @@ pub const DEFAULT_NUM_DEVICES: usize = 1; pub const DEFAULT_ROUND_ROBIN_STREAM_POOL_SIZE: usize = 4; pub trait FunctionKey: Hash { + /// Fast hash for in-memory cache lookup (uses `DefaultHasher`). fn get_hash_string(&self) -> String { let mut hasher = DefaultHasher::new(); self.hash(&mut hasher); let hash_value: u64 = hasher.finish(); format!("{:x}", hash_value) } + + /// SHA-256 hash for disk persistence. Provides a collision-resistant key + /// suitable for storing compiled artifacts on disk. + /// + /// Implementors should override this to hash a canonical string representation + /// of all key fields for maximum collision resistance. The default falls back + /// to [`get_hash_string`](Self::get_hash_string) so existing downstream impls + /// continue to compile without change. + fn get_disk_hash_string(&self) -> String { + self.get_hash_string() + } } #[derive(Debug, Clone)] @@ -66,31 +92,84 @@ pub struct Validator { pub params: Vec, } -type DeviceFunctions = HashMap, Arc)>; -type DeviceFunctionValidators = HashMap>; +// ── Global Device (process-wide, per-device singleton) ───────────────────── + +/// Global per-device handles. Shared across all threads so that +/// `Module`/`Function` loaded against a device can be used from any thread. +static DEVICES: OnceLock>>> = OnceLock::new(); + +fn devices() -> &'static Mutex>> { + DEVICES.get_or_init(|| Mutex::new(HashMap::new())) +} + +/// Get or create the global [`Device`] for a device ordinal. +/// +/// The first call for a given `device_id` creates the device handle; subsequent +/// calls return the same `Arc`. +fn get_or_init_device(device_id: usize) -> Result, DeviceError> { + let mut devices = devices() + .lock() + .map_err(|_| device_error(device_id, "device map lock poisoned"))?; + if let Some(device) = devices.get(&device_id) { + return Ok(Arc::clone(device)); + } + let device = Device::new(device_id)?; + devices.insert(device_id, Arc::clone(&device)); + Ok(device) +} + +// ── Global kernel cache (process-wide, cross-thread) ──────────────────────── + +/// A compiled kernel: module, function handle, and parameter validator. +#[derive(Debug)] +pub struct CompiledKernel { + pub module: Arc, + pub function: Arc, + pub validator: Arc, +} + +/// Global kernel cache. `DashMap` for cross-thread sharing; inner `OnceLock` for +/// single-flight compilation dedup (if multiple threads need the same kernel, +/// only one compiles while the rest wait). Uses `once_cell::sync::OnceCell` +/// for stable fallible initialization (`get_or_try_init`). +static KERNEL_CACHE: OnceLock>>> = OnceLock::new(); -/// Per-device state: GPU device, scheduling policy, and compiled kernel cache. +/// Get the global kernel cache. /// -/// Each GPU device has one `AsyncDeviceContext` stored in a thread-local map. It holds: +/// Prefer the named operations below (`clear_kernel_cache`, `evict_kernel`) over +/// direct DashMap manipulation — they keep the internal representation an +/// implementation detail. +pub fn get_kernel_cache() -> &'static DashMap>> { + KERNEL_CACHE.get_or_init(DashMap::new) +} + +/// Remove all compiled kernels from the in-memory cache. /// -/// - A [`Device`] for driver API calls. -/// - A [`SchedulingPolicy`] that decides which stream each operation runs on. -/// - A cache of already-compiled kernel functions (keyed by [`FunctionKey::get_hash_string()`]). +/// Does not touch the disk cache (JitStore). Useful in tests that need a +/// clean slate without restarting the process. +pub fn clear_kernel_cache() { + get_kernel_cache().clear(); +} + +/// Evict a single compiled kernel from the in-memory cache by its hash string. +/// +/// Returns `true` if an entry was removed, `false` if the key was not present. +pub fn evict_kernel(key_str: &str) -> bool { + get_kernel_cache().remove(key_str).is_some() +} + +// ── Per-thread device state (scheduling policy + deallocator stream) ──────── + +/// Per-thread, per-device state: scheduling policy and deallocator stream. /// -/// The context is lazily initialized on first use with the default round-robin policy -/// ([`DEFAULT_ROUND_ROBIN_STREAM_POOL_SIZE`] = 4 streams). To customize, call -/// [`init_device_contexts`] before any GPU work. -// TODO (hme): None of this needs to be compiled per thread. +/// The CUDA context and kernel cache are global (see above). This struct only +/// holds the thread-local scheduling policy and deallocator stream. pub struct AsyncDeviceContext { #[expect(dead_code, reason = "will be used when multi-device is implemented")] device_id: usize, - // TODO: (hme): This will hurt perf due to contention. This should at least be static (OnceLock?). - device: Arc, deallocator_stream: Arc, policy: Arc, pool: Option>, - functions: DeviceFunctions, - validators: DeviceFunctionValidators, } pub struct AsyncDeviceContexts { @@ -157,16 +236,13 @@ pub fn new_device_context( device_id: usize, policy: Arc, ) -> Result { - let device = Device::new(device_id)?; + let device = get_or_init_device(device_id)?; let deallocator_stream = device.new_stream()?; Ok(AsyncDeviceContext { device_id, - device, deallocator_stream, policy, pool: None, - functions: HashMap::new(), - validators: HashMap::new(), }) } @@ -198,17 +274,14 @@ pub fn init_with_default_policy( hashmap: &mut HashMap, device_id: usize, ) -> Result<(), DeviceError> { - let device = Device::new(device_id)?; + let device = get_or_init_device(device_id)?; let policy = StreamPoolRoundRobin::new(&device, DEFAULT_ROUND_ROBIN_STREAM_POOL_SIZE)?; let deallocator_stream = device.new_stream()?; let device_context = AsyncDeviceContext { device_id, - device, deallocator_stream, policy: Arc::new(policy), pool: None, - functions: HashMap::new(), - validators: HashMap::new(), }; let pred = hashmap.insert(device_id, device_context).is_none(); device_assert(device_id, pred, "Device is already initialized.") @@ -296,7 +369,8 @@ pub fn with_device(device_id: usize, f: F) -> Result where F: FnOnce(&Arc) -> R, { - with_global_device_context(device_id, |device_context| f(&device_context.device)) + let device = get_or_init_device(device_id)?; + Ok(f(&device)) } // Default device policy. @@ -413,6 +487,28 @@ pub fn load_module_from_file(filename: &str, device_id: usize) -> Result Result, DeviceError> { + use std::io::Write; + use std::sync::atomic::{AtomicU64, Ordering}; + static TMP_COUNTER: AtomicU64 = AtomicU64::new(0); + let n = TMP_COUNTER.fetch_add(1, Ordering::Relaxed); + let tmp_dir = std::env::temp_dir(); + let tmp_path = tmp_dir.join(format!("cutile_cache_{}_{}.cubin", std::process::id(), n)); + let mut f = std::fs::File::create(&tmp_path) + .map_err(|e| device_error(device_id, &format!("Failed to write temp cubin: {e}")))?; + f.write_all(data) + .map_err(|e| device_error(device_id, &format!("Failed to write temp cubin: {e}")))?; + let result = load_module_from_file(tmp_path.to_str().unwrap(), device_id); + let _ = std::fs::remove_file(&tmp_path); + result +} + /// JIT-compile a PTX string into a CUDA module for the given device. pub fn load_module_from_ptx(ptx_src: &str, device_id: usize) -> Result, DeviceError> { with_device(device_id, |device| { @@ -421,71 +517,120 @@ pub fn load_module_from_ptx(ptx_src: &str, device_id: usize) -> Result, Arc), ) -> Result<(), DeviceError> { - with_global_device_context_mut(device_id, |device_context| { - let key = func_key.get_hash_string(); - let res = device_context.functions.insert(key.clone(), value); - device_assert(device_id, res.is_none(), "Unexpected cache key collision.") - })? + let key = func_key.get_hash_string(); + let cache = get_kernel_cache(); + let slot = cache + .entry(key) + .or_insert_with(|| Arc::new(OnceCell::new())); + // If the OnceCell is already initialized, this is a duplicate insert — that's fine, + // the first writer wins and subsequent inserts are no-ops. + let _ = slot.set(CompiledKernel { + module: value.0, + function: value.1, + // Insert a dummy validator; the real one is set via insert_function_validator. + // This maintains backward compatibility with the existing two-step insert pattern. + validator: Arc::new(Validator { params: vec![] }), + }); + Ok(()) } /// Check whether a kernel with the given key has already been compiled and cached. -pub fn contains_cuda_function(device_id: usize, func_key: &impl FunctionKey) -> bool { - with_global_device_context(device_id, |device_context| { - let key = func_key.get_hash_string(); - device_context.functions.contains_key(&key) - }) - .is_ok_and(|pred| pred) +pub fn contains_cuda_function(_device_id: usize, func_key: &impl FunctionKey) -> bool { + let key = func_key.get_hash_string(); + let cache = get_kernel_cache(); + if let Some(slot) = cache.get(&key) { + let lock: &OnceCell = slot.value().as_ref(); + lock.get().is_some() + } else { + false + } } /// Retrieve a previously compiled kernel from the cache. /// -/// # Panics +/// # Errors /// -/// Panics if no function with the given key exists. Use [`contains_cuda_function`] to -/// check first, or rely on the compilation pipeline which always inserts before retrieving. +/// Returns an error if no function with the given key exists. +/// Use [`contains_cuda_function`] to check first, or rely on the compilation +/// pipeline which always inserts before retrieving. pub fn get_cuda_function( device_id: usize, func_key: &impl FunctionKey, ) -> Result, DeviceError> { - with_global_device_context(device_id, |device_context| { - let key = func_key.get_hash_string(); - let entry = device_context - .functions - .get(&key) - .ok_or(device_error(device_id, "Failed to get cuda function."))?; - Ok(entry.1.clone()) - })? + let key = func_key.get_hash_string(); + let cache = get_kernel_cache(); + let slot = cache + .get(&key) + .ok_or_else(|| device_error(device_id, "Failed to get cuda function."))?; + let compiled = slot + .get() + .ok_or_else(|| device_error(device_id, "Kernel not yet compiled."))?; + Ok(Arc::clone(&compiled.function)) } pub fn insert_function_validator( - device_id: usize, + _device_id: usize, func_key: &impl FunctionKey, value: Arc, ) -> Result<(), DeviceError> { - with_global_device_context_mut(device_id, |device_context| { - let key = func_key.get_hash_string(); - let res = device_context.validators.insert(key.clone(), value); - device_assert(device_id, res.is_none(), "Unexpected cache key collision.") - })? + let key = func_key.get_hash_string(); + let cache = get_kernel_cache(); + let slot = cache + .entry(key) + .or_insert_with(|| Arc::new(OnceCell::new())); + // If the kernel is already compiled, updating the validator in-place is not possible + // with OnceCell. Instead we store validators in a separate map for backward compat. + // For Phase 1, we use a parallel validator store. + get_validator_cache().insert(func_key.get_hash_string(), value); + // Also try to initialize the slot if it's empty (this handles the case where + // insert_function_validator is called before insert_cuda_function — unlikely but safe). + let _ = slot; + Ok(()) } pub fn get_function_validator( device_id: usize, func_key: &impl FunctionKey, ) -> Result, DeviceError> { - with_global_device_context(device_id, |device_context| { - let key = func_key.get_hash_string(); - let entry = device_context - .validators - .get(&key) - .ok_or(device_error(device_id, "Failed to get function validator."))?; - Ok(entry.clone()) - })? + let key = func_key.get_hash_string(); + + // Check the kernel cache first (unified path via get_or_try_init stores + // the validator inside CompiledKernel). + let kernel_cache = get_kernel_cache(); + if let Some(slot) = kernel_cache.get(&key) { + let lock: &OnceCell = slot.value().as_ref(); + if let Some(compiled) = lock.get() { + if !compiled.validator.params.is_empty() { + return Ok(Arc::clone(&compiled.validator)); + } + } + } + + // Fall back to separate validator cache (backward compat with two-step insert). + let validator_cache = get_validator_cache(); + let validator = validator_cache + .get(&key) + .ok_or_else(|| device_error(device_id, "Failed to get function validator."))?; + Ok(Arc::clone(validator.value())) +} + +// ── Validator cache (backward compat for two-step insert callers) ──────────── + +/// Separate validator cache for backward compatibility with the two-step insert +/// pattern (`insert_cuda_function` + `insert_function_validator`). +/// +/// The primary compilation path (`compile_from_context`) now uses single-shot +/// `OnceLock::get_or_try_init` which stores the validator inside `CompiledKernel`. +/// This cache is only needed for code paths that still use the two-step pattern. +static VALIDATOR_CACHE: OnceLock>> = OnceLock::new(); + +fn get_validator_cache() -> &'static DashMap> { + VALIDATOR_CACHE.get_or_init(DashMap::new) } diff --git a/cuda-async/src/jit_store.rs b/cuda-async/src/jit_store.rs new file mode 100644 index 0000000..e59bac5 --- /dev/null +++ b/cuda-async/src/jit_store.rs @@ -0,0 +1,171 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Object-store-like interface for persisting JIT compilation artifacts to disk. +//! +//! The [`JitStore`] trait provides a simple key-value interface for caching compiled +//! cubins. Keys are SHA-256 hashes that encode all factors affecting compilation output +//! (source content, GPU architecture, compiler version, toolkit version). +//! +//! [`FileSystemJitStore`] is the default implementation, storing artifacts as individual +//! files under a configurable directory (defaults to `~/.cache/cutile/`). + +use std::io; +use std::path::PathBuf; +use std::sync::OnceLock; + +/// Object-store-like interface for persisting JIT compilation artifacts. +/// +/// Implementations store compiled cubins keyed by a SHA-256 hash string. +/// The hash encodes all factors that affect compilation output: +/// source content, GPU architecture, compiler version, and toolkit version. +pub trait JitStore: Send + Sync { + /// Retrieve a cached artifact by key. + fn get(&self, key: &str) -> io::Result>>; + + /// Store a compiled artifact. + fn put(&self, key: &str, data: &[u8]) -> io::Result<()>; + + /// Check whether an artifact exists without reading it. + fn contains(&self, key: &str) -> io::Result; + + /// Remove a cached artifact. + fn delete(&self, key: &str) -> io::Result<()>; + + /// Remove all cached artifacts. + fn clear(&self) -> io::Result<()>; +} + +/// Filesystem-backed JIT artifact store. +/// +/// Stores compiled cubins as individual `.cubin` files under a base directory. +pub struct FileSystemJitStore { + base_dir: PathBuf, +} + +impl FileSystemJitStore { + /// Create a new store at the given directory, creating it if necessary. + pub fn new(base_dir: PathBuf) -> io::Result { + std::fs::create_dir_all(&base_dir)?; + Ok(Self { base_dir }) + } + + /// Create a store at the default location (`~/.cache/cutile/`). + pub fn default_location() -> io::Result { + let dir = dirs_default_cache_dir().join("cutile"); + Self::new(dir) + } + + fn artifact_path(&self, key: &str) -> PathBuf { + self.base_dir.join(format!("{key}.cubin")) + } +} + +/// Returns a default cache directory, similar to `dirs::cache_dir()`. +fn dirs_default_cache_dir() -> PathBuf { + + #[cfg(target_os = "linux")] + { + if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") { + return PathBuf::from(xdg); + } + if let Some(home) = std::env::var_os("HOME") { + return PathBuf::from(home).join(".cache"); + } + } + PathBuf::from("/tmp") +} + +impl JitStore for FileSystemJitStore { + fn get(&self, key: &str) -> io::Result>> { + let path = self.artifact_path(key); + match std::fs::read(&path) { + Ok(data) => Ok(Some(data)), + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e), + } + } + + fn put(&self, key: &str, data: &[u8]) -> io::Result<()> { + use std::sync::atomic::{AtomicU64, Ordering}; + static TMP_COUNTER: AtomicU64 = AtomicU64::new(0); + let n = TMP_COUNTER.fetch_add(1, Ordering::Relaxed); + let path = self.artifact_path(key); + // Write to a uniquely-named temp file, then rename for atomicity. + // The PID + counter suffix prevents collisions across threads and processes. + let tmp_path = path.with_extension(format!("cubin.tmp.{}.{}", std::process::id(), n)); + std::fs::write(&tmp_path, data)?; + std::fs::rename(&tmp_path, &path) + } + + fn contains(&self, key: &str) -> io::Result { + Ok(self.artifact_path(key).exists()) + } + + fn delete(&self, key: &str) -> io::Result<()> { + let path = self.artifact_path(key); + match std::fs::remove_file(&path) { + Ok(()) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()), + Err(e) => Err(e), + } + } + + fn clear(&self) -> io::Result<()> { + if self.base_dir.exists() { + std::fs::remove_dir_all(&self.base_dir)?; + std::fs::create_dir_all(&self.base_dir)?; + } + Ok(()) + } +} + +// ── Global JitStore configuration ─────────────────────────────────────────── + +static JIT_STORE: OnceLock>> = OnceLock::new(); + +/// Configure the global JIT store. Call once at startup. +/// +/// Pass `None` to disable disk persistence. Panics if called more than once. +pub fn set_jit_store(store: Option>) { + if JIT_STORE.set(store).is_err() { + panic!("JIT store has already been configured"); + } +} + +/// Try to configure the global JIT store. Returns `true` if successfully set, +/// `false` if a store was already configured (in which case the argument is dropped). +/// +/// This is useful in test code where multiple tests may race to set the store. +pub fn set_jit_store_if_unset(store: Option>) -> bool { + JIT_STORE.set(store).is_ok() +} + +/// Get a reference to the global JIT store, if one has been configured. +pub fn get_jit_store() -> Option<&'static dyn JitStore> { + JIT_STORE.get().and_then(|s| s.as_ref().map(|b| b.as_ref())) +} + +/// Ensure a JIT disk cache is configured. +/// +/// If no store has been explicitly set via [`set_jit_store`] or +/// [`set_jit_store_if_unset`], this lazily initializes a +/// [`FileSystemJitStore`] at the default location (`~/.cache/cutile/` on Linux). +/// Set `CUTILE_NO_DISK_CACHE=1` to disable auto-initialization. +/// +/// This is called automatically from the compilation pipeline. Users who want +/// a custom store should call [`set_jit_store`] *before* any kernel compilation. +pub fn ensure_default_jit_store() { + static INIT: OnceLock<()> = OnceLock::new(); + INIT.get_or_init(|| { + if std::env::var("CUTILE_NO_DISK_CACHE").is_ok_and(|v| v == "1") { + return; + } + if let Ok(store) = FileSystemJitStore::default_location() { + // Best-effort: if set_jit_store was already called, this is a no-op. + let _ = set_jit_store_if_unset(Some(Box::new(store))); + } + }); +} diff --git a/cuda-async/src/lib.rs b/cuda-async/src/lib.rs index 99f66f3..4449b52 100644 --- a/cuda-async/src/lib.rs +++ b/cuda-async/src/lib.rs @@ -12,6 +12,7 @@ pub mod device_context; pub mod device_future; pub mod device_operation; pub mod error; +pub mod jit_store; pub mod launch; pub mod prelude; pub mod scheduling_policies; diff --git a/cuda-async/tests/jit_store.rs b/cuda-async/tests/jit_store.rs new file mode 100644 index 0000000..89df888 --- /dev/null +++ b/cuda-async/tests/jit_store.rs @@ -0,0 +1,138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Unit tests for `FileSystemJitStore`. +//! These tests run on CPU — no GPU required. + +use cuda_async::jit_store::{FileSystemJitStore, JitStore}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering}; + +// Monotonically increasing counter so each test gets its own directory, +// even when tests run in parallel within the same process. +static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0); + +fn tmp_store() -> (FileSystemJitStore, PathBuf) { + let id = TEST_COUNTER.fetch_add(1, Ordering::Relaxed); + let dir = std::env::temp_dir().join(format!("cutile_jit_test_{}_{id}", std::process::id())); + let store = FileSystemJitStore::new(dir.clone()).expect("failed to create store"); + (store, dir) +} + +fn cleanup(dir: &PathBuf) { + let _ = std::fs::remove_dir_all(dir); +} + +#[test] +fn put_and_get() { + let (store, dir) = tmp_store(); + let data = b"fake cubin data"; + store.put("abc123", data).unwrap(); + let result = store.get("abc123").unwrap(); + assert_eq!(result, Some(data.to_vec())); + cleanup(&dir); +} + +#[test] +fn get_missing_returns_none() { + let (store, dir) = tmp_store(); + let result = store.get("nonexistent").unwrap(); + assert_eq!(result, None); + cleanup(&dir); +} + +#[test] +fn contains() { + let (store, dir) = tmp_store(); + assert!(!store.contains("key1").unwrap()); + store.put("key1", b"data").unwrap(); + assert!(store.contains("key1").unwrap()); + cleanup(&dir); +} + +#[test] +fn delete() { + let (store, dir) = tmp_store(); + store.put("key2", b"data").unwrap(); + assert!(store.contains("key2").unwrap()); + store.delete("key2").unwrap(); + assert!(!store.contains("key2").unwrap()); + // Deleting a nonexistent key is a no-op. + store.delete("key2").unwrap(); + cleanup(&dir); +} + +#[test] +fn clear() { + let (store, dir) = tmp_store(); + store.put("a", b"1").unwrap(); + store.put("b", b"2").unwrap(); + store.clear().unwrap(); + assert!(!store.contains("a").unwrap()); + assert!(!store.contains("b").unwrap()); + cleanup(&dir); +} + +#[test] +fn put_overwrites() { + let (store, dir) = tmp_store(); + store.put("key", b"old").unwrap(); + store.put("key", b"new").unwrap(); + let result = store.get("key").unwrap(); + assert_eq!(result, Some(b"new".to_vec())); + cleanup(&dir); +} + +#[test] +fn large_data() { + let (store, dir) = tmp_store(); + let data = vec![0xABu8; 1024 * 1024]; // 1 MB + store.put("large", &data).unwrap(); + let result = store.get("large").unwrap().unwrap(); + assert_eq!(result.len(), data.len()); + assert_eq!(result, data); + cleanup(&dir); +} + +#[test] +fn concurrent_put_get() { + use std::sync::Arc; + let (store, dir) = tmp_store(); + let store = Arc::new(store); + let mut handles = vec![]; + for i in 0..8 { + let store = Arc::clone(&store); + handles.push(std::thread::spawn(move || { + let key = format!("concurrent_{i}"); + let data = vec![i as u8; 256]; + store.put(&key, &data).unwrap(); + let result = store.get(&key).unwrap().unwrap(); + assert_eq!(result, data); + })); + } + for h in handles { + h.join().unwrap(); + } + cleanup(&dir); +} + +#[test] +fn clear_on_empty_dir() { + let (store, dir) = tmp_store(); + store.clear().unwrap(); + store.clear().unwrap(); + cleanup(&dir); +} + +#[test] +fn keys_with_hex_characters() { + let (store, dir) = tmp_store(); + // Realistic SHA-256 hex key. + let key = "a3f8b2c1deadbeef0123456789abcdef0123456789abcdef0123456789abcdef"; + store.put(key, b"cubin bytes").unwrap(); + assert!(store.contains(key).unwrap()); + assert_eq!(store.get(key).unwrap(), Some(b"cubin bytes".to_vec())); + cleanup(&dir); +} diff --git a/cutile-compiler/src/cuda_tile_runtime_utils.rs b/cutile-compiler/src/cuda_tile_runtime_utils.rs index 6148979..46b4112 100644 --- a/cutile-compiler/src/cuda_tile_runtime_utils.rs +++ b/cutile-compiler/src/cuda_tile_runtime_utils.rs @@ -18,6 +18,38 @@ use uuid::Uuid; /// Set this to an absolute path such as `/opt/cuda-tile/bin/tileiras` to use /// that binary instead of the `tileiras` found on `PATH`. pub const TILEIRAS_PATH_ENV: &str = "CUTILE_TILEIRAS_PATH"; +/// Returns the cutile compiler version (from the workspace Cargo.toml). +pub fn get_compiler_version() -> String { + env!("CARGO_PKG_VERSION").to_string() +} + +/// Returns the CUDA toolkit version by parsing `nvcc --version` output. +/// +/// Falls back to `"unknown"` if `nvcc` is not available. +pub fn get_cuda_toolkit_version() -> String { + Command::new("nvcc") + .arg("--version") + .output() + .ok() + .and_then(|output| { + if !output.status.success() { + return None; + } + let stdout = String::from_utf8_lossy(&output.stdout); + // Parse lines like "Cuda compilation tools, release 12.4, V12.4.131" + for line in stdout.lines() { + if let Some(pos) = line.find("release ") { + let rest = &line[pos + "release ".len()..]; + if let Some(comma) = rest.find(',') { + return Some(rest[..comma].to_string()); + } + return Some(rest.trim().to_string()); + } + } + None + }) + .unwrap_or_else(|| "unknown".to_string()) +} /// Queries the CUDA driver to determine the SM architecture name (e.g. `"sm_90"`) for a device. pub fn get_gpu_name(device_id: usize) -> String { diff --git a/cutile-examples/examples/async_and_then_example.rs b/cutile-examples/examples/async_and_then_example.rs index 1d42190..2fb552b 100644 --- a/cutile-examples/examples/async_and_then_example.rs +++ b/cutile-examples/examples/async_and_then_example.rs @@ -56,6 +56,7 @@ async fn main() -> Result<(), DeviceError> { vec![], None, CompileOptions::default(), + my_module::_SOURCE_HASH, ); value(func) }) diff --git a/cutile-examples/examples/async_saxpy_unsafe.rs b/cutile-examples/examples/async_saxpy_unsafe.rs index fae5d6f..756ff61 100644 --- a/cutile-examples/examples/async_saxpy_unsafe.rs +++ b/cutile-examples/examples/async_saxpy_unsafe.rs @@ -61,6 +61,7 @@ async fn main() -> Result<(), Error> { vec![], None, CompileOptions::default(), + my_module::_SOURCE_HASH, ); value(func) }) diff --git a/cutile-macro/Cargo.toml b/cutile-macro/Cargo.toml index f45b345..44948c7 100644 --- a/cutile-macro/Cargo.toml +++ b/cutile-macro/Cargo.toml @@ -24,3 +24,4 @@ itertools = { workspace = true } phf = { workspace = true } convert_case = { workspace = true } cutile-compiler = { workspace = true } +sha2 = { workspace = true } diff --git a/cutile-macro/src/_module.rs b/cutile-macro/src/_module.rs index f63f45b..eb85340 100644 --- a/cutile-macro/src/_module.rs +++ b/cutile-macro/src/_module.rs @@ -58,6 +58,7 @@ use proc_macro2::Ident; use proc_macro2::{LineColumn, Span, TokenStream as TokenStream2}; use quote::{format_ident, quote, ToTokens}; use std::collections::{HashMap, HashSet}; +use sha2::{Digest, Sha256}; use std::path::PathBuf; use std::{env, fs}; @@ -241,10 +242,11 @@ fn process_items( items: &[syn::Item], parent_name: &Ident, tile_rust_crate_root: &Ident, -) -> Result<(Vec, Vec), Error> { +) -> Result<(Vec, Vec, Vec<(String, String)>), Error> { let mut concrete_items: Vec = vec![]; let mut entry_functions: Vec = vec![]; let type_aliases = cutile_compiler::type_aliases::collect_type_aliases(items); + let mut entry_metas: Vec<(String, String)> = vec![]; for item in items { match item { @@ -262,6 +264,10 @@ fn process_items( function_item, &type_aliases, )?); + let fn_name = function_item.sig.ident.to_string(); + let kernel_naming = KernelNaming::new(&fn_name); + let fn_entry = kernel_naming.entry_name(); + entry_metas.push((fn_name, fn_entry)); }; concrete_items.push(function( function_item.clone(), @@ -302,7 +308,7 @@ fn process_items( not supported because the macro needs the body at expansion time.", ); }; - let (sub_concrete, sub_entries) = + let (sub_concrete, sub_entries, _sub_metas) = process_items(&sub_content.1, &submod.ident, tile_rust_crate_root)?; let sub_name = &submod.ident; let sub_attrs = &submod.attrs; @@ -321,7 +327,7 @@ fn process_items( } } } - Ok((concrete_items, entry_functions)) + Ok((concrete_items, entry_functions, entry_metas)) } /// Fallible inner implementation of the `module` macro. @@ -334,14 +340,69 @@ fn module_inner( return module_item.err("Non-empty module expected."); }; let name = &module_item.ident; - let (concrete_items, entry_functions) = process_items(&content.1, name, tile_rust_crate_root)?; + let (concrete_items, entry_functions, entry_metas) = process_items(&content.1, name, tile_rust_crate_root)?; let ast_path = get_ast_path(tile_rust_crate_root); let ast_module_item: ItemMod = module_item.clone(); let ast_module_tokens = emit_module_ast_self_and_registry_entry( ast_module_item, tile_rust_crate_root, - raw_item_source, + raw_item_source.clone(), ); + + // Compute SHA-256 source hash at macro expansion time. + let source_hash = format!("{:x}", Sha256::digest(raw_item_source.as_bytes())); + let module_name_str = name.to_string(); + + // Generate _entries() and _SOURCE_HASH for warmup support. + let entry_meta_items: Vec = entry_metas + .iter() + .map(|(fn_name, fn_entry)| { + quote! { + #tile_rust_crate_root::tile_kernel::EntryMeta { + module_name: #module_name_str, + function_name: #fn_name, + function_entry: #fn_entry, + } + } + }) + .collect(); + let warmup_metadata = quote! { + /// SHA-256 hash of the module source, computed at compile time. + /// Changes whenever any kernel source in this module changes. + pub const _SOURCE_HASH: &str = #source_hash; + + /// Returns metadata for all entry points in this module. + pub fn _entries() -> Vec<#tile_rust_crate_root::tile_kernel::EntryMeta> { + vec![#(#entry_meta_items),*] + } + + /// Pre-compile kernel specializations for this module. + /// + /// This is a convenience wrapper around [`compile_warmup`] that automatically + /// supplies the module ASTs, entry metadata, module name, and source hash. + /// Callers only need to provide the [`WarmupSpec`]s describing which + /// generics/strides to pre-compile. + /// + /// # Example + /// + /// ```rust,ignore + /// my_module::_compile_warmup(&[ + /// WarmupSpec::new("vector_add", vec!["f32".into(), "128".into()]), + /// ])?; + /// ``` + pub fn _compile_warmup( + specs: &[#tile_rust_crate_root::tile_kernel::WarmupSpec], + ) -> Result<(), #tile_rust_crate_root::error::Error> { + #tile_rust_crate_root::tile_kernel::compile_warmup( + || __module_ast_self(), + &_entries(), + #module_name_str, + _SOURCE_HASH, + specs, + ) + } + }; + let res = if entry_functions.is_empty() { quote! { pub mod #name { @@ -352,6 +413,8 @@ fn module_inner( use #ast_path; #ast_module_tokens #(#concrete_items)* + // Warmup metadata. + #warmup_metadata } } } else { @@ -378,6 +441,8 @@ fn module_inner( #(#concrete_items)* // Entry point code. #(#entry_functions)* + // Warmup metadata. + #warmup_metadata } } }; diff --git a/cutile-macro/src/kernel_launcher_generator.rs b/cutile-macro/src/kernel_launcher_generator.rs index e423e36..cd26467 100644 --- a/cutile-macro/src/kernel_launcher_generator.rs +++ b/cutile-macro/src/kernel_launcher_generator.rs @@ -867,7 +867,8 @@ pub fn generate_kernel_launcher( module_name, function_name, function_entry, function_generics, stride_args, spec_args.clone(), scalar_hints, const_grid, - compile_options + compile_options, + _SOURCE_HASH, )?; }}) .unwrap() diff --git a/cutile/Cargo.toml b/cutile/Cargo.toml index ca5518c..b9be5bf 100644 --- a/cutile/Cargo.toml +++ b/cutile/Cargo.toml @@ -26,3 +26,5 @@ cutile-compiler = { workspace = true } cutile-macro = { workspace = true } cutile-ir = { workspace = true } linkme = { workspace = true } +sha2 = { workspace = true } +once_cell = { workspace = true } diff --git a/cutile/src/lib.rs b/cutile/src/lib.rs index 8308049..5851c05 100644 --- a/cutile/src/lib.rs +++ b/cutile/src/lib.rs @@ -186,6 +186,7 @@ pub mod tile_kernel; pub mod utils; pub use cuda_async; +pub use cuda_async::jit_store::{self, FileSystemJitStore, JitStore}; pub use cuda_core; pub use cuda_core::{DType, DTypeId}; pub use cutile_compiler; diff --git a/cutile/src/tile_kernel.rs b/cutile/src/tile_kernel.rs index 51f72d4..0b5c9b7 100644 --- a/cutile/src/tile_kernel.rs +++ b/cutile/src/tile_kernel.rs @@ -11,13 +11,30 @@ use cuda_core::DType; use cuda_core::{memcpy_dtoh_async, Function}; use cutile_compiler::ast::Module; use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules}; -use cutile_compiler::cuda_tile_runtime_utils::{compile_tile_ir_module, get_gpu_name}; +use cutile_compiler::cuda_tile_runtime_utils::{compile_tile_ir_module, get_gpu_name, get_cuda_toolkit_version, get_compiler_version}; use cutile_compiler::specialization::{DivHint, SpecializationBits}; +use once_cell::sync::OnceCell; +use sha2::{Digest, Sha256}; use std::alloc::{alloc, Layout}; use std::fs; use std::future::IntoFuture; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +// JIT diagnostic logging (set CUTILE_JIT_LOG=1 to enable) + +fn jit_log_enabled() -> bool { + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var("CUTILE_JIT_LOG").is_ok_and(|v| v == "1")) +} + +macro_rules! jit_log { + ($($arg:tt)*) => { + if jit_log_enabled() { + eprintln!("[cutile::jit] {}", format!($($arg)*)); + } + }; +} use crate::error::*; use crate::tensor::{IntoPartition, IntoPartitionArc, Partition, Tensor}; @@ -34,7 +51,8 @@ pub use cutile_compiler::compiler::utils::CompileOptions; /// Two kernel invocations that share the same `TileFunctionKey` can reuse the same compiled /// CUDA module and function, avoiding recompilation. The key captures everything that can /// change the generated GPU code: module name, function name, generic type/const parameters, -/// tensor stride layouts, (optionally) the launch grid, and compile options. +/// tensor stride layouts, (optionally) the launch grid, compile options, source hash, +/// GPU architecture, compiler version, and CUDA toolkit version. #[derive(Debug, Eq, PartialEq, Hash, Clone)] pub struct TileFunctionKey { module_name: String, @@ -45,6 +63,10 @@ pub struct TileFunctionKey { pub scalar_hints: Vec<(String, DivHint)>, pub grid: Option<(u32, u32, u32)>, pub compile_options: CompileOptions, + source_hash: String, + gpu_name: String, + compiler_version: String, + cuda_toolkit_version: String, } impl TileFunctionKey { @@ -57,6 +79,10 @@ impl TileFunctionKey { scalar_hints: Vec<(String, DivHint)>, grid: Option<(u32, u32, u32)>, compile_options: CompileOptions, + source_hash: String, + gpu_name: String, + compiler_version: String, + cuda_toolkit_version: String, ) -> Self { Self { module_name, @@ -67,11 +93,153 @@ impl TileFunctionKey { scalar_hints, grid, compile_options, + source_hash, + gpu_name, + compiler_version, + cuda_toolkit_version, } } } -impl FunctionKey for TileFunctionKey {} +/// Builder for [`TileFunctionKey`]. +/// +/// With 11 positional arguments it is easy to silently transpose two `String` +/// fields and produce a wrong-but-valid key. The builder makes each field +/// self-documenting and keeps future additions backward-compatible. +/// +/// # Example +/// +/// ```rust,ignore +/// let key = TileFunctionKey::builder("linalg", "matmul") +/// .generics(vec!["f32".into(), "128".into()]) +/// .source_hash(linalg::_SOURCE_HASH) +/// .gpu_name(get_gpu_name(device_id)) +/// .compiler_version(get_compiler_version()) +/// .cuda_toolkit_version(get_cuda_toolkit_version()) +/// .build(); +/// ``` +pub struct TileFunctionKeyBuilder { + module_name: String, + function_name: String, + function_generics: Vec, + stride_args: Vec<(String, Vec)>, + spec_args: Vec<(String, SpecializationBits)>, + scalar_hints: Vec<(String, DivHint)>, + grid: Option<(u32, u32, u32)>, + compile_options: CompileOptions, + source_hash: String, + gpu_name: String, + compiler_version: String, + cuda_toolkit_version: String, +} + +impl TileFunctionKeyBuilder { + pub fn generics(mut self, generics: Vec) -> Self { + self.function_generics = generics; + self + } + pub fn stride_args(mut self, stride_args: Vec<(String, Vec)>) -> Self { + self.stride_args = stride_args; + self + } + pub fn spec_args(mut self, spec_args: Vec<(String, SpecializationBits)>) -> Self { + self.spec_args = spec_args; + self + } + pub fn grid(mut self, grid: (u32, u32, u32)) -> Self { + self.grid = Some(grid); + self + } + pub fn compile_options(mut self, options: CompileOptions) -> Self { + self.compile_options = options; + self + } + pub fn source_hash(mut self, hash: impl Into) -> Self { + self.source_hash = hash.into(); + self + } + pub fn gpu_name(mut self, name: impl Into) -> Self { + self.gpu_name = name.into(); + self + } + pub fn compiler_version(mut self, version: impl Into) -> Self { + self.compiler_version = version.into(); + self + } + pub fn cuda_toolkit_version(mut self, version: impl Into) -> Self { + self.cuda_toolkit_version = version.into(); + self + } + pub fn build(self) -> TileFunctionKey { + TileFunctionKey { + module_name: self.module_name, + function_name: self.function_name, + function_generics: self.function_generics, + stride_args: self.stride_args, + spec_args: self.spec_args, + scalar_hints: self.scalar_hints, + grid: self.grid, + compile_options: self.compile_options, + source_hash: self.source_hash, + gpu_name: self.gpu_name, + compiler_version: self.compiler_version, + cuda_toolkit_version: self.cuda_toolkit_version, + } + } +} + +impl TileFunctionKey { + /// Start building a key with required `module_name` and `function_name`. + /// All other fields default to empty / `None` / `default()`. + pub fn builder( + module_name: impl Into, + function_name: impl Into, + ) -> TileFunctionKeyBuilder { + TileFunctionKeyBuilder { + module_name: module_name.into(), + function_name: function_name.into(), + function_generics: vec![], + stride_args: vec![], + spec_args: vec![], + scalar_hints: vec![], + grid: None, + compile_options: CompileOptions::default(), + source_hash: String::new(), + gpu_name: String::new(), + compiler_version: String::new(), + cuda_toolkit_version: String::new(), + } + } +} + +impl FunctionKey for TileFunctionKey { + fn get_disk_hash_string(&self) -> String { + let canonical = format!( + "{}:{}:{}:{}:{}:{:?}:{:?}:{}:{}:{}:{}", + self.module_name, + self.function_name, + self.function_generics.join(","), + self.stride_args + .iter() + .map(|(k, v)| format!("{}={:?}", k, v)) + .collect::>() + .join(";"), + self.spec_args + .iter() + .map(|(k, v)| format!("{}={:?}", k, v)) + .collect::>() + .join(";"), + self.grid, + self.compile_options, + self.source_hash, + self.gpu_name, + self.compiler_version, + self.cuda_toolkit_version, + ); + let hash = Sha256::digest(canonical.as_bytes()); + format!("{:x}", hash) + } +} /// Reads IR (MLIR or PTX) from a file. /// @@ -123,12 +291,25 @@ fn write_ir( println!("IR written to {path:?}"); } +/// Attempt to load a cubin from the global JitStore. +fn try_load_from_jit_store(disk_key: &str) -> Option> { + let store = cuda_async::jit_store::get_jit_store()?; + store.get(disk_key).ok().flatten() +} + +// ── Single-flight compilation dedup is handled by once_cell::sync::OnceLock ── + /// Compiles a tile function to CUDA and caches it for reuse. /// /// Handles the complete compilation pipeline from Rust/MLIR to CUDA: -/// 1. Checks the thread-local cache for a previously compiled function -/// 2. If not cached, compiles the module AST to MLIR, then to PTX/CUBIN -/// 3. Loads the compiled function and caches it for future use +/// 1. Checks the global kernel cache (process-wide, cross-thread) +/// 2. Checks the disk cache (JitStore) for a previously persisted cubin +/// 3. If not cached, compiles the module AST to MLIR, then to PTX/CUBIN +/// 4. Stores the result in the global cache and optionally persists to disk +/// +/// **Compilation dedup**: When multiple threads need the same kernel, `OnceLock::get_or_try_init` +/// ensures only one thread performs compilation while others block. Once initialization completes, +/// all threads see the same cached result. /// /// The caching key is based on the module name, function name, type generics, stride arguments, /// and compile-time grid dimensions, ensuring correct reuse across different specializations. @@ -174,9 +355,14 @@ pub fn compile_from_context Module>( scalar_hints: Vec<(String, DivHint)>, const_grid: Option<(u32, u32, u32)>, compile_options: CompileOptions, + source_hash: &str, ) -> Result<(Arc, Arc), Error> { + cuda_async::jit_store::ensure_default_jit_store(); + let device_id: usize = ctx.get_device_id(); - // Compilation constructs a lookup key. + let gpu_name = get_gpu_name(device_id); + let compiler_version = get_compiler_version(); + let cuda_toolkit_version = get_cuda_toolkit_version(); let key = TileFunctionKey::new( module_name.to_string(), function_name.to_string(), @@ -186,18 +372,86 @@ pub fn compile_from_context Module>( scalar_hints, const_grid, compile_options, + source_hash.to_string(), + gpu_name.clone(), + compiler_version, + cuda_toolkit_version, ); - let cache_hash_str = key.get_hash_string(); - if contains_cuda_function(device_id, &key) { - // A hit to the thread local kernel cache returns the compiled function. - let func = get_cuda_function(device_id, &key)?; - let validator = get_function_validator(device_id, &key)?; - Ok((func, validator)) - } else { - let gpu_name = get_gpu_name(device_id); - // LINKING Phase B: build the module set by walking the kernel's - // `use` graph against the linker registry. The legacy chained - // `_module_asts()` is no longer consulted here. + let key_str = key.get_hash_string(); + let disk_key = key.get_disk_hash_string(); + + let cache = get_kernel_cache(); + let slot = { + let entry = cache + .entry(key_str.clone()) + .or_insert_with(|| Arc::new(OnceCell::new())); + Arc::clone(entry.value()) + }; // entry dropped here — releases DashMap shard lock. + + // Use OnceCell::get_or_try_init for single-flight compilation dedup. + // Only one thread executes the closure; others block and see the result. + let compiled = slot.get_or_try_init(|| -> Result { + // Try disk cache first. + if let Some(cubin_bytes) = try_load_from_jit_store(&disk_key) { + jit_log!( + "{module_name}::{function_name} → disk cache hit ({} bytes)", + cubin_bytes.len() + ); + let modules = CUDATileModules::from_kernel(kernel_ast())?; + let compiler = CUDATileFunctionCompiler::new( + &modules, + module_name, + function_name, + &key.function_generics, + &key.stride_args + .iter() + .map(|x| (x.0.as_str(), x.1.as_slice())) + .collect::>(), + &key.spec_args + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + &key.scalar_hints + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + const_grid, + gpu_name.clone(), + &key.compile_options, + )?; + let validator = Arc::new(compiler.get_validator()); + match load_module_from_bytes(&cubin_bytes, device_id) { + Ok(module) => { + let function = + Arc::new(module.load_function(function_entry).map_err(|e| { + Error::KernelLaunch(KernelLaunchError(format!( + "failed to load '{function_entry}' from cached cubin: {e}" + ))) + })?); + return Ok(CompiledKernel { + module, + function, + validator, + }); + } + Err(e) => { + // Corrupted or incompatible cubin (e.g. disk error, partial write, + // architecture mismatch). Delete the bad entry and fall through to + // full JIT recompilation so the caller is not permanently blocked. + jit_log!( + "{module_name}::{function_name} → corrupted disk cubin, \ + deleting and recompiling (error: {e})" + ); + if let Some(store) = cuda_async::jit_store::get_jit_store() { + let _ = store.delete(&disk_key); + } + } + } + } + + // Full JIT compilation. + jit_log!("{module_name}::{function_name} → JIT compiling..."); + let t0 = std::time::Instant::now(); let modules = CUDATileModules::from_kernel(kernel_ast())?; let _debug_mlir_path = modules.get_entry_arg_string_by_function_name( module_name, @@ -276,7 +530,7 @@ pub fn compile_from_context Module>( write_ir( module_name, function_name, - cache_hash_str.as_str(), + key_str.as_str(), "mlir", path.as_str(), mlir.as_str(), @@ -305,11 +559,27 @@ pub fn compile_from_context Module>( // println!(); // } let stage3_start = std::time::Instant::now(); + let jit_elapsed = t0.elapsed(); + // Persist to disk cache if a JitStore is configured. + if let Some(store) = cuda_async::jit_store::get_jit_store() { + if let Ok(cubin_bytes) = std::fs::read(&cubin_filename) { + if store.put(&disk_key, &cubin_bytes).is_ok() { + jit_log!( + "{module_name}::{function_name} → saved to disk cache ({} bytes)", + cubin_bytes.len() + ); + } + } + } let module = load_module_from_file(&cubin_filename, device_id)?; - let function = Arc::new( - module - .load_function(function_entry) - .expect("Failed to compile function."), + let function = Arc::new(module.load_function(function_entry).map_err(|e| { + Error::KernelLaunch(KernelLaunchError(format!( + "failed to load '{function_entry}' from compiled cubin: {e}" + ))) + })?); + jit_log!( + "{module_name}::{function_name} → JIT compiled in {:.1?}", + jit_elapsed ); let stage3_ms = stage3_start.elapsed().as_secs_f64() * 1000.0; if std::env::var_os("CUTILE_JIT_TIMING").is_some() { @@ -317,17 +587,335 @@ pub fn compile_from_context Module>( "CUTILE_JIT_TIMING module={} function={} key={} stage1_ms={:.3} stage2_ms={:.3} stage3_ms={:.3} generics={}", module_name, function_name, - cache_hash_str, + key_str, stage1_ms, stage2_ms, stage3_ms, key.function_generics.join(","), ); } - insert_cuda_function(device_id, &key, (module, function.clone()))?; + insert_cuda_function(device_id, &key, (module.clone(), function.clone()))?; insert_function_validator(device_id, &key, validator.clone())?; - Ok((function, validator)) + Ok(CompiledKernel { + module, + function, + validator, + }) + })?; + + Ok(( + Arc::clone(&compiled.function), + Arc::clone(&compiled.validator), + )) +} + +// ── Warmup types and functions ─────────────────────────────────────────────── + +/// Metadata for a single kernel entry point, generated by the `#[cutile::module]` macro. +#[derive(Debug, Clone)] +pub struct EntryMeta { + pub module_name: &'static str, + pub function_name: &'static str, + pub function_entry: &'static str, +} + +/// User-provided specialization for warmup compilation. +/// +/// Each `WarmupSpec` describes one kernel specialization to pre-compile. +#[derive(Debug, Clone)] +pub struct WarmupSpec { + pub function_name: String, + pub function_generics: Vec, + pub stride_args: Vec<(String, Vec)>, + pub spec_args: Vec<(String, SpecializationBits)>, + pub scalar_hints: Vec<(String, DivHint)>, + pub const_grid: Option<(u32, u32, u32)>, +} + +impl WarmupSpec { + /// Create a warmup spec with just generics (no strides, no const grid). + pub fn new(function_name: &str, generics: Vec) -> Self { + Self { + function_name: function_name.to_string(), + function_generics: generics, + stride_args: vec![], + spec_args: vec![], + scalar_hints: vec![], + const_grid: None, + } + } + + /// Set stride arguments for this spec. + pub fn with_strides(mut self, stride_args: Vec<(String, Vec)>) -> Self { + self.stride_args = stride_args; + self + } + + /// Set specialization arguments for this spec. + pub fn with_spec_args(mut self, spec_args: Vec<(String, SpecializationBits)>) -> Self { + self.spec_args = spec_args; + self + } + + /// Set scalar divisibility hints for this spec. + pub fn with_scalar_hints(mut self, scalar_hints: Vec<(String, DivHint)>) -> Self { + self.scalar_hints = scalar_hints; + self + } + + /// Set a const grid for this spec. + pub fn with_const_grid(mut self, grid: (u32, u32, u32)) -> Self { + self.const_grid = Some(grid); + self + } +} + +/// Pre-compile a set of kernel specializations without launching. +/// +/// Builds the module ASTs once, then compiles each requested specialization. +/// Results are placed in the global kernel cache and optionally persisted +/// to the JitStore. +/// +/// # Example +/// +/// ```rust,ignore +/// compile_warmup( +/// || linalg::__module_ast_self(), +/// &linalg::_entries(), +/// "linalg", +/// linalg::_SOURCE_HASH, +/// &[ +/// WarmupSpec::new("vector_add", vec!["f32".into(), "128".into()]), +/// WarmupSpec::new("vector_add", vec!["f16".into(), "256".into()]), +/// WarmupSpec::new("relu", vec!["f32".into(), "128".into()]), +/// ], +/// )?; +/// ``` +pub fn compile_warmup Module>( + module_asts: F, + entries: &[EntryMeta], + module_name: &str, + source_hash: &str, + specs: &[WarmupSpec], +) -> Result<(), Error> { + cuda_async::jit_store::ensure_default_jit_store(); + + let device_id = get_default_device(); + let gpu_name = get_gpu_name(device_id); + let compiler_version = get_compiler_version(); + let cuda_toolkit_version = get_cuda_toolkit_version(); + + // Build module ASTs once, shared across all specs in this warmup call. + let modules = CUDATileModules::from_kernel(module_asts())?; + + for spec in specs { + // Find matching entry metadata. + let entry = entries + .iter() + .find(|e| e.function_name == spec.function_name) + .ok_or_else(|| { + Error::KernelLaunch(KernelLaunchError(format!( + "compile_warmup: unknown function '{}' in module '{}'", + spec.function_name, module_name + ))) + })?; + + let key = TileFunctionKey::new( + module_name.to_string(), + spec.function_name.clone(), + spec.function_generics.clone(), + spec.stride_args.clone(), + spec.spec_args.clone(), + spec.scalar_hints.clone(), + spec.const_grid, + CompileOptions::default(), + source_hash.to_string(), + gpu_name.clone(), + compiler_version.clone(), + cuda_toolkit_version.clone(), + ); + + let key_str = key.get_hash_string(); + let disk_key = key.get_disk_hash_string(); + let cache = get_kernel_cache(); + let slot = { + let entry = cache + .entry(key_str.clone()) + .or_insert_with(|| Arc::new(OnceCell::new())); + Arc::clone(entry.value()) + }; // entry dropped — releases DashMap shard lock. + + // Use OnceCell::get_or_try_init for single-flight compilation dedup. + // Only one thread executes the closure; others block and see the result. + let _ = slot.get_or_try_init(|| -> Result { + jit_log!( + "warmup: {module_name}::{} <{}> → compiling...", + spec.function_name, + spec.function_generics.join(", ") + ); + + // Try disk cache first. + if let Some(cubin_bytes) = try_load_from_jit_store(&disk_key) { + jit_log!( + "warmup: {module_name}::{} → disk cache hit ({} bytes)", + spec.function_name, + cubin_bytes.len() + ); + let compiler = CUDATileFunctionCompiler::new( + &modules, + module_name, + &spec.function_name, + &spec.function_generics, + &spec + .stride_args + .iter() + .map(|x| (x.0.as_str(), x.1.as_slice())) + .collect::>(), + &spec + .spec_args + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + &spec + .scalar_hints + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + spec.const_grid, + gpu_name.clone(), + &key.compile_options, + )?; + let validator = Arc::new(compiler.get_validator()); + match load_module_from_bytes(&cubin_bytes, device_id) { + Ok(module) => { + let function = Arc::new( + module.load_function(entry.function_entry).map_err(|e| { + Error::KernelLaunch(KernelLaunchError(format!( + "failed to load '{}' from cached cubin: {e}", + entry.function_entry + ))) + })?, + ); + return Ok(CompiledKernel { + module, + function, + validator, + }); + } + Err(e) => { + // Corrupted or incompatible cubin. Delete and fall through to JIT. + jit_log!( + "warmup: {module_name}::{} → corrupted disk cubin, \ + deleting and recompiling (error: {e})", + spec.function_name + ); + if let Some(store) = cuda_async::jit_store::get_jit_store() { + let _ = store.delete(&disk_key); + } + } + } + } + + // Full JIT compilation. + let t0 = std::time::Instant::now(); + let compiler = CUDATileFunctionCompiler::new( + &modules, + module_name, + &spec.function_name, + &spec.function_generics, + &spec + .stride_args + .iter() + .map(|x| (x.0.as_str(), x.1.as_slice())) + .collect::>(), + &spec + .spec_args + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + &spec + .scalar_hints + .iter() + .map(|x| (x.0.as_str(), &x.1)) + .collect::>(), + spec.const_grid, + gpu_name.clone(), + &key.compile_options, + )?; + let validator = Arc::new(compiler.get_validator()); + let module_op = compiler.compile()?; + let cubin_filename = compile_tile_ir_module(&module_op, &gpu_name); + let jit_elapsed = t0.elapsed(); + + // Persist to disk cache. + if let Some(store) = cuda_async::jit_store::get_jit_store() { + if let Ok(cubin_bytes) = std::fs::read(&cubin_filename) { + if store.put(&disk_key, &cubin_bytes).is_ok() { + jit_log!( + "warmup: {module_name}::{} → saved to disk cache ({} bytes)", + spec.function_name, + cubin_bytes.len() + ); + } + } + } + + let module = load_module_from_file(&cubin_filename, device_id)?; + let function = + Arc::new(module.load_function(entry.function_entry).map_err(|e| { + Error::KernelLaunch(KernelLaunchError(format!( + "failed to load '{}' from compiled cubin: {e}", + entry.function_entry + ))) + })?); + jit_log!( + "warmup: {module_name}::{} → JIT compiled in {:.1?}", + spec.function_name, + jit_elapsed + ); + Ok(CompiledKernel { + module, + function, + validator, + }) + })?; } + + Ok(()) +} + +/// Execute a warmup routine with realistic kernel launches. +/// +/// The provided closure should launch kernels with production-representative +/// shapes and data. This warms up both the JIT compilation cache and the +/// CUDA runtime (driver initialization, shared memory allocation, occupancy +/// calculation, etc.). +/// +/// # Example +/// +/// ```rust,ignore +/// execute_warmup(|| { +/// let x = api::zeros::([4096, 4096]).sync()?; +/// let y = api::zeros::([4096, 4096]).sync()?; +/// let z = api::zeros::([4096, 4096]).sync()?; +/// linalg::matmul(z, x, y) +/// .generics(vec!["f32".into(), "128".into()]) +/// .grid((32, 32, 1)) +/// .sync()?; +/// Ok(()) +/// })?; +/// ``` +pub fn execute_warmup(f: F) -> Result<(), Error> +where + F: FnOnce() -> Result<(), Error>, +{ + // Ensure device context is initialized. + let device_id = get_default_device(); + let _ = with_global_device_context(device_id, |_| {})?; + + // Run user-provided warmup routine. + // Kernels inside will auto-JIT via existing compile_from_context path. + f() } /// Validates that all partition grids match the expected launch grid. @@ -471,6 +1059,7 @@ where scalar_hints: Vec<(String, DivHint)>, grid: Option<(u32, u32, u32)>, compile_options: CompileOptions, + source_hash: &str, ) -> Result<(Arc, Arc), Error> { compile_from_context( ctx, @@ -484,6 +1073,7 @@ where scalar_hints, grid, compile_options, + source_hash, ) } /// Sets the type and const generic arguments for this kernel. diff --git a/cutile/tests/gpu.rs b/cutile/tests/gpu.rs index 7f81799..49a92fd 100644 --- a/cutile/tests/gpu.rs +++ b/cutile/tests/gpu.rs @@ -16,3 +16,9 @@ mod tensor; #[path = "gpu/num_tiles.rs"] mod num_tiles; + +#[path = "gpu/warmup.rs"] +mod warmup; + +#[path = "gpu/warmup_bench.rs"] +mod warmup_bench; diff --git a/cutile/tests/gpu/warmup.rs b/cutile/tests/gpu/warmup.rs new file mode 100644 index 0000000..cef4ac3 --- /dev/null +++ b/cutile/tests/gpu/warmup.rs @@ -0,0 +1,962 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! GPU integration tests for compile_warmup and execute_warmup. + +use crate::common; +use cutile::api; +use cutile::jit_store::FileSystemJitStore; +use cutile::prelude::{DeviceOp, PartitionOp}; +use cutile::tile_kernel::{ + contains_cuda_function, evict_kernel, execute_warmup, get_default_device, + load_module_from_bytes, CompileOptions, FunctionKey, + TileFunctionKey, TileKernel, WarmupSpec, +}; +use cutile_compiler::cuda_tile_runtime_utils::{ + get_compiler_version, get_cuda_toolkit_version, get_gpu_name, +}; +use cutile_compiler::specialization::SpecializationBits; +use once_cell::sync::Lazy; +use std::process::Command; +use std::sync::Arc; +use std::sync::Mutex; + +static WARMUP_CACHE_TEST_LOCK: Lazy> = Lazy::new(|| Mutex::new(())); + +#[cutile::module] +mod warmup_test_module { + use cutile::core::*; + + #[cutile::entry()] + fn vector_add( + z: &mut Tensor, + x: &Tensor, + y: &Tensor, + ) { + let tile_x = load_tile_like(x, z); + let tile_y = load_tile_like(y, z); + z.store(tile_x + tile_y); + } +} + +fn vector_add_stride_args() -> Vec<(String, Vec)> { + vec![ + ("z".to_string(), vec![1]), + ("x".to_string(), vec![1]), + ("y".to_string(), vec![1]), + ] +} + +fn vector_add_spec_args(len: usize, tile: usize) -> Vec<(String, SpecializationBits)> { + let x = api::ones::(&[len]).sync().unwrap(); + let y = api::ones::(&[len]).sync().unwrap(); + let z = api::zeros::(&[len]).partition([tile]).sync().unwrap(); + let z_spec = z.unpartition().spec().clone(); + vec![ + ("z".to_string(), z_spec), + ("x".to_string(), x.spec().clone()), + ("y".to_string(), y.spec().clone()), + ] +} + +fn unique_temp_cache_dir(tag: &str) -> std::path::PathBuf { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time went backwards") + .as_nanos(); + std::env::temp_dir().join(format!("cutile_warmup_{tag}_{}_{}", std::process::id(), nanos)) +} + +fn run_warmup_worker(role: &str, cache_root: &std::path::Path) -> std::process::Output { + let exe = std::env::current_exe().expect("failed to resolve current test binary path"); + let no_disk = if role == "no-disk-cache" { "1" } else { "0" }; + let mut cmd = Command::new(exe); + cmd.arg("--exact") + .arg("warmup::cross_process_warmup_worker") + .arg("--nocapture") + .env("CUTILE_WARMUP_WORKER_ROLE", role) + .env("XDG_CACHE_HOME", cache_root) + .env("CUTILE_JIT_LOG", "1") + .env("CUTILE_NO_DISK_CACHE", no_disk); + cmd.output().expect("failed to run warmup worker") +} + +// Compile_warmup +#[test] +fn compile_warmup_populates_cache() { + common::with_test_stack(move || { + let _guard = WARMUP_CACHE_TEST_LOCK + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + + // Uses the macro-generated _compile_warmup helper — callers only pass specs. + warmup_test_module::_compile_warmup(&[WarmupSpec::new( + "vector_add", + vec!["f32".into(), "64".into()], + ) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 64))]) + .expect("_compile_warmup failed"); + + // Verify the kernel is in cache. + let device_id = get_default_device(); + let key = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "64".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 64)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in cache after compile_warmup" + ); + }); +} + +#[test] +fn compile_warmup_skips_duplicate() { + common::with_test_stack(move || { + let specs = &[WarmupSpec::new("vector_add", vec!["f32".into(), "128".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 128))]; + // First call compiles. + warmup_test_module::_compile_warmup(specs) + .expect("first compile_warmup failed"); + + // Second call should be a no-op (hits cache). + warmup_test_module::_compile_warmup(specs) + .expect("second compile_warmup failed"); + }); +} + +#[test] +fn compile_warmup_unknown_function_errors() { + common::with_test_stack(|| { + let result = warmup_test_module::_compile_warmup( + &[WarmupSpec::new("nonexistent_fn", vec!["f32".into()])], + ); + assert!(result.is_err(), "should error for unknown function"); + }); +} + +// Compile_warmup with JitStore disk persistence +#[test] +fn compile_warmup_persists_to_disk() { + common::with_test_stack(|| { + let dir = + std::env::temp_dir().join(format!("cutile_warmup_disk_test_{}", std::process::id())); + let _ = std::fs::remove_dir_all(&dir); + let store = FileSystemJitStore::new(dir.clone()).expect("failed to create store"); + + // Configure JitStore (note: can only be set once per process). + // If this fails because it's already set from another test, that's OK — + // just skip the disk assertions. + let store_was_set = cuda_async::jit_store::set_jit_store_if_unset(Some(Box::new(store))); + + warmup_test_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "256".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 256))]) + .expect("compile_warmup failed"); + + if store_was_set { + // Verify a cubin file was written to disk. + let cubin_count = std::fs::read_dir(&dir) + .unwrap() + .filter(|e| { + e.as_ref() + .unwrap() + .path() + .extension() + .is_some_and(|ext| ext == "cubin") + }) + .count(); + assert!( + cubin_count > 0, + "at least one .cubin should be persisted to disk" + ); + } + + let _ = std::fs::remove_dir_all(&dir); + }); +} + +// Execute_warmup +#[test] +fn execute_warmup_runs_kernel() { + common::with_test_stack(|| { + execute_warmup(|| { + let x = api::ones::(&[256]).sync()?; + let y = api::ones::(&[256]).sync()?; + let z = api::zeros::(&[256]).partition([64]).sync()?; + let _result = warmup_test_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "64".into()]) + .sync()?; + Ok(()) + }) + .expect("execute_warmup failed"); + }); +} + +// Multi-thread compilation dedup +// Spawns multiple threads that all compile the same kernel specialization +// concurrently. Verifies that all threads succeed and the kernel ends up +// in cache. With single-flight dedup, only one thread performs the actual +// JIT; the rest wait and get the cached result. +#[test] +fn multi_thread_compile_dedup() { + common::with_test_stack(|| { + let _guard = WARMUP_CACHE_TEST_LOCK + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let n_threads = 4; + let barrier = Arc::new(std::sync::Barrier::new(n_threads)); + let handles: Vec<_> = (0..n_threads) + .map(|_| { + let barrier = Arc::clone(&barrier); + std::thread::Builder::new() + .stack_size(common::TEST_STACK_SIZE) + .spawn(move || { + barrier.wait(); + let x = api::ones::(&[256]).sync().unwrap(); + let y = api::ones::(&[256]).sync().unwrap(); + let z = api::zeros::(&[256]).partition([128]).sync().unwrap(); + warmup_test_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "128".into()]) + .sync() + .unwrap(); + }) + .unwrap() + }) + .collect(); + + for h in handles { + h.join().expect("thread panicked during concurrent compile"); + } + + // Build a key with runtime specialization bits to match launcher behavior. + let x_probe = api::ones::(&[256]).sync().unwrap(); + let y_probe = api::ones::(&[256]).sync().unwrap(); + let z_probe = api::zeros::(&[256]).partition([128]).sync().unwrap(); + let z_spec = z_probe.unpartition().spec().clone(); + + // Verify the kernel is in cache after concurrent compilation. + let device_id = get_default_device(); + let key = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "128".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vec![ + ("z".to_string(), z_spec), + ("x".to_string(), x_probe.spec().clone()), + ("y".to_string(), y_probe.spec().clone()), + ]) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in cache after concurrent compilation" + ); + }); +} + +// Verifies the disk → memory cache path: compile a kernel (populating both +// caches), evict from memory, re-warmup → the second compilation should load +// from disk instead of re-JIT-compiling. +#[test] +fn disk_cache_hit_after_memory_eviction() { + common::with_test_stack(|| { + let dir = std::env::temp_dir().join(format!( + "cutile_disk_read_test_{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&dir); + let store = FileSystemJitStore::new(dir.clone()).expect("failed to create store"); + let store_was_set = + cuda_async::jit_store::set_jit_store_if_unset(Some(Box::new(store))); + + if !store_was_set { + // Another test already configured the JitStore — we can't control the + // disk directory, so skip. The JIT log output still demonstrates the + // path when run standalone. + println!("Skipping disk_cache_hit_after_memory_eviction: JitStore already set"); + return; + } + + let specs = &[WarmupSpec::new("vector_add", vec!["f32".into(), "2".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 2))]; + + // Step 1: compile → populates memory + disk. + warmup_test_module::_compile_warmup(specs) + .expect("first compile_warmup failed"); + + // Verify cubin was written to disk. + let cubin_count = std::fs::read_dir(&dir) + .unwrap() + .filter(|e| { + e.as_ref() + .unwrap() + .path() + .extension() + .is_some_and(|ext| ext == "cubin") + }) + .count(); + assert!(cubin_count > 0, "cubin should be on disk after compile"); + + // Step 2: evict from memory cache. + let device_id = get_default_device(); + let key = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "2".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 2)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + evict_kernel(&key.get_hash_string()); + assert!( + !contains_cuda_function(device_id, &key), + "should be evicted from memory" + ); + + // Step 3: re-warmup → should hit disk cache (visible with CUTILE_JIT_LOG=1). + warmup_test_module::_compile_warmup(specs) + .expect("second compile_warmup (disk hit) failed"); + + // Step 4: verify back in memory cache. + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in memory cache after disk hit" + ); + + let _ = std::fs::remove_dir_all(&dir); + }); +} + +// Compilation failure does not poison cache +// Verifies that a failed compile_warmup (unknown function) does not prevent +// a subsequent valid warmup from succeeding. +#[test] +fn failed_warmup_does_not_poison_cache() { + common::with_test_stack(|| { + // First call: invalid function name → should error. + let result = warmup_test_module::_compile_warmup( + &[WarmupSpec::new("nonexistent_fn", vec!["f32".into()])], + ); + assert!(result.is_err(), "should error for unknown function"); + + // Second call: valid params → should succeed despite prior failure. + warmup_test_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "4".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 4))]) + .expect("valid warmup should succeed after failed one"); + + let device_id = get_default_device(); + let key = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "4".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 4)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in cache after recovery from prior failure" + ); + }); +} + +// Corrupted disk cubin self-healing +// Verifies that if a .cubin on disk is corrupted (e.g. partial write), compile_warmup +// deletes the bad file and falls through to JIT recompilation instead of hard-erroring. +#[test] +fn disk_cache_corrupted_cubin_self_heals() { + common::with_test_stack(|| { + let dir = std::env::temp_dir().join(format!( + "cutile_corrupted_cubin_test_{}", + std::process::id() + )); + let _ = std::fs::remove_dir_all(&dir); + let store = FileSystemJitStore::new(dir.clone()).expect("failed to create store"); + let store_was_set = + cuda_async::jit_store::set_jit_store_if_unset(Some(Box::new(store))); + + if !store_was_set { + println!("Skipping disk_cache_corrupted_cubin_self_heals: JitStore already set"); + return; + } + + let device_id = get_default_device(); + let spec = WarmupSpec::new("vector_add", vec!["f32".into(), "8".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 8)); + + // Step 1: compile → write valid cubin to disk. + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec)) + .expect("initial compile_warmup failed"); + + // Step 2: overwrite the on-disk cubin with garbage bytes. + let key = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "8".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 8)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + let disk_key = key.get_disk_hash_string(); + let cubin_path = dir.join(format!("{disk_key}.cubin")); + std::fs::write(&cubin_path, b"this is not a valid cubin") + .expect("failed to corrupt cubin"); + assert!(cubin_path.exists(), "corrupted cubin should exist before test"); + + // Step 3: evict from memory so the next warmup must go to disk. + evict_kernel(&key.get_hash_string()); + assert!( + !contains_cuda_function(device_id, &key), + "kernel should be evicted from memory" + ); + + // Step 4: re-warmup — should detect corruption, delete bad file, JIT recompile. + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec)) + .expect("compile_warmup should self-heal after corrupted cubin"); + + // Step 5: kernel must be back in memory cache (recompiled successfully). + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in memory cache after self-healing recompile" + ); + + // Step 6: the corrupted file should be gone (deleted by recovery path). + assert!( + !std::fs::read(&cubin_path) + .ok() + .is_some_and(|b| b == b"this is not a valid cubin"), + "corrupted cubin bytes should have been replaced or deleted" + ); + + let _ = std::fs::remove_dir_all(&dir); + }); +} + +// Multi-spec warmup does not skip subsequent specs +// Tests that when warming up multiple specs [A, B], if A is cached (skipped), +// B is still compiled correctly. +#[test] +fn multi_spec_warmup_compiles_all() { + common::with_test_stack(|| { + // Pre-compile spec A so it's in cache. + warmup_test_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "32".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 32))]) + .expect("pre-compile spec A failed"); + + // Now warmup [A, B] — A should skip, B should compile. + warmup_test_module::_compile_warmup(&[ + WarmupSpec::new("vector_add", vec!["f32".into(), "32".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 32)), + WarmupSpec::new("vector_add", vec!["f32".into(), "16".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 16)), + ]) + .expect("multi-spec warmup failed"); + + // Verify BOTH are in cache — the critical check is that B was compiled. + let device_id = get_default_device(); + let gpu_name = get_gpu_name(device_id); + let compiler_version = get_compiler_version(); + let cuda_toolkit_version = get_cuda_toolkit_version(); + + let key_a = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "32".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 32)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name.clone()) + .compiler_version(compiler_version.clone()) + .cuda_toolkit_version(cuda_toolkit_version.clone()) + .build(); + let key_b = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "16".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 16)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name.clone()) + .compiler_version(compiler_version.clone()) + .cuda_toolkit_version(cuda_toolkit_version.clone()) + .build(); + + assert!( + contains_cuda_function(device_id, &key_a), + "spec A should be in cache" + ); + assert!( + contains_cuda_function(device_id, &key_b), + "spec B must be in cache — multi-spec warmup should not skip subsequent specs" + ); + }); +} + +// Load_module_from_bytes multi-thread safety +// Verifies that load_module_from_bytes can be called concurrently from +// multiple threads without tmp file collisions . +#[test] +fn load_module_from_bytes_concurrent() { + common::with_test_stack(|| { + // First, compile a kernel to get valid cubin bytes. + warmup_test_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "64".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 64))]) + .expect("warmup failed"); + + // Find the cubin file from the JitStore directory or compile output. + // Use the compile_module path to get a cubin: re-compile and read the file. + let device_id = get_default_device(); + let gpu_name = get_gpu_name(device_id); + + // Get cubin bytes by compiling the kernel and reading the output file. + let modules = + cutile_compiler::compiler::CUDATileModules::new(vec![warmup_test_module::__module_ast_self()]) + .unwrap(); + let compiler = cutile_compiler::compiler::CUDATileFunctionCompiler::new( + &modules, + "warmup_test_module", + "vector_add", + &["f32".to_string(), "64".to_string()], + &[ + ("z", &[1i32][..]), + ("x", &[1i32][..]), + ("y", &[1i32][..]), + ], + &[], + &[], + None, + gpu_name.clone(), + &CompileOptions::default(), + ) + .unwrap(); + let module_op = compiler.compile().unwrap(); + let cubin_filename = + cutile_compiler::cuda_tile_runtime_utils::compile_tile_ir_module(&module_op, &gpu_name); + let cubin_bytes = std::fs::read(&cubin_filename).expect("failed to read cubin"); + + // Spawn multiple threads, each loading the same cubin bytes. + let n_threads = 4; + let barrier = Arc::new(std::sync::Barrier::new(n_threads)); + let cubin = Arc::new(cubin_bytes); + + let handles: Vec<_> = (0..n_threads) + .map(|_| { + let barrier = Arc::clone(&barrier); + let cubin = Arc::clone(&cubin); + std::thread::Builder::new() + .stack_size(common::TEST_STACK_SIZE) + .spawn(move || { + barrier.wait(); + let module = load_module_from_bytes(&cubin, device_id) + .expect("load_module_from_bytes failed"); + // Verify the module can load the function. + let _func = module + .load_function("vector_add_entry") + .expect("failed to load function from module"); + }) + .unwrap() + }) + .collect(); + + for h in handles { + h.join().expect("thread panicked in load_module_from_bytes"); + } + }); +} + +// Empty specs returns Ok(()) +#[test] +fn compile_warmup_empty_specs() { + common::with_test_stack(|| { + let result = warmup_test_module::_compile_warmup(&[]); + assert!( + result.is_ok(), + "compile_warmup with empty specs should return Ok(())" + ); + }); +} + +// Corrupted cubin bytes must return Err, not panic +#[test] +fn corrupted_cubin_returns_error_not_panic() { + common::with_test_stack(|| { + let device_id = get_default_device(); + let result = load_module_from_bytes(b"this is not a valid cubin", device_id); + assert!( + result.is_err(), + "corrupted cubin should return Err, not panic" + ); + }); +} + +// Different kernel specializations can compile concurrently without interference +#[test] +fn different_keys_parallel_compile() { + common::with_test_stack(|| { + let _guard = WARMUP_CACHE_TEST_LOCK + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let barrier = Arc::new(std::sync::Barrier::new(2)); + + let b1 = Arc::clone(&barrier); + let h1 = std::thread::Builder::new() + .stack_size(common::TEST_STACK_SIZE) + .spawn(move || { + b1.wait(); + let x = api::ones::(&[256]).sync().unwrap(); + let y = api::ones::(&[256]).sync().unwrap(); + let z = api::zeros::(&[256]).partition([128]).sync().unwrap(); + warmup_test_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "128".into()]) + .sync() + .unwrap(); + }) + .unwrap(); + + let b2 = Arc::clone(&barrier); + let h2 = std::thread::Builder::new() + .stack_size(common::TEST_STACK_SIZE) + .spawn(move || { + b2.wait(); + let x = api::ones::(&[512]).sync().unwrap(); + let y = api::ones::(&[512]).sync().unwrap(); + let z = api::zeros::(&[512]).partition([256]).sync().unwrap(); + warmup_test_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "256".into()]) + .sync() + .unwrap(); + }) + .unwrap(); + + h1.join().expect("thread 1 panicked"); + h2.join().expect("thread 2 panicked"); + + // Build keys with runtime specialization bits to match launcher behavior. + let x_probe_8 = api::ones::(&[256]).sync().unwrap(); + let y_probe_8 = api::ones::(&[256]).sync().unwrap(); + let z_probe_8 = api::zeros::(&[256]).partition([128]).sync().unwrap(); + let z_spec_8 = z_probe_8.unpartition().spec().clone(); + + let x_probe_32 = api::ones::(&[512]).sync().unwrap(); + let y_probe_32 = api::ones::(&[512]).sync().unwrap(); + let z_probe_32 = api::zeros::(&[512]).partition([256]).sync().unwrap(); + let z_spec_32 = z_probe_32.unpartition().spec().clone(); + + // Verify both distinct keys are in cache. + let device_id = get_default_device(); + let gpu_name = get_gpu_name(device_id); + let cv = get_compiler_version(); + let tv = get_cuda_toolkit_version(); + let key_8 = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "128".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vec![ + ("z".to_string(), z_spec_8), + ("x".to_string(), x_probe_8.spec().clone()), + ("y".to_string(), y_probe_8.spec().clone()), + ]) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name.clone()) + .compiler_version(cv.clone()) + .cuda_toolkit_version(tv.clone()) + .build(); + let key_32 = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "256".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vec![ + ("z".to_string(), z_spec_32), + ("x".to_string(), x_probe_32.spec().clone()), + ("y".to_string(), y_probe_32.spec().clone()), + ]) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name) + .compiler_version(cv) + .cuda_toolkit_version(tv) + .build(); + assert!(contains_cuda_function(device_id, &key_8)); + assert!(contains_cuda_function(device_id, &key_32)); + }); +} + +// Multi-thread dedup: verify OnceCell single-initialization via timing. +// If dedup works, wall time ≈ 1x single compile. +// If broken (each thread compiles independently), wall time ≈ N x single compile. +#[test] +fn multi_thread_dedup_timing_evidence() { + common::with_test_stack(|| { + // First, compile a different spec to estimate single-compile time. + let t_single = std::time::Instant::now(); + warmup_test_module::_compile_warmup(&[WarmupSpec::new( + "vector_add", + vec!["f32".into(), "128".into()], + ) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 128))]) + .unwrap(); + let single_duration = t_single.elapsed(); + + // Now race 4 threads on a FRESH spec (tile_size=4, not previously compiled). + let n_threads = 4; + let barrier = Arc::new(std::sync::Barrier::new(n_threads)); + let t_parallel = std::time::Instant::now(); + let handles: Vec<_> = (0..n_threads) + .map(|_| { + let barrier = Arc::clone(&barrier); + std::thread::Builder::new() + .stack_size(common::TEST_STACK_SIZE) + .spawn(move || { + barrier.wait(); + let x = api::ones::(&[256]).sync().unwrap(); + let y = api::ones::(&[256]).sync().unwrap(); + let z = api::zeros::(&[256]).partition([4]).sync().unwrap(); + warmup_test_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "4".into()]) + .sync() + .unwrap(); + }) + .unwrap() + }) + .collect(); + for h in handles { + h.join().unwrap(); + } + let parallel_duration = t_parallel.elapsed(); + + // If dedup works: parallel ≈ single. If broken: parallel ≈ 4 * single. + // Use 3.5x as threshold to reduce timing flakiness in shared CI/GPU environments. + let ratio = parallel_duration.as_secs_f64() / single_duration.as_secs_f64(); + eprintln!( + "[dedup timing] single={:.1?} parallel(4)={:.1?} ratio={:.2}", + single_duration, parallel_duration, ratio + ); + assert!( + ratio < 3.5, + "parallel compile of 4 threads took {ratio:.2}x single — dedup may be broken \ + (single={single_duration:.1?}, parallel={parallel_duration:.1?})" + ); + }); +} + +// CUTILE_NO_DISK_CACHE=1 disables disk persistence +#[test] +fn no_disk_cache_env_disables_persistence() { + common::with_test_stack(|| { + let cache_root = unique_temp_cache_dir("no_disk_cache"); + let _ = std::fs::remove_dir_all(&cache_root); + + let output = run_warmup_worker("no-disk-cache", &cache_root); + let combined = format!( + "{}\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + assert!( + output.status.success(), + "no-disk-cache worker failed\n{}", + combined + ); + + // The cache directory should either not exist or contain no .cubin files. + let cache_dir = cache_root.join("cutile"); + let cubin_count = if cache_dir.exists() { + std::fs::read_dir(&cache_dir) + .unwrap() + .filter(|e| { + e.as_ref() + .ok() + .and_then(|x| x.path().extension().map(|ext| ext == "cubin")) + .unwrap_or(false) + }) + .count() + } else { + 0 + }; + assert_eq!( + cubin_count, 0, + "no cubin should be on disk when CUTILE_NO_DISK_CACHE=1\n{}", + combined + ); + + let _ = std::fs::remove_dir_all(&cache_root); + }); +} + +// Cross-process disk-hit integration +#[test] +fn cross_process_disk_hit_integration() { + common::with_test_stack(|| { + let cache_root = unique_temp_cache_dir("cross_process"); + let _ = std::fs::remove_dir_all(&cache_root); + + // Process A: produce persisted cubin(s). + let producer = run_warmup_worker("producer", &cache_root); + assert!( + producer.status.success(), + "producer failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&producer.stdout), + String::from_utf8_lossy(&producer.stderr) + ); + + let cache_dir = cache_root.join("cutile"); + let cubin_count = std::fs::read_dir(&cache_dir) + .expect("cache dir should exist after producer") + .filter(|e| { + e.as_ref() + .ok() + .and_then(|x| x.path().extension().map(|ext| ext == "cubin")) + .unwrap_or(false) + }) + .count(); + assert!(cubin_count > 0, "producer should persist at least one cubin"); + + // Process B: first request should be disk-hit (new process memory cache is empty). + let consumer = run_warmup_worker("consumer", &cache_root); + let combined = format!( + "{}\n{}", + String::from_utf8_lossy(&consumer.stdout), + String::from_utf8_lossy(&consumer.stderr) + ); + assert!( + consumer.status.success(), + "consumer failed\n{}", + combined + ); + assert!( + combined.contains("disk cache hit"), + "consumer should report disk cache hit\n{}", + combined + ); + + let _ = std::fs::remove_dir_all(&cache_root); + }); +} + +// Multi-spec A=disk-hit, B=cold-compile +#[test] +fn multi_spec_disk_hit_then_cold_compile_integration() { + common::with_test_stack(|| { + let cache_root = unique_temp_cache_dir("multi_spec"); + let _ = std::fs::remove_dir_all(&cache_root); + + let output = run_warmup_worker("multi-spec", &cache_root); + let combined = format!( + "{}\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + assert!(output.status.success(), "worker failed\n{}", combined); + assert!( + combined.contains("disk cache hit"), + "expected spec A disk-hit\n{}", + combined + ); + assert!( + combined.contains("JIT compiled"), + "expected spec B cold-compile\n{}", + combined + ); + + let _ = std::fs::remove_dir_all(&cache_root); + }); +} + +// Worker test used by cross-process integration tests above. +// Runs as no-op unless CUTILE_WARMUP_WORKER_ROLE is set. +#[test] +fn cross_process_warmup_worker() { + let Ok(role) = std::env::var("CUTILE_WARMUP_WORKER_ROLE") else { + return; + }; + + common::with_test_stack(move || { + let spec_a = WarmupSpec::new("vector_add", vec!["f32".into(), "64".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 64)); + let spec_b = WarmupSpec::new("vector_add", vec!["f32".into(), "32".into()]) + .with_strides(vector_add_stride_args()) + .with_spec_args(vector_add_spec_args(256, 32)); + + match role.as_str() { + "producer" => { + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec_a)) + .expect("producer warmup failed"); + } + "consumer" => { + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec_a)) + .expect("consumer warmup failed"); + } + "multi-spec" => { + // Step 1: compile A and ensure it's persisted. + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec_a)) + .expect("initial warmup for spec A failed"); + + // Step 2: evict A from memory so A path is forced to disk-hit. + let device_id = get_default_device(); + let key_a = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "64".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 64)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(get_gpu_name(device_id)) + .compiler_version(get_compiler_version()) + .cuda_toolkit_version(get_cuda_toolkit_version()) + .build(); + evict_kernel(&key_a.get_hash_string()); + + // Step 3: A should disk-hit, B should cold-compile. + warmup_test_module::_compile_warmup(&[spec_a.clone(), spec_b.clone()]) + .expect("multi-spec warmup failed"); + + // Verify both A and B are present in memory cache. + let gpu_name = get_gpu_name(device_id); + let compiler_version = get_compiler_version(); + let cuda_toolkit_version = get_cuda_toolkit_version(); + let key_b = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "32".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 32)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name.clone()) + .compiler_version(compiler_version.clone()) + .cuda_toolkit_version(cuda_toolkit_version.clone()) + .build(); + let key_a_after = TileFunctionKey::builder("warmup_test_module", "vector_add") + .generics(vec!["f32".into(), "64".into()]) + .stride_args(vector_add_stride_args()) + .spec_args(vector_add_spec_args(256, 64)) + .source_hash(warmup_test_module::_SOURCE_HASH) + .gpu_name(gpu_name) + .compiler_version(compiler_version) + .cuda_toolkit_version(cuda_toolkit_version) + .build(); + assert!(contains_cuda_function(device_id, &key_a_after)); + assert!(contains_cuda_function(device_id, &key_b)); + } + "no-disk-cache" => { + // Compile a kernel with CUTILE_NO_DISK_CACHE=1 (set by the parent process). + // The parent asserts that no .cubin files appear on disk afterward. + warmup_test_module::_compile_warmup(std::slice::from_ref(&spec_a)) + .expect("no-disk-cache warmup failed"); + } + other => panic!("unknown worker role: {other}"), + } + }); +} diff --git a/cutile/tests/gpu/warmup_bench.rs b/cutile/tests/gpu/warmup_bench.rs new file mode 100644 index 0000000..bc0b01e --- /dev/null +++ b/cutile/tests/gpu/warmup_bench.rs @@ -0,0 +1,258 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Verification benchmarks for warmup, disk cache, and compilation dedup. +//! +//! These tests are designed to **demonstrate observable effects**, not just +//! "does it not crash". Run with: +//! +//! ```bash +//! CUTILE_JIT_LOG=1 cargo test --test gpu warmup_bench -- --nocapture 2>&1 +//! ``` +//! +//! You should see `[cutile::jit]` logs showing cache hits vs JIT compilations, +//! and timing comparisons printed to stdout. + +use crate::common; +use cutile::api; +use cutile::prelude::{DeviceOp, PartitionOp}; +use cutile::tile_kernel::{ + contains_cuda_function, execute_warmup, get_default_device, CompileOptions, + TileFunctionKey, TileKernel, WarmupSpec, +}; +use cutile_compiler::cuda_tile_runtime_utils::{ + get_compiler_version, get_cuda_toolkit_version, get_gpu_name, +}; +use cutile_compiler::specialization::SpecializationBits; +use std::time::Instant; + +// Use a separate module with a distinct name to avoid cache collisions with warmup.rs tests. +#[cutile::module] +mod bench_module { + use cutile::core::*; + + #[cutile::entry()] + fn vector_add( + z: &mut Tensor, + x: &Tensor, + y: &Tensor, + ) { + let tile_x = load_tile_like(x, z); + let tile_y = load_tile_like(y, z); + z.store(tile_x + tile_y); + } +} + +fn stride_args() -> Vec<(String, Vec)> { + vec![ + ("z".to_string(), vec![1]), + ("x".to_string(), vec![1]), + ("y".to_string(), vec![1]), + ] +} + +fn vector_add_spec_args(len: usize, tile: usize) -> Vec<(String, SpecializationBits)> { + let x = api::ones::(&[len]).sync().unwrap(); + let y = api::ones::(&[len]).sync().unwrap(); + let z = api::zeros::(&[len]).partition([tile]).sync().unwrap(); + let z_spec = z.unpartition().spec().clone(); + vec![ + ("z".to_string(), z_spec), + ("x".to_string(), x.spec().clone()), + ("y".to_string(), y.spec().clone()), + ] +} + +// Helper: run a vector_add kernel with specific generics, return wall-clock time. +fn timed_kernel_call(tile_size: &str) -> std::time::Duration { + let n: usize = tile_size.parse().unwrap(); + let t0 = Instant::now(); + let x = api::ones::(&[256]).sync().unwrap(); + let y = api::ones::(&[256]).sync().unwrap(); + let z = api::zeros::(&[256]).partition([n]).sync().unwrap(); + let _result = bench_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), tile_size.into()]) + .sync() + .unwrap(); + t0.elapsed() +} + +// Warmup eliminates first-call JIT latency +// +// Demonstrates that compile_warmup pre-compiles kernels so the first real +// call hits the memory cache instead of triggering JIT compilation. +// +// Uses different tile sizes (32 vs 64) to ensure fresh cache entries: +// - tile_size=32: called WITHOUT warmup → first call includes JIT +// - tile_size=64: called WITH warmup → first call is a cache hit +#[test] +fn warmup_eliminates_first_call_jit() { + common::with_test_stack(|| { + // ── Without warmup: first call includes JIT compilation ── + let cold_duration = timed_kernel_call("32"); + + // ── With warmup: pre-compile tile_size=64, then call it ── + let spec_args_64 = vector_add_spec_args(256, 64); + let warmup_t0 = Instant::now(); + bench_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "64".into()]) + .with_strides(stride_args()) + .with_spec_args(spec_args_64.clone())]) + .expect("compile_warmup failed"); + let warmup_duration = warmup_t0.elapsed(); + + // Now call the warmed-up kernel — should be near-instant (cache hit). + let warm_duration = timed_kernel_call("64"); + + println!("\n╔══════════════════════════════════════════════════════════╗"); + println!("║ Warmup Verification: First-Call Latency ║"); + println!("╠══════════════════════════════════════════════════════════╣"); + println!( + "║ Without warmup (tile=32): {:>10.1?} (includes JIT) ║", + cold_duration + ); + println!( + "║ Warmup step (tile=64): {:>10.1?} (pre-compile) ║", + warmup_duration + ); + println!( + "║ With warmup (tile=64): {:>10.1?} (cache hit) ║", + warm_duration + ); + println!("╠══════════════════════════════════════════════════════════╣"); + if warm_duration < cold_duration { + let speedup = cold_duration.as_secs_f64() / warm_duration.as_secs_f64().max(0.001); + println!( + "║ ✓ Warmed-up call is {:.1}x faster ║", + speedup + ); + } else { + println!("║ (both calls similar — kernel may already be cached) ║"); + } + println!("╚══════════════════════════════════════════════════════════╝\n"); + + // The warmed-up call should be significantly faster than the cold call. + // We don't assert a specific ratio because CI timing varies, but the + // JIT log output (CUTILE_JIT_LOG=1) provides definitive evidence. + // At minimum, verify the kernel IS in cache after warmup. + let device_id = get_default_device(); + let key = TileFunctionKey::new( + "bench_module".into(), + "vector_add".into(), + vec!["f32".into(), "64".into()], + stride_args(), + spec_args_64, + vec![], + None, + CompileOptions::default(), + bench_module::_SOURCE_HASH.into(), + get_gpu_name(device_id), + get_compiler_version(), + get_cuda_toolkit_version(), + ); + assert!( + contains_cuda_function(device_id, &key), + "kernel should be in memory cache after warmup" + ); + }); +} + +/// Second call always hits memory cache +// +/// Demonstrates that the second call to the same kernel is a memory cache hit, +/// regardless of whether warmup was used. +#[test] +fn second_call_hits_memory_cache() { + common::with_test_stack(|| { + // First call: JIT compiles (tile=16, unique to this test). + let first = timed_kernel_call("16"); + + // Second call: memory cache hit. + let second = timed_kernel_call("16"); + + println!("\n╔══════════════════════════════════════════════════════════╗"); + println!("║ Memory Cache Verification: 1st vs 2nd Call ║"); + println!("╠══════════════════════════════════════════════════════════╣"); + println!( + "║ First call (tile=16): {:>10.1?} (JIT compile) ║", + first + ); + println!( + "║ Second call (tile=16): {:>10.1?} (memory cache) ║", + second + ); + println!("╠══════════════════════════════════════════════════════════╣"); + if second < first { + let speedup = first.as_secs_f64() / second.as_secs_f64().max(0.001); + println!( + "║ ✓ Cache hit is {:.0}x faster ║", + speedup + ); + } + println!("╚══════════════════════════════════════════════════════════╝\n"); + + // Memory cache hit should be dramatically faster (100x+ typical). + assert!( + second < first, + "second call ({second:?}) should be faster than first ({first:?})" + ); + }); +} + +// Compile_warmup + execute_warmup combined flow +// +// Demonstrates a realistic warmup workflow: +// 1. compile_warmup pre-compiles the kernel +// 2. execute_warmup runs it with real data (also warms CUDA runtime) +// 3. Subsequent calls are fast +#[test] +fn full_warmup_workflow() { + common::with_test_stack(|| { + // Step 1: Pre-compile via compile_warmup. + let spec_args_128 = vector_add_spec_args(256, 128); + let t0 = Instant::now(); + bench_module::_compile_warmup(&[WarmupSpec::new("vector_add", vec!["f32".into(), "128".into()]) + .with_strides(stride_args()) + .with_spec_args(spec_args_128)]) + .expect("compile_warmup failed"); + let compile_time = t0.elapsed(); + + // Step 2: Execute warmup with real data (warms CUDA runtime). + let t1 = Instant::now(); + execute_warmup(|| { + let x = api::ones::(&[256]).sync()?; + let y = api::ones::(&[256]).sync()?; + let z = api::zeros::(&[256]).partition([128]).sync()?; + let _result = bench_module::vector_add(z, &x, &y) + .generics(vec!["f32".into(), "128".into()]) + .sync()?; + Ok(()) + }) + .expect("execute_warmup failed"); + let execute_time = t1.elapsed(); + + // Step 3: Production call — should be fast. + let production_time = timed_kernel_call("128"); + + println!("\n╔══════════════════════════════════════════════════════════╗"); + println!("║ Full Warmup Workflow Verification ║"); + println!("╠══════════════════════════════════════════════════════════╣"); + println!( + "║ 1. compile_warmup: {:>10.1?} (JIT to cache) ║", + compile_time + ); + println!( + "║ 2. execute_warmup: {:>10.1?} (cache + CUDA init) ║", + execute_time + ); + println!( + "║ 3. production call: {:>10.1?} (fully warm) ║", + production_time + ); + println!("╚══════════════════════════════════════════════════════════╝\n"); + + // Timing can vary significantly on shared/noisy GPUs, so keep this benchmark informative + // without enforcing a strict ordering between compile and execute phases. + }); +} diff --git a/cutile/tests/warmup.rs b/cutile/tests/warmup.rs new file mode 100644 index 0000000..e443d50 --- /dev/null +++ b/cutile/tests/warmup.rs @@ -0,0 +1,210 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Tests for cache key correctness, JitStore integration, and warmup APIs. +//! +//! - `cache_key_*` tests are CPU-only (no GPU required). +//! - `warmup_*` tests require GPU (compile + launch). + +use cutile::tile_kernel::{CompileOptions, EntryMeta, FunctionKey, TileFunctionKey, WarmupSpec}; +use cutile_compiler::specialization::{DivHint, SpecializationBits}; + +fn default_key() -> cutile::tile_kernel::TileFunctionKeyBuilder { + TileFunctionKey::builder("m", "f") + .source_hash("hash") + .gpu_name("sm_90") + .compiler_version("0.0.1") + .cuda_toolkit_version("12.4") +} + +// TileFunctionKey hash properties +#[test] +fn cache_key_hash_deterministic() { + let key1 = TileFunctionKey::builder("mod", "fn") + .generics(vec!["f32".into()]) + .source_hash("abc123") + .gpu_name("sm_90") + .compiler_version("0.0.1-alpha") + .cuda_toolkit_version("12.4") + .build(); + let key2 = key1.clone(); + assert_eq!(key1.get_hash_string(), key2.get_hash_string()); + assert_eq!(key1.get_disk_hash_string(), key2.get_disk_hash_string()); +} + +#[test] +fn cache_key_different_source_hash() { + let key_a = default_key().source_hash("hash_v1").build(); + let key_b = default_key().source_hash("hash_v2").build(); + assert_ne!(key_a.get_hash_string(), key_b.get_hash_string()); + assert_ne!(key_a.get_disk_hash_string(), key_b.get_disk_hash_string()); +} + +#[test] +fn cache_key_different_gpu_name() { + let key_a = default_key().gpu_name("sm_80").build(); + let key_b = default_key().gpu_name("sm_90").build(); + assert_ne!(key_a.get_hash_string(), key_b.get_hash_string()); + assert_ne!(key_a.get_disk_hash_string(), key_b.get_disk_hash_string()); +} + +#[test] +fn cache_key_different_compiler_version() { + let key_a = default_key().compiler_version("0.0.1").build(); + let key_b = default_key().compiler_version("0.0.2").build(); + assert_ne!(key_a.get_hash_string(), key_b.get_hash_string()); +} + +#[test] +fn cache_key_different_cuda_toolkit_version() { + let key_a = default_key().cuda_toolkit_version("12.4").build(); + let key_b = default_key().cuda_toolkit_version("12.6").build(); + assert_ne!(key_a.get_hash_string(), key_b.get_hash_string()); +} + +#[test] +fn cache_key_different_generics() { + let key_a = default_key().generics(vec!["f32".into()]).build(); + let key_b = default_key().generics(vec!["f16".into()]).build(); + assert_ne!(key_a.get_hash_string(), key_b.get_hash_string()); +} + +#[test] +fn cache_key_disk_hash_is_sha256_length() { + let key = default_key().build(); + let disk_hash = key.get_disk_hash_string(); + // SHA-256 hex output = 64 characters. + assert_eq!(disk_hash.len(), 64, "disk hash should be 64 hex chars"); + assert!( + disk_hash.chars().all(|c| c.is_ascii_hexdigit()), + "disk hash should be lowercase hex" + ); +} + +// When `nvcc` is unavailable, `get_cuda_toolkit_version()` returns `"unknown"`. +// Verify that `"unknown"` still produces a distinct key from any real version, +// so kernels compiled without a known toolkit version are never falsely reused +// when a real version becomes available (or vice versa). +#[test] +fn cache_key_toolkit_unknown_is_distinct() { + let key_unknown = default_key().cuda_toolkit_version("unknown").build(); + let key_real = default_key().cuda_toolkit_version("12.4").build(); + assert_ne!( + key_unknown.get_hash_string(), + key_real.get_hash_string(), + "unknown toolkit must produce distinct memory key" + ); + assert_ne!( + key_unknown.get_disk_hash_string(), + key_real.get_disk_hash_string(), + "unknown toolkit must produce distinct disk key" + ); +} + + +#[test] +fn cache_key_source_hash_change_invalidates() { + let make_key = |source_hash: &str| -> TileFunctionKey { + TileFunctionKey::builder("linalg", "matmul") + .generics(vec!["f32".into(), "128".into()]) + .stride_args(vec![("a".into(), vec![1, 128])]) + .grid((4, 4, 1)) + .source_hash(source_hash) + .gpu_name("sm_90") + .compiler_version("0.1.0") + .cuda_toolkit_version("12.4") + .build() + }; + let key_v1 = make_key("aabbccdd11223344"); + let key_v2 = make_key("eeff0011deadbeef"); + assert_ne!(key_v1.get_hash_string(), key_v2.get_hash_string()); + assert_ne!(key_v1.get_disk_hash_string(), key_v2.get_disk_hash_string()); +} + +// Cache keys must distinguish data alignments to prevent incorrect kernel reuse. +#[test] +fn cache_key_different_spec_args() { + let spec_aligned = SpecializationBits { + shape_div: vec![DivHint::from_value(16), DivHint::from_value(16)], + stride_div: vec![DivHint::from_value(16), DivHint::from_value(16)], + stride_one: vec![false, true], + base_ptr_div: DivHint::from_ptr(16), + elements_disjoint: true, + }; + let spec_misaligned = SpecializationBits { + shape_div: vec![DivHint::from_value(4), DivHint::from_value(4)], + stride_div: vec![DivHint::from_value(4), DivHint::from_value(4)], + stride_one: vec![false, true], + base_ptr_div: DivHint::from_ptr(4), + elements_disjoint: true, + }; + let key_a = default_key() + .spec_args(vec![("x".into(), spec_aligned)]) + .build(); + let key_b = default_key() + .spec_args(vec![("x".into(), spec_misaligned)]) + .build(); + assert_ne!( + key_a.get_hash_string(), + key_b.get_hash_string(), + "different SpecializationBits must produce distinct memory keys" + ); + assert_ne!( + key_a.get_disk_hash_string(), + key_b.get_disk_hash_string(), + "different SpecializationBits must produce distinct disk keys" + ); +} + +#[test] +fn cache_key_different_compile_options() { + let key_a = default_key() + .compile_options(CompileOptions::default().max_divisibility(8)) + .build(); + let key_b = default_key() + .compile_options(CompileOptions::default().max_divisibility(16)) + .build(); + assert_ne!( + key_a.get_hash_string(), + key_b.get_hash_string(), + "different CompileOptions must produce distinct memory keys" + ); + assert_ne!( + key_a.get_disk_hash_string(), + key_b.get_disk_hash_string(), + "different CompileOptions must produce distinct disk keys" + ); + + let key_c = default_key() + .compile_options(CompileOptions::default().occupancy(2)) + .build(); + let key_d = default_key() + .compile_options(CompileOptions::default().occupancy(4)) + .build(); + assert_ne!(key_c.get_hash_string(), key_d.get_hash_string()); +} + +#[test] +fn warmup_spec_builder() { + let spec = WarmupSpec::new("my_kernel", vec!["f32".into(), "128".into()]) + .with_strides(vec![("x".into(), vec![1, 128])]) + .with_const_grid((4, 1, 1)); + assert_eq!(spec.function_name, "my_kernel"); + assert_eq!(spec.function_generics, vec!["f32", "128"]); + assert_eq!(spec.stride_args.len(), 1); + assert_eq!(spec.const_grid, Some((4, 1, 1))); +} + +#[test] +fn entry_meta_fields() { + let meta = EntryMeta { + module_name: "linalg", + function_name: "vector_add", + function_entry: "vector_add_entry", + }; + assert_eq!(meta.module_name, "linalg"); + assert_eq!(meta.function_name, "vector_add"); + assert_eq!(meta.function_entry, "vector_add_entry"); +} diff --git a/scripts/run_cpu_tests.sh b/scripts/run_cpu_tests.sh index 673143c..2a23d2d 100755 --- a/scripts/run_cpu_tests.sh +++ b/scripts/run_cpu_tests.sh @@ -37,6 +37,13 @@ run_step \ "cutile type inference sanity regressions" \ cargo test -p cutile --test type_inference_sanity +run_step \ + "cutile warmup/cache-key CPU tests" \ + cargo test -p cutile --test warmup + +run_step \ + cargo test -p cuda-async --test jit_store + print_summary_and_exit \ "All CPU tests passed!" \ "Some CPU checks failed. See output above for details."