From b0979baf79826a9942dc9fdf54863a69253f5501 Mon Sep 17 00:00:00 2001 From: sunteng Date: Mon, 13 Apr 2026 08:49:55 +0800 Subject: [PATCH] feat: custom MemPool support for async tensor allocation - Add MemPool wrapping CUmemoryPool with device ownership and RAII drop - set_device_pool() registers pool per-device; rejects cross-device pools - ExecutionContext::new() auto-resolves pool from stream's device via pool_for_stream() - All scheduling paths (.sync, .await, .schedule, .sync_on, .async_on) carry pool - Add cuda-async integration tests covering lifecycle, scheduling freeze, and cross-device rejection --- cuda-async/src/device_context.rs | 74 ++++++- cuda-async/src/device_operation.rs | 22 +- cuda-async/tests/pool_allocation.rs | 328 ++++++++++++++++++++++++++++ cuda-core/src/api.rs | 13 ++ cuda-core/src/cudarc_shim.rs | 70 ++++++ cuda-core/src/runtime.rs | 86 +++++++- cutile/src/api.rs | 6 +- cutile/src/tensor.rs | 3 +- scripts/run_gpu_tests.sh | 8 + 9 files changed, 601 insertions(+), 9 deletions(-) create mode 100644 cuda-async/tests/pool_allocation.rs diff --git a/cuda-async/src/device_context.rs b/cuda-async/src/device_context.rs index 11732d33..d158a43c 100644 --- a/cuda-async/src/device_context.rs +++ b/cuda-async/src/device_context.rs @@ -7,7 +7,7 @@ use crate::error::{device_assert, device_error, DeviceError}; use crate::scheduling_policies::{SchedulingPolicy, StreamPoolRoundRobin}; -use cuda_core::{Device, Function, Module, Stream}; +use cuda_core::{Device, Function, Module, Stream, MemPool}; use std::cell::Cell; use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -88,6 +88,7 @@ pub struct AsyncDeviceContext { device: Arc, deallocator_stream: Arc, policy: Arc, + pool: Option>, functions: DeviceFunctions, validators: DeviceFunctionValidators, } @@ -163,6 +164,7 @@ pub fn new_device_context( device, deallocator_stream, policy, + pool: None, functions: HashMap::new(), validators: HashMap::new(), }) @@ -204,6 +206,7 @@ pub fn init_with_default_policy( device, deallocator_stream, policy: Arc::new(policy), + pool: None, functions: HashMap::new(), validators: HashMap::new(), }; @@ -317,6 +320,75 @@ pub fn set_default_device(default_device_id: usize) { }) } +/// Set a custom memory pool for the given device **on the current thread**. +/// +/// Subsequent allocations on this device will use the given pool instead of the +/// default pool. The pool is resolved at scheduling time (`.sync()`, `.await`, +/// `.schedule()`, `.sync_on()`, `.async_on()`) and carried on +/// [`ExecutionContext`](crate::device_operation::ExecutionContext), so it also +/// applies to futures that are later polled on other threads. +/// +/// # Thread-locality +/// +/// `AsyncDeviceContext` — and therefore the pool registration — lives in a +/// `thread_local!`. Calling `set_device_pool(0, pool)` on thread A does **not** +/// affect allocations scheduled by thread B on device 0. +/// +/// If you build a `DeviceFuture` on one thread and move it to another, the pool +/// travels with the future via its `ExecutionContext` snapshot — the destination +/// thread does not need its own `set_device_pool` call. But if thread B +/// independently creates ops via `.sync()`/`.await`, those ops see thread B's +/// pool (typically `None` unless B also called `set_device_pool`). +/// +/// Multi-threaded workers that want a shared pool should each call +/// `set_device_pool` during their initialization. +/// +/// # Errors +/// +/// Returns [`DeviceError::Context`](crate::error::DeviceError::Context) if +/// `pool` was created on a different device than `device_id`. +pub fn set_device_pool(device_id: usize, pool: Arc) -> Result<(), DeviceError> { + let pool_device = pool.device().ordinal(); + device_assert( + device_id, + pool_device == device_id, + &format!("pool belongs to device {pool_device}, expected device {device_id}"), + )?; + with_global_device_context_mut(device_id, |device_context| { + device_context.pool = Some(pool); + }) +} + +/// Clear the custom memory pool for the given device **on the current thread**, +/// reverting to the default pool. +/// +/// Only affects the calling thread's pool registration; see +/// [`set_device_pool`] for the full thread-locality contract. In-flight +/// `DeviceFuture`s that already captured the pool are unaffected (the pool is +/// kept alive via `Arc` until those futures complete). +pub fn clear_device_pool(device_id: usize) -> Result<(), DeviceError> { + with_global_device_context_mut(device_id, |device_context| { + device_context.pool = None; + }) +} + +/// Returns the custom memory pool registered for the given device **on the +/// current thread**, if any. +/// +/// Returns `Ok(None)` when the calling thread has not registered a pool, even +/// if another thread has done so. See [`set_device_pool`] for thread-locality. +pub fn get_device_pool(device_id: usize) -> Result>, DeviceError> { + with_global_device_context(device_id, |device_context| device_context.pool.clone()) +} + +/// Resolve the custom memory pool associated with the device that owns `stream`. +/// +/// Errors from the device-context lookup are downgraded to `None`; this is the +/// single choke-point for that decision so callers don't each re-derive it. +pub fn pool_for_stream(stream: &Arc) -> Option> { + get_device_pool(stream.device().ordinal()).ok().flatten() +} + /// Run a closure with the scheduling policy of the current thread's default device. /// /// This is the function called internally by [`DeviceOp::sync()`] and by the diff --git a/cuda-async/src/device_operation.rs b/cuda-async/src/device_operation.rs index 0dfc2286..27776609 100644 --- a/cuda-async/src/device_operation.rs +++ b/cuda-async/src/device_operation.rs @@ -5,11 +5,11 @@ //! Lazy, composable GPU operations and combinator types. -use crate::device_context::with_default_device_policy; +use crate::device_context::{pool_for_stream, with_default_device_policy}; use crate::device_future::DeviceFuture; use crate::error::{device_error, DeviceError}; use crate::scheduling_policies::SchedulingPolicy; -use cuda_core::{Device, Stream}; +use cuda_core::{Device, Stream, MemPool}; use std::cell::{Cell, UnsafeCell}; use std::fmt::Debug; use std::future::IntoFuture; @@ -61,16 +61,19 @@ pub struct ExecutionContext { ordinal: DeviceOrdinal, cuda_stream: Arc, device: Arc, + pool: Option>, } impl ExecutionContext { pub fn new(cuda_stream: Arc) -> Self { let device = cuda_stream.device().clone(); let ordinal = device.ordinal(); + let pool = pool_for_stream(&cuda_stream); Self { cuda_stream, device, ordinal, + pool, } } pub fn get_cuda_stream(&self) -> &Arc { @@ -82,6 +85,21 @@ impl ExecutionContext { pub fn get_device_id(&self) -> DeviceOrdinal { self.ordinal } + pub fn get_pool(&self) -> Option<&Arc> { + self.pool.as_ref() + } + /// Allocates device memory on this context's stream, using the custom pool if set. + /// + /// # Safety + /// The stream must be valid and not destroyed. + pub unsafe fn alloc_async(&self, num_bytes: usize) -> cuda_core::sys::CUdeviceptr { + match &self.pool { + Some(pool) => { + cuda_core::malloc_from_pool_async(num_bytes, pool, &self.cuda_stream) + } + None => cuda_core::malloc_async(num_bytes, &self.cuda_stream), + } + } #[expect( dead_code, reason = "kept for direct synchronous execution in tests and future blocking APIs" diff --git a/cuda-async/tests/pool_allocation.rs b/cuda-async/tests/pool_allocation.rs new file mode 100644 index 00000000..be06cfd1 --- /dev/null +++ b/cuda-async/tests/pool_allocation.rs @@ -0,0 +1,328 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Integration tests for custom memory pool allocation. +//! +//! Tests that exercise the `MemPool` lifecycle, device-level pool +//! configuration, and pool-aware allocation through `ExecutionContext`. +//! +//! Each test runs on a fresh thread so that thread-local `DEVICE_CONTEXTS` +//! starts clean. + +use cuda_async::device_context::{ + clear_device_pool, get_device_pool, global_policy, init_device_contexts, set_device_pool, + with_device, +}; +use cuda_async::device_operation::{value, DeviceOp}; +use cuda_async::prelude::*; + +fn on_fresh_thread(f: F) { + std::thread::spawn(f).join().expect("test thread panicked"); +} + +// --------------------------------------------------------------------------- +// MemPool RAII lifecycle +// --------------------------------------------------------------------------- + +#[test] +fn create_and_drop_mem_pool() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + + assert!(!pool.cu_pool().is_null()); + }); +} + +#[test] +fn default_mem_pool_is_not_owned() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.default_mem_pool()) + .expect("get context failed") + .expect("default pool failed"); + + assert!(!pool.cu_pool().is_null()); + }); +} + +#[test] +fn set_release_threshold() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + + pool.set_release_threshold(u64::MAX) + .expect("set threshold failed"); + + pool.set_release_threshold(1024 * 1024) + .expect("set finite threshold failed"); + }); +} + +// --------------------------------------------------------------------------- +// Device-level pool configuration +// --------------------------------------------------------------------------- + +#[test] +fn set_and_get_device_pool() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool_opt = get_device_pool(0).expect("get pool failed"); + assert!(pool_opt.is_none()); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + let pool_ptr = pool.cu_pool(); + set_device_pool(0, pool).expect("set pool failed"); + + let retrieved = get_device_pool(0) + .expect("get pool failed") + .expect("pool should be set"); + assert_eq!(retrieved.cu_pool(), pool_ptr); + }); +} + +#[test] +fn clear_device_pool_reverts_to_none() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + set_device_pool(0, pool).expect("set pool failed"); + + assert!(get_device_pool(0).expect("get failed").is_some()); + + clear_device_pool(0).expect("clear pool failed"); + assert!(get_device_pool(0).expect("get failed").is_none()); + }); +} + +#[test] +fn set_device_pool_rejects_cross_device_pool() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + + let err = set_device_pool(99, pool) + .expect_err("expected cross-device pool to be rejected"); + match err { + cuda_async::error::DeviceError::Context { device_id, message } => { + assert_eq!(device_id, 99, "error should point to target device"); + assert!( + message.contains("pool belongs to device 0") + && message.contains("expected device 99"), + "message should name both devices, got: {message}" + ); + } + other => panic!("expected DeviceError::Context, got {other:?}"), + } + + assert!(get_device_pool(0).expect("get pool failed").is_none()); + }); +} + +// --------------------------------------------------------------------------- +// Pool-aware allocation through DeviceOp +// --------------------------------------------------------------------------- + +#[test] +fn alloc_with_custom_pool_via_device_op() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + pool.set_release_threshold(u64::MAX) + .expect("set threshold failed"); + set_device_pool(0, pool).expect("set pool failed"); + + let op = with_context(|ctx| { + let num_bytes = 1024; + let dptr = unsafe { ctx.alloc_async(num_bytes) }; + assert!(dptr != 0, "allocation returned null pointer"); + value(dptr) + }); + let dptr = op.sync().expect("device op failed"); + assert!(dptr != 0); + }); +} + +#[test] +fn alloc_without_pool_uses_default() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let op = with_context(|ctx| { + assert!(ctx.get_pool().is_none()); + let num_bytes = 1024; + let dptr = unsafe { ctx.alloc_async(num_bytes) }; + assert!(dptr != 0, "allocation returned null pointer"); + value(dptr) + }); + let dptr = op.sync().expect("device op failed"); + assert!(dptr != 0); + }); +} + +#[test] +fn pool_is_frozen_at_scheduling_time() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool_a = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool A creation failed"); + let pool_b = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool B creation failed"); + + let pool_a_ptr = pool_a.cu_pool() as usize; + + set_device_pool(0, pool_a).expect("set pool_a failed"); + + // Schedule while pool_a is active — freezes pool_a into ExecutionContext. + let policy = global_policy(0).expect("get policy failed"); + let future = with_context(move |ctx| { + let p = ctx.get_pool().expect("pool should be present"); + assert_eq!(p.cu_pool() as usize, pool_a_ptr, "should use frozen pool_a, not pool_b"); + value(()) + }) + .schedule(&policy) + .expect("schedule failed"); + + // Change global pool AFTER scheduling — must not affect the already-frozen ExecutionContext. + set_device_pool(0, pool_b).expect("set pool_b failed"); + + // Execute: DeviceFuture carries pool_a in its ExecutionContext, pool_b is ignored. + futures::executor::block_on(future).expect("future failed"); + }); +} + +// --------------------------------------------------------------------------- +// Explicit .schedule() path +// --------------------------------------------------------------------------- + +#[test] +fn schedule_applies_device_pool() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + pool.set_release_threshold(u64::MAX) + .expect("set threshold failed"); + let pool_ptr = pool.cu_pool() as usize; + set_device_pool(0, pool).expect("set pool failed"); + + let policy = global_policy(0).expect("get policy failed"); + let future = with_context(move |ctx| { + let p = ctx.get_pool().expect("pool should be present via schedule"); + assert_eq!(p.cu_pool() as usize, pool_ptr, "schedule must pick up device pool"); + let dptr = unsafe { ctx.alloc_async(512) }; + assert!(dptr != 0, "allocation returned null pointer"); + value(dptr) + }) + .schedule(&policy) + .expect("schedule failed"); + + let dptr = futures::executor::block_on(future).expect("future failed"); + assert!(dptr != 0); + }); +} + +#[test] +fn sync_on_applies_device_pool() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool creation failed"); + pool.set_release_threshold(u64::MAX) + .expect("set threshold failed"); + let pool_ptr = pool.cu_pool() as usize; + set_device_pool(0, pool).expect("set pool failed"); + + let stream = global_policy(0) + .expect("get policy failed") + .next_stream() + .expect("get stream failed"); + + let dptr = with_context(move |ctx| { + let p = ctx.get_pool().expect("pool should be present via sync_on"); + assert_eq!(p.cu_pool() as usize, pool_ptr, "sync_on must pick up device pool"); + let dptr = unsafe { ctx.alloc_async(512) }; + assert!(dptr != 0, "allocation returned null pointer"); + value(dptr) + }) + .sync_on(&stream) + .expect("sync_on failed"); + assert!(dptr != 0); + }); +} + +// --------------------------------------------------------------------------- +// Multiple pools +// --------------------------------------------------------------------------- + +#[test] +fn switch_between_pools() { + on_fresh_thread(|| { + init_device_contexts(0, 1).expect("init failed (requires GPU)"); + + let pool_a = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool A creation failed"); + let pool_b = with_device(0, |device| device.new_mem_pool()) + .expect("get context failed") + .expect("pool B creation failed"); + + let pool_a_ptr = pool_a.cu_pool(); + let pool_b_ptr = pool_b.cu_pool(); + assert_ne!(pool_a_ptr, pool_b_ptr, "pools should be distinct"); + + set_device_pool(0, pool_a).expect("set A failed"); + let op_a = with_context(|ctx| { + let dptr = unsafe { ctx.alloc_async(512) }; + assert!(dptr != 0); + value(()) + }); + op_a.sync().expect("op A failed"); + + set_device_pool(0, pool_b).expect("set B failed"); + let op_b = with_context(|ctx| { + let dptr = unsafe { ctx.alloc_async(512) }; + assert!(dptr != 0); + value(()) + }); + op_b.sync().expect("op B failed"); + + clear_device_pool(0).expect("clear failed"); + let op_default = with_context(|ctx| { + assert!(ctx.get_pool().is_none()); + let dptr = unsafe { ctx.alloc_async(512) }; + assert!(dptr != 0); + value(()) + }); + op_default.sync().expect("default op failed"); + }); +} diff --git a/cuda-core/src/api.rs b/cuda-core/src/api.rs index 11fe50dc..ae493d6f 100644 --- a/cuda-core/src/api.rs +++ b/cuda-core/src/api.rs @@ -78,6 +78,19 @@ pub unsafe fn malloc_async(num_bytes: usize, stream: &Arc) -> sys::CUdev .expect("Malloc async failed.") } +/// Asynchronously allocates `num_bytes` of device memory from a specific pool on the given stream. +/// +/// # Safety +/// `stream` must be a valid, non-destroyed CUDA stream. `pool` must be a valid memory pool. +pub unsafe fn malloc_from_pool_async( + num_bytes: usize, + pool: &Arc, + stream: &Arc, +) -> sys::CUdeviceptr { + crate::cudarc_shim::pool::malloc_from_pool_async(pool.cu_pool(), stream.cu_stream(), num_bytes) + .expect("Malloc from pool async failed.") +} + /// Asynchronously frees device memory on the given stream. /// /// # Safety diff --git a/cuda-core/src/cudarc_shim.rs b/cuda-core/src/cudarc_shim.rs index 7afaac98..1413bb9c 100644 --- a/cuda-core/src/cudarc_shim.rs +++ b/cuda-core/src/cudarc_shim.rs @@ -455,6 +455,76 @@ pub(crate) mod event { } } +/// Low-level CUDA memory pool operations. +pub(crate) mod pool { + use super::{DriverError, IntoResult}; + use std::mem::MaybeUninit; + + /// Creates a new memory pool with the given properties. + /// + /// # Safety + /// `props` must describe a valid pool configuration for an available device. + pub unsafe fn create( + props: &cuda_bindings::CUmemPoolProps, + ) -> Result { + let mut pool = MaybeUninit::uninit(); + cuda_bindings::cuMemPoolCreate(pool.as_mut_ptr(), props as *const _ as *mut _) + .result()?; + Ok(pool.assume_init()) + } + + /// Destroys a memory pool. + /// + /// # Safety + /// `pool` must be valid and all allocations from it must have been freed. + pub unsafe fn destroy(pool: cuda_bindings::CUmemoryPool) -> Result<(), DriverError> { + cuda_bindings::cuMemPoolDestroy(pool).result() + } + + /// Returns the default memory pool for the given device. + /// + /// # Safety + /// `device` must be a valid device ordinal. + pub unsafe fn get_default( + device: cuda_bindings::CUdevice, + ) -> Result { + let mut pool = MaybeUninit::uninit(); + cuda_bindings::cuDeviceGetDefaultMemPool(pool.as_mut_ptr(), device).result()?; + Ok(pool.assume_init()) + } + + /// Sets the release threshold for a memory pool. + /// + /// # Safety + /// `pool` must be a valid pool handle. + pub unsafe fn set_release_threshold( + pool: cuda_bindings::CUmemoryPool, + threshold: u64, + ) -> Result<(), DriverError> { + cuda_bindings::cuMemPoolSetAttribute( + pool, + cuda_bindings::CUmemPool_attribute_enum_CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + &threshold as *const _ as *mut _, + ) + .result() + } + + /// Allocates device memory from a specific pool asynchronously. + /// + /// # Safety + /// `pool` and `stream` must be valid handles. + pub unsafe fn malloc_from_pool_async( + pool: cuda_bindings::CUmemoryPool, + stream: cuda_bindings::CUstream, + num_bytes: usize, + ) -> Result { + let mut dev_ptr = MaybeUninit::uninit(); + cuda_bindings::cuMemAllocFromPoolAsync(dev_ptr.as_mut_ptr(), num_bytes, pool, stream) + .result()?; + Ok(dev_ptr.assume_init()) + } +} + /// Low-level CUDA memory allocation, transfer, and management operations. #[allow(dead_code)] pub(crate) mod memory { diff --git a/cuda-core/src/runtime.rs b/cuda-core/src/runtime.rs index 707fce97..8a92a4e1 100644 --- a/cuda-core/src/runtime.rs +++ b/cuda-core/src/runtime.rs @@ -12,7 +12,7 @@ use std::ffi::{c_int, c_void, CString}; use std::sync::Arc; -use crate::cudarc_shim::{ctx, device, module, primary_ctx, stream}; +use crate::cudarc_shim::{ctx, device, module, pool, primary_ctx, stream}; use crate::error::*; use crate::init; @@ -198,6 +198,90 @@ impl Device { owned: true, })) } + + /// Creates a new memory pool on this device. + /// + /// The returned pool is owned — it will be destroyed when the last `Arc` + /// is dropped. Pair with [`MemPool::set_release_threshold`] to control + /// when the pool returns memory to the OS. + pub fn new_mem_pool(self: &Arc) -> Result, DriverError> { + self.bind_to_thread()?; + let mut props: cuda_bindings::CUmemPoolProps = unsafe { std::mem::zeroed() }; + props.allocType = cuda_bindings::CUmemAllocationType_enum_CU_MEM_ALLOCATION_TYPE_PINNED; + props.handleTypes = cuda_bindings::CUmemAllocationHandleType_enum_CU_MEM_HANDLE_TYPE_NONE; + props.location.type_ = cuda_bindings::CUmemLocationType_enum_CU_MEM_LOCATION_TYPE_DEVICE; + props.location.__bindgen_anon_1.id = self.ordinal as c_int; + let cu_pool = unsafe { pool::create(&props) }?; + Ok(Arc::new(MemPool { + cu_pool, + device: self.clone(), + owned: true, + })) + } + + /// Returns the driver-owned default memory pool for this device. + /// + /// The returned wrapper is **not owned** — dropping it does not destroy the + /// default pool, which is shared across all users of the device. + pub fn default_mem_pool(self: &Arc) -> Result, DriverError> { + self.bind_to_thread()?; + let cu_pool = unsafe { pool::get_default(self.cu_device) }?; + Ok(Arc::new(MemPool { + cu_pool, + device: self.clone(), + owned: false, + })) + } +} + +/// A CUDA memory pool handle. +/// +/// Can be either **owned** (created via [`Device::new_mem_pool`], destroyed on +/// drop) or **borrowed** (created via [`Device::default_mem_pool`], does NOT +/// destroy on drop). +/// +/// Used by async tensor allocation via `cuMemAllocFromPoolAsync` when a pool +/// is registered via `cuda_async::device_context::set_device_pool`. +#[derive(Debug)] +pub struct MemPool { + pub(crate) cu_pool: cuda_bindings::CUmemoryPool, + pub(crate) device: Arc, + owned: bool, +} + +unsafe impl Send for MemPool {} +unsafe impl Sync for MemPool {} + +impl Drop for MemPool { + fn drop(&mut self) { + if !self.owned { + return; + } + let _ = self.device.bind_to_thread(); + let _ = unsafe { pool::destroy(self.cu_pool) }; + } +} + +impl MemPool { + /// Returns the raw `CUmemoryPool` handle. + pub fn cu_pool(&self) -> cuda_bindings::CUmemoryPool { + self.cu_pool + } + + /// Returns a reference to the parent device. + pub fn device(&self) -> &Arc { + &self.device + } + + /// Sets the release threshold for this pool. + /// + /// Memory held by the pool is not returned to the OS until pool usage drops + /// below this threshold. Use `u64::MAX` to prevent the OS from reclaiming + /// pool memory (useful for inference workloads with stable memory footprints). + pub fn set_release_threshold(&self, threshold: u64) -> Result<(), DriverError> { + self.device.bind_to_thread()?; + unsafe { pool::set_release_threshold(self.cu_pool, threshold) } + } } /// A CUDA stream handle. diff --git a/cutile/src/api.rs b/cutile/src/api.rs index 7e072290..7002bbf5 100644 --- a/cutile/src/api.rs +++ b/cutile/src/api.rs @@ -143,7 +143,7 @@ use cuda_async::error::DeviceError; use cuda_core::curand::{RandNormal, RandUniform, RNG}; use cuda_core::sys::CUdeviceptr; use cuda_core::DType; -use cuda_core::{malloc_async, memcpy_dtod_async, memcpy_dtoh_async, memcpy_htod_async}; +use cuda_core::{memcpy_dtod_async, memcpy_dtoh_async, memcpy_htod_async}; use half::f16; use std::alloc::{alloc, Layout}; use std::future::IntoFuture; @@ -170,7 +170,7 @@ impl DeviceOp for CopyDeviceToDevice { ctx: &ExecutionContext, ) -> Result<::Output, DeviceError> { let num_bytes = self.num_elements * std::mem::size_of::(); - let dst = malloc_async(num_bytes, ctx.get_cuda_stream()); + let dst = ctx.alloc_async(num_bytes); memcpy_dtod_async::(dst, self.src_ptr, self.num_elements, ctx.get_cuda_stream()); Ok(Tensor::from_raw_parts( dst, @@ -392,7 +392,7 @@ impl DeviceOp for CopyHostVecToDevice { let num_elements = vec.len(); let shape = vec![num_elements as i32]; let strides = vec![1]; - let dptr = malloc_async(element_size * num_elements, ctx.get_cuda_stream()); + let dptr = ctx.alloc_async(element_size * num_elements); memcpy_htod_async(dptr, vec.as_ptr(), num_elements, ctx.get_cuda_stream()); Ok(Tensor::from_raw_parts( dptr, diff --git a/cutile/src/tensor.rs b/cutile/src/tensor.rs index 9457f004..70baa8f8 100644 --- a/cutile/src/tensor.rs +++ b/cutile/src/tensor.rs @@ -207,7 +207,6 @@ use anyhow::Result; use cuda_async::device_buffer::{DeviceBuffer, DevicePointer}; use cuda_async::device_operation; use cuda_async::device_operation::{value, DeviceOp, IntoDeviceOp, Value}; -use cuda_core::malloc_async; use cuda_core::sys::CUdeviceptr; use cuda_core::{DType, DTypeId}; use std::fmt::Debug; @@ -619,7 +618,7 @@ impl Tensor { let num_bytes = len * size_of::(); value(MaybeUninit::new(unsafe { Self::from_raw_parts( - malloc_async(num_bytes, ctx.get_cuda_stream()), + ctx.alloc_async(num_bytes), num_bytes, ctx.get_device_id(), vec![len as i32], diff --git a/scripts/run_gpu_tests.sh b/scripts/run_gpu_tests.sh index 3c0a091c..14212fbf 100755 --- a/scripts/run_gpu_tests.sh +++ b/scripts/run_gpu_tests.sh @@ -45,6 +45,14 @@ run_step \ "cutile GPU error-quality tests" \ cargo test -p cutile --test gpu +for test_target in \ + pool_allocation +do + run_step \ + "cuda-async GPU integration test ${test_target}" \ + cargo test -p cuda-async --test "$test_target" +done + print_summary_and_exit \ "All GPU tests passed!" \ "Some GPU tests failed. See output above for details."