Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,27 @@ jobs:
run: |
source "${HOME}/.cargo/env"
cargo build

- name: Format check
env:
CUDA_TOOLKIT_PATH: /usr/local/cuda-13
CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21
run: |
source "${HOME}/.cargo/env"
cargo fmt -- --check

- name: Clippy
env:
CUDA_TOOLKIT_PATH: /usr/local/cuda-13
CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21
run: |
source "${HOME}/.cargo/env"
cargo clippy

- name: Test (compile only)
env:
CUDA_TOOLKIT_PATH: /usr/local/cuda-13
CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21
run: |
source "${HOME}/.cargo/env"
cargo test --no-run
7 changes: 3 additions & 4 deletions cuda-async/src/device_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ impl<T: Send + ?Sized> Drop for DeviceBox<T> {
with_deallocator_stream(self._device_id, |stream| {
free_async(self.cudptr, stream);
})
.expect(
format!(
.unwrap_or_else(|_| {
panic!(
"Failed to free device pointer on device_id={}",
self._device_id
)
.as_str(),
)
})
}
}
}
Expand Down
7 changes: 2 additions & 5 deletions cuda-async/src/device_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,7 @@ where
/// Useful when you need to schedule operations on a specific device outside the
/// default `.await` / `.sync()` path.
pub fn global_policy(device_id: usize) -> Result<Arc<GlobalSchedulingPolicy>, DeviceError> {
with_global_device_context(device_id, |device_context| {
let policy = device_context.policy.clone();
policy
})
with_global_device_context(device_id, |device_context| device_context.policy.clone())
}

