Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion cuda-async/src/device_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::error::{device_assert, device_error, DeviceError};
use crate::scheduling_policies::{SchedulingPolicy, StreamPoolRoundRobin};
use cuda_core::{Device, Function, Module, Stream};
use cuda_core::{Device, Function, Module, Stream, MemPool};
use std::cell::Cell;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
Expand Down Expand Up @@ -88,6 +88,7 @@ pub struct AsyncDeviceContext {
device: Arc<Device>,
deallocator_stream: Arc<Stream>,
policy: Arc<dyn SchedulingPolicy>,
pool: Option<Arc<MemPool>>,
functions: DeviceFunctions,
validators: DeviceFunctionValidators,
}
Expand Down Expand Up @@ -163,6 +164,7 @@ pub fn new_device_context(
device,
deallocator_stream,
policy,
pool: None,
functions: HashMap::new(),
validators: HashMap::new(),
})
Expand Down Expand Up @@ -204,6 +206,7 @@ pub fn init_with_default_policy(
device,
deallocator_stream,
policy: Arc::new(policy),
pool: None,
functions: HashMap::new(),
validators: HashMap::new(),
};
Expand Down Expand Up @@ -317,6 +320,75 @@ pub fn set_default_device(default_device_id: usize) {
})
}

/// Set a custom memory pool for the given device **on the current thread**.
///
/// Subsequent allocations on this device will use the given pool instead of the
/// default pool. The pool is resolved at scheduling time (`.sync()`, `.await`,
/// `.schedule()`, `.sync_on()`, `.async_on()`) and carried on
/// [`ExecutionContext`](crate::device_operation::ExecutionContext), so it also
/// applies to futures that are later polled on other threads.
///
/// # Thread-locality
///
/// `AsyncDeviceContext` — and therefore the pool registration — lives in a
/// `thread_local!`. Calling `set_device_pool(0, pool)` on thread A does **not**
/// affect allocations scheduled by thread B on device 0.
///
/// If you build a `DeviceFuture` on one thread and move it to another, the pool
/// travels with the future via its `ExecutionContext` snapshot — the destination
/// thread does not need its own `set_device_pool` call. But if thread B
/// independently creates ops via `.sync()`/`.await`, those ops see thread B's
/// pool (typically `None` unless B also called `set_device_pool`).
///
/// Multi-threaded workers that want a shared pool should each call
/// `set_device_pool` during their initialization.
///
/// # Errors
///
/// Returns [`DeviceError::Context`](crate::error::DeviceError::Context) if
/// `pool` was created on a different device than `device_id`.
pub fn set_device_pool(device_id: usize, pool: Arc<MemPool>) -> Result<(), DeviceError> {
let pool_device = pool.device().ordinal();
device_assert(
device_id,
pool_device == device_id,
&format!("pool belongs to device {pool_device}, expected device {device_id}"),
)?;
with_global_device_context_mut(device_id, |device_context| {
device_context.pool = Some(pool);
})
}

/// Clear the custom memory pool for the given device **on the current thread**,
/// reverting to the default pool.
///
/// Only affects the calling thread's pool registration; see
/// [`set_device_pool`] for the full thread-locality contract. In-flight
/// `DeviceFuture`s that already captured the pool are unaffected (the pool is
/// kept alive via `Arc` until those futures complete).
pub fn clear_device_pool(device_id: usize) -> Result<(), DeviceError> {
with_global_device_context_mut(device_id, |device_context| {
device_context.pool = None;
})
}

/// Returns the custom memory pool registered for the given device **on the
/// current thread**, if any.
///
/// Returns `Ok(None)` when the calling thread has not registered a pool, even
/// if another thread has done so. See [`set_device_pool`] for thread-locality.
pub fn get_device_pool(device_id: usize) -> Result<Option<Arc<MemPool>>, DeviceError> {
with_global_device_context(device_id, |device_context| device_context.pool.clone())
}

/// Resolve the custom memory pool associated with the device that owns `stream`.
///
/// Errors from the device-context lookup are downgraded to `None`; this is the
/// single choke-point for that decision so callers don't each re-derive it.
pub fn pool_for_stream(stream: &Arc<Stream>) -> Option<Arc<MemPool>> {
get_device_pool(stream.device().ordinal()).ok().flatten()
}

/// Run a closure with the scheduling policy of the current thread's default device.
///
/// This is the function called internally by
Expand Down
22 changes: 20 additions & 2 deletions cuda-async/src/device_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

//! Lazy, composable GPU operations and combinator types.

use crate::device_context::with_default_device_policy;
use crate::device_context::{pool_for_stream, with_default_device_policy};
use crate::device_future::DeviceFuture;
use crate::error::{device_error, DeviceError};
use crate::scheduling_policies::SchedulingPolicy;
use cuda_core::{Device, Stream};
use cuda_core::{Device, Stream, MemPool};
use std::cell::{Cell, UnsafeCell};
use std::fmt::Debug;
use std::future::IntoFuture;
Expand Down Expand Up @@ -61,16 +61,19 @@ pub struct ExecutionContext {
ordinal: DeviceOrdinal,
cuda_stream: Arc<Stream>,
device: Arc<Device>,
pool: Option<Arc<MemPool>>,
}

impl ExecutionContext {
pub fn new(cuda_stream: Arc<Stream>) -> Self {
let device = cuda_stream.device().clone();
let ordinal = device.ordinal();
let pool = pool_for_stream(&cuda_stream);
Self {
cuda_stream,
device,
ordinal,
pool,
}
}
pub fn get_cuda_stream(&self) -> &Arc<Stream> {
Expand All @@ -82,6 +85,21 @@ impl ExecutionContext {
pub fn get_device_id(&self) -> DeviceOrdinal {
self.ordinal
}
pub fn get_pool(&self) -> Option<&Arc<MemPool>> {
self.pool.as_ref()
}
/// Allocates device memory on this context's stream, using the custom pool if set.
///
/// # Safety
/// The stream must be valid and not destroyed.
pub unsafe fn alloc_async(&self, num_bytes: usize) -> cuda_core::sys::CUdeviceptr {
match &self.pool {
Some(pool) => {
cuda_core::malloc_from_pool_async(num_bytes, pool, &self.cuda_stream)
}
None => cuda_core::malloc_async(num_bytes, &self.cuda_stream),
}
}
#[expect(
dead_code,
reason = "kept for direct synchronous execution in tests and future blocking APIs"
Expand Down
Loading
Loading