Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ jobs:
# Install rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "${HOME}/.cargo/env"
rustup default nightly
wget https://apt.llvm.org/llvm.sh
bash ./llvm.sh 21
apt install -y --no-install-recommends libmlir-21-dev mlir-21-tools libpolly-21-dev
Expand All @@ -66,7 +65,7 @@ jobs:
- name: Clippy
run: |
source "${HOME}/.cargo/env"
cargo clippy
cargo clippy -- --deny clippy::all --allow clippy::missing-safety-doc --allow clippy::type_complexity

- name: Test (compile only)
run: |
Expand Down
7 changes: 2 additions & 5 deletions cuda-async/src/device_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum DeviceFutureState {
}

/// Shared state between a CUDA stream callback and the async waker.
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct StreamCallbackState {
pub(crate) waker: AtomicWaker,
pub(crate) complete: AtomicBool,
Expand All @@ -42,10 +42,7 @@ pub struct StreamCallbackState {
impl StreamCallbackState {
/// Creates a new callback state with the completion flag unset.
pub fn new() -> Self {
Self {
waker: AtomicWaker::new(),
complete: AtomicBool::new(false),
}
Self::default()
}
/// Marks the operation as complete and wakes the associated task.
pub fn signal(&self) {
Expand Down
22 changes: 13 additions & 9 deletions cuda-async/src/device_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ pub trait DeviceOp:
/// - `dup`, `copy_host_vec_to_device`
///
/// See [`Scope`](crate::cuda_graph::Scope) for the full safety proof.
pub trait GraphNode: DeviceOp {}
pub trait GraphNode {}

// Arc

Expand Down Expand Up @@ -987,10 +987,12 @@ where
if !self.computed.load(Ordering::Acquire) {
// Safety: This block is guaranteed to execute at most once.
// Put the input in a box so the pointer is dropped when this block exits.
let input = unsafe { (&mut *self.input.get()).take() }.ok_or(device_error(
context.get_device_id(),
"Select operation failed.",
))?;
let input = self.input.get();
let input = unsafe { input.as_mut() };
let input = input
.unwrap()
.take()
.ok_or_else(|| device_error(context.get_device_id(), "Select operation failed."))?;
let (left, right) = input.execute(context)?;
// Update internal state.
unsafe {
Expand All @@ -1002,12 +1004,14 @@ where
Ok(())
}
unsafe fn left(&self) -> T1 {
let left = unsafe { (&mut *self.left.get()).take() }.unwrap();
left
let cell = self.left.get();
let cell = unsafe { cell.as_mut() };
cell.unwrap().take().unwrap()
}
unsafe fn right(&self) -> T2 {
let right = unsafe { (&mut *self.right.get()).take() }.unwrap();
right
let cell = self.right.get();
let cell = unsafe { cell.as_mut() };
cell.unwrap().take().unwrap()
}
}

Expand Down
4 changes: 2 additions & 2 deletions cuda-async/src/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ impl Drop for AsyncKernelLaunch {
let _ = self
.args
.iter()
.map(|arg| {
.map(|&arg| {
// Reconstruct the boxes. Pointers will be dropped when they go out of scope.
unsafe { Box::from_raw(*arg) }
unsafe { Box::from_raw(arg as *mut usize) }
})
.collect::<Vec<_>>();
}
Expand Down
2 changes: 0 additions & 2 deletions cuda-core/src/cudarc_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ pub(crate) mod primary_ctx {
}

/// Low-level device query operations.

#[allow(dead_code)]
pub(crate) mod device {

Expand Down Expand Up @@ -191,7 +190,6 @@ pub(crate) mod ctx {
}

/// Low-level CUDA stream operations.

#[allow(dead_code)]
pub(crate) mod stream {
use super::{DriverError, IntoResult};
Expand Down
18 changes: 12 additions & 6 deletions cutile-ir/src/bytecode/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ pub struct EncodingWriter {

impl EncodingWriter {
pub fn new() -> Self {
Self {
buf: Vec::new(),
required_alignment: 1,
}
Self::default()
}

pub fn with_capacity(cap: usize) -> Self {
Self {
buf: Vec::with_capacity(cap),
required_alignment: 1,
..Self::default()
}
}

Expand Down Expand Up @@ -223,6 +220,15 @@ impl EncodingWriter {
}
}

impl Default for EncodingWriter {
fn default() -> Self {
Self {
buf: Default::default(),
required_alignment: 1,
}
}
}

/// Patch a `u32` value at `offset` in the buffer (little-endian).
pub fn patch_u32(buf: &mut [u8], offset: usize, value: u32) {
buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
Expand Down Expand Up @@ -270,7 +276,7 @@ fn convert_to_f8(
// Handle special values.
if f64_exp == 0x7FF {
// Inf or NaN
if f64_man != 0 || (nan_only_all_ones && f64_man == 0) {
if f64_man != 0 || nan_only_all_ones {
// NaN (or Inf mapped to NaN for formats without infinities)
if nan_only_all_ones {
return (sign << 7) | ((max_exp as u8) << man_bits) | man_mask;
Expand Down
4 changes: 2 additions & 2 deletions cutile-ir/src/ir/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ impl<'a> ModulePrinter<'a> {
let pad = " ".repeat(self.indent);

// Operands: [lb, ub, step, init_values...]
let lb = op.operands.get(0).map(|v| v.index());
let lb = op.operands.first().map(|v| v.index());
let ub = op.operands.get(1).map(|v| v.index());
let step = op.operands.get(2).map(|v| v.index());
let init_values = &op.operands[3.min(op.operands.len())..];
Expand Down Expand Up @@ -2601,7 +2601,7 @@ fn format_dense_i32_array(attr: &Attribute) -> String {
.collect();
format!("[{}]", elems.join(", "))
}
_ => format!("{}", format_attr(attr)),
_ => format_attr(attr).to_string(),
}
}

Expand Down
10 changes: 3 additions & 7 deletions cutile-ir/src/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,7 @@ fn strip_prefix_suffix<'a>(s: &'a str, prefix: &str, _suffix: &str) -> Option<&'
}
}
// No matching close — try without nesting (just strip last char if it's '>').
if after_prefix.ends_with('>') {
Some(&after_prefix[..after_prefix.len() - 1])
} else {
None
}
after_prefix.strip_suffix('>')
}

fn parse_scalar(s: &str) -> Option<ScalarType> {
Expand Down Expand Up @@ -314,7 +310,7 @@ fn parse_tile(inner: &str) -> Option<Type> {
before
.trim_end_matches('x')
.split('x')
.map(|d| parse_dim(d))
.map(parse_dim)
.collect()
};
let ptr_inner_start = ptr_start + "ptr<".len();
Expand Down Expand Up @@ -368,7 +364,7 @@ fn parse_tensor_view(inner: &str) -> Option<Type> {

let strides = if let Some(sp) = strides_part {
let sp = sp.trim_start_matches('[').trim_end_matches(']');
sp.split(',').map(|s| parse_dim(s)).collect()
sp.split(',').map(parse_dim).collect()
} else {
vec![DYNAMIC; shape.len()]
};
Expand Down
2 changes: 1 addition & 1 deletion cutile-macro/src/_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ pub fn trait_(mut item: ItemTrait) -> Result<TokenStream, Error> {
);
let res = match attributes {
Some(attributes)
if attributes.name_as_str().as_deref() == Some("cuda_tile :: variadic_trait") =>
if attributes.name_as_str() == Some("cuda_tile :: variadic_trait".into()) =>
{
desugar_variadic_trait_decl(&item)?
}
Expand Down
16 changes: 5 additions & 11 deletions cutile-macro/src/rank_instantiation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,9 +1109,7 @@ impl RankInstantiator {
/// Rewrite a free-fn signature (generics, args, return) and its body.
pub fn rewrite_function(mut self, item: &ItemFn) -> Result<ItemFn, Error> {
let mut item = item.clone();
if let Err(e) = rewrite_fn_sig(&mut item.sig, &self.bindings) {
return Err(e);
}
rewrite_fn_sig(&mut item.sig, &self.bindings)?;
self.visit_block_mut(&mut item.block);
self.into_result(item)
}
Expand All @@ -1125,9 +1123,7 @@ impl RankInstantiator {
Ok(t) => *item.self_ty = t,
Err(e) => return Err(e),
}
if let Err(e) = rewrite_generics_for_rank(&mut item.generics, &self.bindings) {
return Err(e);
}
rewrite_generics_for_rank(&mut item.generics, &self.bindings)?;
if let Some(trait_) = &mut item.trait_ {
let path = &mut trait_.1;
if path.segments.is_empty() {
Expand All @@ -1138,9 +1134,7 @@ impl RankInstantiator {
}
let last_seg = path.segments.last_mut().unwrap();
if let PathArguments::AngleBracketed(path_args) = &mut last_seg.arguments {
if let Err(e) = rewrite_generic_args_for_rank(path_args, &self.bindings) {
return Err(e);
}
rewrite_generic_args_for_rank(path_args, &self.bindings)?
}
}

Expand All @@ -1164,8 +1158,8 @@ impl RankInstantiator {
}
let mut result = fn_impl.clone();
self.rewrite_impl_method(&original_self_ty, &mut result);
if self.error.is_some() {
return Err(self.error.unwrap());
if let Some(error) = self.error {
return Err(error);
}
impl_items.push(ImplItem::Fn(result));
}
Expand Down
43 changes: 17 additions & 26 deletions cutile-macro/src/shadow_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,8 @@ impl RankPolyOpSpec {
}
// Rank-dependent non-shape arg types as trait generics (e.g. `Idx0`
// for `idx: [i32; N]`). Caller's array literal pins them.
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
all_trait_params.push(quote! { #id });
}
for id in self.rank_dep_arg_idents.iter().flatten() {
all_trait_params.push(quote! { #id });
}
if let Some(ref out) = extra_out_trait_param {
all_trait_params.push(out.clone());
Expand Down Expand Up @@ -565,7 +563,7 @@ impl RankPolyOpSpec {
let mut return_concrete = rewrite_ty_for_rank(&self.return_type, combo, &self.cgas);
for (orig, replacement) in self.dead_lifetimes.iter().zip(self.dead_lt_idents.iter()) {
return_concrete =
replace_lifetimes_with(&return_concrete, &[orig.clone()], replacement);
replace_lifetimes_with(&return_concrete, std::slice::from_ref(orig), replacement);
}

let mut trait_instantiation_args: Vec<TokenStream2> = Vec::new();
Expand Down Expand Up @@ -773,10 +771,8 @@ impl RankPolyOpSpec {
trait_args.push(quote! { #i });
}
// Rank-dep arg generics, matching trait declaration ordering.
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
trait_args.push(quote! { #id });
}
for id in self.rank_dep_arg_idents.iter().flatten() {
trait_args.push(quote! { #id });
}
if use_free_out {
trait_args.push(quote! { #out_ident });
Expand Down Expand Up @@ -827,10 +823,8 @@ impl RankPolyOpSpec {
for i in &extra_shape_generic_idents {
all_wrapper_generics.push(quote! { #i });
}
for slot in &self.rank_dep_arg_idents {
if let Some(id) = slot {
all_wrapper_generics.push(quote! { #id });
}
for id in self.rank_dep_arg_idents.iter().flatten() {
all_wrapper_generics.push(quote! { #id });
}
if use_free_out {
all_wrapper_generics.push(quote! { #out_ident });
Expand Down Expand Up @@ -1744,7 +1738,7 @@ pub fn desugar_variadic_trait_decl(item: &ItemTrait) -> Result<TokenStream2, Err
for param in &item.generics.params {
let drop_it = matches!(
param,
GenericParam::Const(c) if cga_idents.iter().any(|i| *i == c.ident)
GenericParam::Const(c) if cga_idents.contains(&c.ident)
);
if !drop_it {
new_params.push(param.clone());
Expand Down Expand Up @@ -1942,7 +1936,7 @@ fn emit_variadic_trait_impl_for_rank(
for param in &item.generics.params {
let skip = matches!(
param,
GenericParam::Const(c) if cga_idents.iter().any(|i| *i == c.ident)
GenericParam::Const(c) if cga_idents.contains(&c.ident)
);
if !skip {
all_impl_params.push(quote! { #param });
Expand Down Expand Up @@ -2012,7 +2006,7 @@ fn rewrite_trait_method_for_rank_poly(
.map(|(i, _)| i);
if let Some(i) = cga_idx {
if let CgaRole::ShapeBound { sh_ident } = &shape.roles[i] {
pt.ty = Box::new(syn::parse_quote! { #sh_ident });
*pt.ty = syn::parse_quote! { #sh_ident };
}
// Free CGAs aren't in args by definition (classify_cgas's
// post-condition), so reaching this branch with a Free role
Expand All @@ -2031,7 +2025,7 @@ fn rewrite_trait_method_for_rank_poly(
} else {
syn::parse_quote! { Self::Out }
};
*ret = Box::new(new_ret);
**ret = new_ret;
}
}
}
Expand Down Expand Up @@ -2110,13 +2104,13 @@ fn rewrite_impl_method_body_for_rank(
if let FnArg::Typed(pt) = arg {
let new_ty = rewrite_ty_for_rank(&pt.ty, combo, cgas);
let new_ty = bind_anon_lifetimes_to(&new_ty, recv_lt);
pt.ty = Box::new(new_ty);
*pt.ty = new_ty;
}
}
if let ReturnType::Type(_, ret) = &mut new_sig.output {
let new_ret = rewrite_ty_for_rank(ret, combo, cgas);
let new_ret = bind_anon_lifetimes_to(&new_ret, recv_lt);
*ret = Box::new(new_ret);
**ret = new_ret;
}
let muted_args: Vec<TokenStream2> = new_sig
.inputs
Expand Down Expand Up @@ -2212,10 +2206,8 @@ fn type_uses_lifetime(ty: &Type) -> bool {
for arg in ab.args.iter() {
match arg {
GenericArgument::Lifetime(_) => return true,
GenericArgument::Type(t) => {
if type_uses_lifetime(t) {
return true;
}
GenericArgument::Type(t) if type_uses_lifetime(t) => {
return true;
}
_ => {}
}
Expand All @@ -2236,11 +2228,10 @@ fn filter_cuda_tile_attrs(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
attrs
.iter()
.filter(|a| {
let path = a.path();
!path
a.path()
.segments
.first()
.is_some_and(|s| s.ident == "cuda_tile")
.is_none_or(|s| s.ident != "cuda_tile")
})
.cloned()
.collect()
Expand Down
Loading