pub unsafe fn with_deallocator_stream<F, R>(device_id: usize, f: F) -> Result<R, DeviceError>
Expand Down Expand Up @@ -309,7 +306,7 @@ where
/// set_default_device(1);
/// let tensor = api::zeros([1024, 1024]).await; // runs on GPU 1
/// ```
pub fn set_default_device(default_device_id: usize) -> () {
pub fn set_default_device(default_device_id: usize) {
DEVICE_CONTEXTS.with(|ctx| {
ctx.default_device.set(default_device_id);
})
Expand Down
17 changes: 7 additions & 10 deletions cuda-async/src/device_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,13 @@ impl<T: Send, DO: DeviceOperation<Output = T>> Unpin for DeviceFuture<T, DO> {}
impl<T: Send, DO: DeviceOperation<Output = T>> Future for DeviceFuture<T, DO> {
type Output = Result<T, DeviceError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state {
DeviceFutureState::Failed => {
self.state = DeviceFutureState::Complete;
let error = self
.error
.take()
.expect("Failed state must carry an error.");
return Poll::Ready(Err(error));
}
_ => {}
if self.state == DeviceFutureState::Failed {
self.state = DeviceFutureState::Complete;
let error = self
.error
.take()
.expect("Failed state must carry an error.");
return Poll::Ready(Err(error));
}

// If this is being polled, it needs a waker.
Expand Down
10 changes: 5 additions & 5 deletions cuda-async/src/device_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ pub trait DeviceOperation:
) -> Result<<Self as DeviceOperation>::Output, DeviceError> {
let ctx = ExecutionContext::new(stream.clone());
// This is okay since we synchronize immediately.
let res = unsafe { self.execute(&ctx) };
res

unsafe { self.execute(&ctx) }
}
/// Execute on an **explicit stream** and block until the GPU finishes.
///
Expand Down Expand Up @@ -413,9 +413,9 @@ impl<T: Send> IntoDeviceOperation<T> for T {
value(self)
}
}
impl Into<Value<f32>> for f32 {
fn into(self) -> Value<f32> {
Value::new(self)
impl From<f32> for Value<f32> {
fn from(val: f32) -> Self {
Value::new(val)
}
}

Expand Down
8 changes: 4 additions & 4 deletions cuda-core/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub unsafe fn malloc_async(num_bytes: usize, stream: &Arc<CudaStream>) -> sys::C
///
/// # Safety
/// `dptr` must have been allocated with `malloc_async` and must not be used after this call.
pub unsafe fn free_async(dptr: sys::CUdeviceptr, stream: &Arc<CudaStream>) -> () {
pub unsafe fn free_async(dptr: sys::CUdeviceptr, stream: &Arc<CudaStream>) {
crate::memory::free_async(dptr, stream.cu_stream()).expect("Free async failed.")
}

Expand All @@ -93,7 +93,7 @@ pub unsafe fn memcpy_htod_async<T>(
src: *const T,
num_elements: usize,
stream: &Arc<CudaStream>,
) -> () {
) {
let num_bytes = num_elements * mem::size_of::<T>();
unsafe { crate::memory::memcpy_htod_async(dst, src, num_bytes, stream.cu_stream()) }
.expect("memcpy_htod_async failed.")
Expand All @@ -108,7 +108,7 @@ pub unsafe fn memcpy_dtoh_async<T>(
src: sys::CUdeviceptr,
num_elements: usize,
stream: &Arc<CudaStream>,
) -> () {
) {
let num_bytes = num_elements * mem::size_of::<T>();
unsafe { crate::memory::memcpy_dtoh_async(dst, src, num_bytes, stream.cu_stream()) }
.expect("memcpy_dtoh_async failed.")
Expand All @@ -123,7 +123,7 @@ pub unsafe fn memcpy_dtod_async<T>(
src: sys::CUdeviceptr,
num_elements: usize,
stream: &Arc<CudaStream>,
) -> () {
) {
let num_bytes = num_elements * mem::size_of::<T>();
unsafe { crate::memory::memcpy_dtod_async(dst, src, num_bytes, stream.cu_stream()) }
.expect("memcpy_dtod_async failed.")
Expand Down
14 changes: 7 additions & 7 deletions cuda-core/src/cudarc_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl CudaContext {
/// Records an error into the context's error state if the result is `Err`.
pub fn record_err<T>(&self, result: Result<T, DriverError>) {
if let Err(err) = result {
self.error_state.store(err.0 as u32, Ordering::Relaxed)
self.error_state.store(err.0, Ordering::Relaxed)
}
}
}
Expand Down Expand Up @@ -814,7 +814,7 @@ pub mod ctx {

/// Sets flags on the current context.
pub fn set_flags(flags: cuda_bindings::CUctx_flags) -> Result<(), DriverError> {
unsafe { cuda_bindings::cuCtxSetFlags(flags as u32).result() }
unsafe { cuda_bindings::cuCtxSetFlags(flags).result() }
}

/// Blocks until all work in the current context is complete.
Expand Down Expand Up @@ -859,7 +859,7 @@ pub mod stream {
pub fn create(kind: StreamKind) -> Result<cuda_bindings::CUstream, DriverError> {
let mut stream = MaybeUninit::uninit();
unsafe {
cuda_bindings::cuStreamCreate(stream.as_mut_ptr(), kind.flags() as u32).result()?;
cuda_bindings::cuStreamCreate(stream.as_mut_ptr(), kind.flags()).result()?;
Ok(stream.assume_init())
}
}
Expand Down Expand Up @@ -889,7 +889,7 @@ pub mod stream {
event: cuda_bindings::CUevent,
flags: cuda_bindings::CUevent_wait_flags,
) -> Result<(), DriverError> {
cuda_bindings::cuStreamWaitEvent(stream, event, flags as u32).result()
cuda_bindings::cuStreamWaitEvent(stream, event, flags).result()
}

/// Attaches memory to a stream for managed memory visibility.
Expand All @@ -902,7 +902,7 @@ pub mod stream {
num_bytes: usize,
flags: cuda_bindings::CUmemAttach_flags,
) -> Result<(), DriverError> {
cuda_bindings::cuStreamAttachMemAsync(stream, dptr, num_bytes, flags as u32).result()
cuda_bindings::cuStreamAttachMemAsync(stream, dptr, num_bytes, flags).result()
}

/// Enqueues a host function callback on the stream.
Expand Down Expand Up @@ -1028,7 +1028,7 @@ pub mod event {
) -> Result<cuda_bindings::CUevent, DriverError> {
let mut event = MaybeUninit::uninit();
unsafe {
cuda_bindings::cuEventCreate(event.as_mut_ptr(), flags as u32).result()?;
cuda_bindings::cuEventCreate(event.as_mut_ptr(), flags).result()?;
Ok(event.assume_init())
}
}
Expand Down Expand Up @@ -1125,7 +1125,7 @@ pub mod memory {
flags: sys::CUmemAttach_flags,
) -> Result<sys::CUdeviceptr, DriverError> {
let mut dev_ptr = MaybeUninit::uninit();
sys::cuMemAllocManaged(dev_ptr.as_mut_ptr(), num_bytes, flags as u32).result()?;
sys::cuMemAllocManaged(dev_ptr.as_mut_ptr(), num_bytes, flags).result()?;
Ok(dev_ptr.assume_init())
}

Expand Down
5 changes: 2 additions & 3 deletions cuda-tile-rs/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ fn main() {
llvm_path.display()
);
println!(
"cargo:rustc-env={}={}",
"TABLEGEN_210_PREFIX",
"cargo:rustc-env=TABLEGEN_210_PREFIX={}",
llvm_path.display()
);

Expand All @@ -138,7 +137,7 @@ fn main() {
LLVM_LIB_PATH_VAR,
llvm_lib_path.display()
);
println!("cargo:warning={}", "Defaultling to download mode.");
println!("cargo:warning=Defaultling to download mode.");
}
}

Expand Down
2 changes: 1 addition & 1 deletion cuda-tile-rs/examples/build_translate_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
*/

use cuda_tile_rs::util::{operation_parse, parse_named_attr};
use melior::Context;
use melior::dialect::DialectRegistry;
use melior::ir::attribute::{StringAttribute, TypeAttribute};
use melior::ir::r#type::FunctionType;
use melior::ir::{
Attribute, Block, BlockLike, Identifier, Location, Region, RegionLike, Type, Value, ValueLike,
};
use melior::utility::{register_all_dialects, register_all_llvm_translations};
use melior::Context;
use std::error::Error;
use std::process::Command;

Expand Down
4 changes: 2 additions & 2 deletions cuda-tile-rs/src/cuda_tile_c_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ unsafe extern "C" {
unsafe extern "C" {
#[doc = " Returns a cuda_tile RoundingModeAttr with the given rounding mode string."]
pub fn mlirCudaTileRoundingModeAttrGet(ctx: MlirContext, value: MlirStringRef)
-> MlirAttribute;
-> MlirAttribute;
}
unsafe extern "C" {
#[doc = " Returns the rounding mode string of the given cuda_tile RoundingModeAttr."]
Expand Down Expand Up @@ -287,7 +287,7 @@ unsafe extern "C" {
unsafe extern "C" {
#[doc = " Returns a cuda_tile PaddingValueAttr with the given padding value string."]
pub fn mlirCudaTilePaddingValueAttrGet(ctx: MlirContext, value: MlirStringRef)
-> MlirAttribute;
-> MlirAttribute;
}
unsafe extern "C" {
#[doc = " Returns the padding value string of the given cuda_tile PaddingValueAttr."]
Expand Down
4 changes: 2 additions & 2 deletions cuda-tile-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ mod tests {

use crate::cuda_tile::{self};
use crate::util::{attribute_parse, operation_parse, type_parse};
use melior::Context;
use melior::dialect::DialectRegistry;
use melior::ir::RegionLike;
use melior::ir::attribute::StringAttribute;
use melior::ir::operation::{OperationBuilder, OperationLike};
use melior::ir::RegionLike;
use melior::ir::{Attribute, Block, Identifier, Location, Module, Region};
use melior::utility::{register_all_dialects, register_all_llvm_translations};
use melior::Context;

static TEST_MUTEX: Mutex<()> = Mutex::new(());
static REGISTER_GLOBALS: Once = Once::new();
Expand Down
10 changes: 5 additions & 5 deletions cuda-tile-rs/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
*/

use melior::{
Context, StringRef,
ir::{Attribute, Identifier, Operation, Type, operation::OperationLike},
ir::{operation::OperationLike, Attribute, Identifier, Operation, Type},
pass::PassManager,
Context, StringRef,
};
use mlir_sys::mlirPassManagerRunOnOp;
use mlir_sys::{mlirAttributeParseGet, mlirOperationCreateParse};
Expand All @@ -21,7 +21,7 @@ pub fn operation_parse<'c>(
) -> Option<Operation<'c>> {
let source = CString::new(source).unwrap();
let source = StringRef::from_c_str(&source);
let source_name = CString::new(source_name.unwrap_or_else(|| "sourceName")).unwrap();
let source_name = CString::new(source_name.unwrap_or("sourceName")).unwrap();
let source_name_ref = StringRef::from_c_str(&source_name);
unsafe {
Operation::from_option_raw(mlirOperationCreateParse(
Expand Down Expand Up @@ -50,10 +50,10 @@ pub fn parse_named_attr<'c>(
name: &str,
attr_str: &str,
) -> (Identifier<'c>, Attribute<'c>) {
let Some(attr) = Attribute::parse(&context, attr_str) else {
let Some(attr) = Attribute::parse(context, attr_str) else {
panic!("Failed to parse named attribute {name} = {attr_str}");
};
(Identifier::new(&context, name), attr)
(Identifier::new(context, name), attr)
}

pub fn execute_pass_manager(
Expand Down
6 changes: 2 additions & 4 deletions cutile-compiler/src/compiler/compile_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,11 @@ impl<'m, 'c> CUDATileFunctionCompiler<'m> {
};
}
Stmt::Item(item) => {
let mut binding_name: Option<String> = None;
let mut ct_ty: Option<TileRustType> = None;
match item {
Item::Const(const_item) => {
// This is like a let binding.
binding_name = Some(const_item.ident.to_string());
ct_ty = self.compile_type(
let binding_name: Option<String> = Some(const_item.ident.to_string());
let ct_ty: Option<TileRustType> = self.compile_type(
&*const_item.ty,
generic_args,
&HashMap::new(),
Expand Down
6 changes: 3 additions & 3 deletions cutile-examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ pub fn fmha_ref_exec<T: WithDType>(
let qk_scaled = qk.mul(&sm_scale_tensor).expect("Failed to scale qk.");
let qk_softmax = softmax(&qk_scaled, 3).expect("Failed to softmax qk.");

let qkv = qk_softmax
// (m x m) @ (m x d)
qk_softmax
.broadcast_matmul(&v_host)
.expect("Failed to execute qk @ v."); // (m x m) @ (m x d)
qkv // (b, h, m, d)
.expect("Failed to execute qk @ v.") // (b, h, m, d)
}

/// Computes the theoretical peak (speed-of-light) tensor core TFLOPS for a Blackwell GPU.
Expand Down
Loading