diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 30bc8884..0e0b96a4 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -54,3 +54,27 @@ jobs: run: | source "${HOME}/.cargo/env" cargo build + + - name: Format check + env: + CUDA_TOOLKIT_PATH: /usr/local/cuda-13 + CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21 + run: | + source "${HOME}/.cargo/env" + cargo fmt -- --check + + - name: Clippy + env: + CUDA_TOOLKIT_PATH: /usr/local/cuda-13 + CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21 + run: | + source "${HOME}/.cargo/env" + cargo clippy + + - name: Test (compile only) + env: + CUDA_TOOLKIT_PATH: /usr/local/cuda-13 + CUDA_TILE_USE_LLVM_INSTALL_DIR: /usr/lib/llvm-21 + run: | + source "${HOME}/.cargo/env" + cargo test --no-run diff --git a/cuda-async/src/device_box.rs b/cuda-async/src/device_box.rs index 90c5fc37..c896df79 100644 --- a/cuda-async/src/device_box.rs +++ b/cuda-async/src/device_box.rs @@ -53,13 +53,12 @@ impl Drop for DeviceBox { with_deallocator_stream(self._device_id, |stream| { free_async(self.cudptr, stream); }) - .expect( - format!( + .unwrap_or_else(|_| { + panic!( "Failed to free device pointer on device_id={}", self._device_id ) - .as_str(), - ) + }) } } } diff --git a/cuda-async/src/device_context.rs b/cuda-async/src/device_context.rs index f29a727c..916704d5 100644 --- a/cuda-async/src/device_context.rs +++ b/cuda-async/src/device_context.rs @@ -272,10 +272,7 @@ where /// Useful when you need to schedule operations on a specific device outside the /// default `.await` / `.sync()` path. pub fn global_policy(device_id: usize) -> Result, DeviceError> { - with_global_device_context(device_id, |device_context| { - let policy = device_context.policy.clone(); - policy - }) + with_global_device_context(device_id, |device_context| device_context.policy.clone()) } pub unsafe fn with_deallocator_stream(device_id: usize, f: F) -> Result @@ -309,7 +306,7 @@ where /// set_default_device(1); /// let tensor = api::zeros([1024, 1024]).await; // runs on GPU 1 /// ``` -pub fn set_default_device(default_device_id: usize) -> () { +pub fn set_default_device(default_device_id: usize) { DEVICE_CONTEXTS.with(|ctx| { ctx.default_device.set(default_device_id); }) diff --git a/cuda-async/src/device_future.rs b/cuda-async/src/device_future.rs index 17ddd379..8084656f 100644 --- a/cuda-async/src/device_future.rs +++ b/cuda-async/src/device_future.rs @@ -138,16 +138,13 @@ impl> Unpin for DeviceFuture {} impl> Future for DeviceFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.state { - DeviceFutureState::Failed => { - self.state = DeviceFutureState::Complete; - let error = self - .error - .take() - .expect("Failed state must carry an error."); - return Poll::Ready(Err(error)); - } - _ => {} + if self.state == DeviceFutureState::Failed { + self.state = DeviceFutureState::Complete; + let error = self + .error + .take() + .expect("Failed state must carry an error."); + return Poll::Ready(Err(error)); } // If this is being polled, it needs a waker. diff --git a/cuda-async/src/device_operation.rs b/cuda-async/src/device_operation.rs index 68e9bbd4..33af8dc3 100644 --- a/cuda-async/src/device_operation.rs +++ b/cuda-async/src/device_operation.rs @@ -203,8 +203,8 @@ pub trait DeviceOperation: ) -> Result<::Output, DeviceError> { let ctx = ExecutionContext::new(stream.clone()); // This is okay since we synchronize immediately. - let res = unsafe { self.execute(&ctx) }; - res + + unsafe { self.execute(&ctx) } } /// Execute on an **explicit stream** and block until the GPU finishes. /// @@ -413,9 +413,9 @@ impl IntoDeviceOperation for T { value(self) } } -impl Into> for f32 { - fn into(self) -> Value { - Value::new(self) +impl From for Value { + fn from(val: f32) -> Self { + Value::new(val) } } diff --git a/cuda-core/src/api.rs b/cuda-core/src/api.rs index 590d8289..e3d1014c 100644 --- a/cuda-core/src/api.rs +++ b/cuda-core/src/api.rs @@ -80,7 +80,7 @@ pub unsafe fn malloc_async(num_bytes: usize, stream: &Arc) -> sys::C /// /// # Safety /// `dptr` must have been allocated with `malloc_async` and must not be used after this call. -pub unsafe fn free_async(dptr: sys::CUdeviceptr, stream: &Arc) -> () { +pub unsafe fn free_async(dptr: sys::CUdeviceptr, stream: &Arc) { crate::memory::free_async(dptr, stream.cu_stream()).expect("Free async failed.") } @@ -93,7 +93,7 @@ pub unsafe fn memcpy_htod_async( src: *const T, num_elements: usize, stream: &Arc, -) -> () { +) { let num_bytes = num_elements * mem::size_of::(); unsafe { crate::memory::memcpy_htod_async(dst, src, num_bytes, stream.cu_stream()) } .expect("memcpy_htod_async failed.") @@ -108,7 +108,7 @@ pub unsafe fn memcpy_dtoh_async( src: sys::CUdeviceptr, num_elements: usize, stream: &Arc, -) -> () { +) { let num_bytes = num_elements * mem::size_of::(); unsafe { crate::memory::memcpy_dtoh_async(dst, src, num_bytes, stream.cu_stream()) } .expect("memcpy_dtoh_async failed.") @@ -123,7 +123,7 @@ pub unsafe fn memcpy_dtod_async( src: sys::CUdeviceptr, num_elements: usize, stream: &Arc, -) -> () { +) { let num_bytes = num_elements * mem::size_of::(); unsafe { crate::memory::memcpy_dtod_async(dst, src, num_bytes, stream.cu_stream()) } .expect("memcpy_dtod_async failed.") diff --git a/cuda-core/src/cudarc_shim.rs b/cuda-core/src/cudarc_shim.rs index a51a237f..4a2a01de 100644 --- a/cuda-core/src/cudarc_shim.rs +++ b/cuda-core/src/cudarc_shim.rs @@ -202,7 +202,7 @@ impl CudaContext { /// Records an error into the context's error state if the result is `Err`. pub fn record_err(&self, result: Result) { if let Err(err) = result { - self.error_state.store(err.0 as u32, Ordering::Relaxed) + self.error_state.store(err.0, Ordering::Relaxed) } } } @@ -814,7 +814,7 @@ pub mod ctx { /// Sets flags on the current context. pub fn set_flags(flags: cuda_bindings::CUctx_flags) -> Result<(), DriverError> { - unsafe { cuda_bindings::cuCtxSetFlags(flags as u32).result() } + unsafe { cuda_bindings::cuCtxSetFlags(flags).result() } } /// Blocks until all work in the current context is complete. @@ -859,7 +859,7 @@ pub mod stream { pub fn create(kind: StreamKind) -> Result { let mut stream = MaybeUninit::uninit(); unsafe { - cuda_bindings::cuStreamCreate(stream.as_mut_ptr(), kind.flags() as u32).result()?; + cuda_bindings::cuStreamCreate(stream.as_mut_ptr(), kind.flags()).result()?; Ok(stream.assume_init()) } } @@ -889,7 +889,7 @@ pub mod stream { event: cuda_bindings::CUevent, flags: cuda_bindings::CUevent_wait_flags, ) -> Result<(), DriverError> { - cuda_bindings::cuStreamWaitEvent(stream, event, flags as u32).result() + cuda_bindings::cuStreamWaitEvent(stream, event, flags).result() } /// Attaches memory to a stream for managed memory visibility. @@ -902,7 +902,7 @@ pub mod stream { num_bytes: usize, flags: cuda_bindings::CUmemAttach_flags, ) -> Result<(), DriverError> { - cuda_bindings::cuStreamAttachMemAsync(stream, dptr, num_bytes, flags as u32).result() + cuda_bindings::cuStreamAttachMemAsync(stream, dptr, num_bytes, flags).result() } /// Enqueues a host function callback on the stream. @@ -1028,7 +1028,7 @@ pub mod event { ) -> Result { let mut event = MaybeUninit::uninit(); unsafe { - cuda_bindings::cuEventCreate(event.as_mut_ptr(), flags as u32).result()?; + cuda_bindings::cuEventCreate(event.as_mut_ptr(), flags).result()?; Ok(event.assume_init()) } } @@ -1125,7 +1125,7 @@ pub mod memory { flags: sys::CUmemAttach_flags, ) -> Result { let mut dev_ptr = MaybeUninit::uninit(); - sys::cuMemAllocManaged(dev_ptr.as_mut_ptr(), num_bytes, flags as u32).result()?; + sys::cuMemAllocManaged(dev_ptr.as_mut_ptr(), num_bytes, flags).result()?; Ok(dev_ptr.assume_init()) } diff --git a/cuda-tile-rs/build.rs b/cuda-tile-rs/build.rs index 65e08c9f..8fb452dc 100644 --- a/cuda-tile-rs/build.rs +++ b/cuda-tile-rs/build.rs @@ -120,8 +120,7 @@ fn main() { llvm_path.display() ); println!( - "cargo:rustc-env={}={}", - "TABLEGEN_210_PREFIX", + "cargo:rustc-env=TABLEGEN_210_PREFIX={}", llvm_path.display() ); @@ -138,7 +137,7 @@ fn main() { LLVM_LIB_PATH_VAR, llvm_lib_path.display() ); - println!("cargo:warning={}", "Defaultling to download mode."); + println!("cargo:warning=Defaultling to download mode."); } } diff --git a/cuda-tile-rs/examples/build_translate_basic.rs b/cuda-tile-rs/examples/build_translate_basic.rs index 949790be..1132e94d 100644 --- a/cuda-tile-rs/examples/build_translate_basic.rs +++ b/cuda-tile-rs/examples/build_translate_basic.rs @@ -4,7 +4,6 @@ */ use cuda_tile_rs::util::{operation_parse, parse_named_attr}; -use melior::Context; use melior::dialect::DialectRegistry; use melior::ir::attribute::{StringAttribute, TypeAttribute}; use melior::ir::r#type::FunctionType; @@ -12,6 +11,7 @@ use melior::ir::{ Attribute, Block, BlockLike, Identifier, Location, Region, RegionLike, Type, Value, ValueLike, }; use melior::utility::{register_all_dialects, register_all_llvm_translations}; +use melior::Context; use std::error::Error; use std::process::Command; diff --git a/cuda-tile-rs/src/cuda_tile_c_bindings.rs b/cuda-tile-rs/src/cuda_tile_c_bindings.rs index 07bc0628..00853ca1 100644 --- a/cuda-tile-rs/src/cuda_tile_c_bindings.rs +++ b/cuda-tile-rs/src/cuda_tile_c_bindings.rs @@ -201,7 +201,7 @@ unsafe extern "C" { unsafe extern "C" { #[doc = " Returns a cuda_tile RoundingModeAttr with the given rounding mode string."] pub fn mlirCudaTileRoundingModeAttrGet(ctx: MlirContext, value: MlirStringRef) - -> MlirAttribute; + -> MlirAttribute; } unsafe extern "C" { #[doc = " Returns the rounding mode string of the given cuda_tile RoundingModeAttr."] @@ -287,7 +287,7 @@ unsafe extern "C" { unsafe extern "C" { #[doc = " Returns a cuda_tile PaddingValueAttr with the given padding value string."] pub fn mlirCudaTilePaddingValueAttrGet(ctx: MlirContext, value: MlirStringRef) - -> MlirAttribute; + -> MlirAttribute; } unsafe extern "C" { #[doc = " Returns the padding value string of the given cuda_tile PaddingValueAttr."] diff --git a/cuda-tile-rs/src/lib.rs b/cuda-tile-rs/src/lib.rs index 40079714..db1d7404 100644 --- a/cuda-tile-rs/src/lib.rs +++ b/cuda-tile-rs/src/lib.rs @@ -80,13 +80,13 @@ mod tests { use crate::cuda_tile::{self}; use crate::util::{attribute_parse, operation_parse, type_parse}; - use melior::Context; use melior::dialect::DialectRegistry; - use melior::ir::RegionLike; use melior::ir::attribute::StringAttribute; use melior::ir::operation::{OperationBuilder, OperationLike}; + use melior::ir::RegionLike; use melior::ir::{Attribute, Block, Identifier, Location, Module, Region}; use melior::utility::{register_all_dialects, register_all_llvm_translations}; + use melior::Context; static TEST_MUTEX: Mutex<()> = Mutex::new(()); static REGISTER_GLOBALS: Once = Once::new(); diff --git a/cuda-tile-rs/src/util.rs b/cuda-tile-rs/src/util.rs index 165a7ced..7baec37c 100644 --- a/cuda-tile-rs/src/util.rs +++ b/cuda-tile-rs/src/util.rs @@ -4,9 +4,9 @@ */ use melior::{ - Context, StringRef, - ir::{Attribute, Identifier, Operation, Type, operation::OperationLike}, + ir::{operation::OperationLike, Attribute, Identifier, Operation, Type}, pass::PassManager, + Context, StringRef, }; use mlir_sys::mlirPassManagerRunOnOp; use mlir_sys::{mlirAttributeParseGet, mlirOperationCreateParse}; @@ -21,7 +21,7 @@ pub fn operation_parse<'c>( ) -> Option> { let source = CString::new(source).unwrap(); let source = StringRef::from_c_str(&source); - let source_name = CString::new(source_name.unwrap_or_else(|| "sourceName")).unwrap(); + let source_name = CString::new(source_name.unwrap_or("sourceName")).unwrap(); let source_name_ref = StringRef::from_c_str(&source_name); unsafe { Operation::from_option_raw(mlirOperationCreateParse( @@ -50,10 +50,10 @@ pub fn parse_named_attr<'c>( name: &str, attr_str: &str, ) -> (Identifier<'c>, Attribute<'c>) { - let Some(attr) = Attribute::parse(&context, attr_str) else { + let Some(attr) = Attribute::parse(context, attr_str) else { panic!("Failed to parse named attribute {name} = {attr_str}"); }; - (Identifier::new(&context, name), attr) + (Identifier::new(context, name), attr) } pub fn execute_pass_manager( diff --git a/cutile-compiler/src/compiler/compile_block.rs b/cutile-compiler/src/compiler/compile_block.rs index 8eae400d..8cce1909 100644 --- a/cutile-compiler/src/compiler/compile_block.rs +++ b/cutile-compiler/src/compiler/compile_block.rs @@ -279,13 +279,11 @@ impl<'m, 'c> CUDATileFunctionCompiler<'m> { }; } Stmt::Item(item) => { - let mut binding_name: Option = None; - let mut ct_ty: Option = None; match item { Item::Const(const_item) => { // This is like a let binding. - binding_name = Some(const_item.ident.to_string()); - ct_ty = self.compile_type( + let binding_name: Option = Some(const_item.ident.to_string()); + let ct_ty: Option = self.compile_type( &*const_item.ty, generic_args, &HashMap::new(), diff --git a/cutile-examples/src/lib.rs b/cutile-examples/src/lib.rs index 9229ac17..ad88d32d 100644 --- a/cutile-examples/src/lib.rs +++ b/cutile-examples/src/lib.rs @@ -51,10 +51,10 @@ pub fn fmha_ref_exec( let qk_scaled = qk.mul(&sm_scale_tensor).expect("Failed to scale qk."); let qk_softmax = softmax(&qk_scaled, 3).expect("Failed to softmax qk."); - let qkv = qk_softmax + // (m x m) @ (m x d) + qk_softmax .broadcast_matmul(&v_host) - .expect("Failed to execute qk @ v."); // (m x m) @ (m x d) - qkv // (b, h, m, d) + .expect("Failed to execute qk @ v.") // (b, h, m, d) } /// Computes the theoretical peak (speed-of-light) tensor core TFLOPS for a Blackwell GPU. diff --git a/cutile-macro/src/_module.rs b/cutile-macro/src/_module.rs index 3beb4e15..c0ef4fdf 100644 --- a/cutile-macro/src/_module.rs +++ b/cutile-macro/src/_module.rs @@ -227,7 +227,7 @@ fn module_inner( for item in &content.1 { match item { syn::Item::Use(use_item) => { - concrete_items.push(use_item.to_token_stream().into()); + concrete_items.push(use_item.to_token_stream()); // Include module_ast dependency as part of the export. if !is_core { // println!("{use_item:#?}"); @@ -246,26 +246,26 @@ fn module_inner( let module_ast_call_str = format!( "{}::{}()", module_ast_use_path.last().unwrap(), - get_asts_ident().to_string() + get_asts_ident() ); module_ast_calls.push(module_ast_call_str); let module_ast_use_path_str = format!("use {};", module_ast_use_path.join("::")); let module_ast_use_path_item = syn::parse::(module_ast_use_path_str.parse().unwrap()).unwrap(); - concrete_items.push(module_ast_use_path_item.to_token_stream().into()); + concrete_items.push(module_ast_use_path_item.to_token_stream()); } } syn::Item::Fn(function_item) => { let entry_attrs = get_meta_list( - format!("{} :: entry", tile_rust_crate_root.to_string()).as_str(), + format!("{} :: entry", tile_rust_crate_root).as_str(), &function_item.attrs, ); if entry_attrs.is_some() { - entry_functions.push(kernel_launcher(name, &function_item)?); + entry_functions.push(kernel_launcher(name, function_item)?); }; ast_content.push(Item::Fn(function_item.clone())); - concrete_items.push(function(function_item.clone(), &tile_rust_crate_root)?); + concrete_items.push(function(function_item.clone(), tile_rust_crate_root)?); } syn::Item::Struct(struct_item) => { ast_content.push(Item::Struct(struct_item.clone())); @@ -281,7 +281,7 @@ fn module_inner( concrete_items.push(trait_(item_clone)?.into()); } syn::Item::Type(type_item) => { - concrete_items.push(type_item.to_token_stream().into()); + concrete_items.push(type_item.to_token_stream()); } syn::Item::Impl(impl_item) => { if !is_core { @@ -297,22 +297,22 @@ fn module_inner( } ast_content.push(Item::Macro(macro_item.clone())); let item_clone = macro_item.clone(); - concrete_items.push(item_clone.to_token_stream().into()); + concrete_items.push(item_clone.to_token_stream()); } other => { return other.err("Unsupported item type in module."); } } } - let ast_path = get_ast_path(&tile_rust_crate_root); + let ast_path = get_ast_path(tile_rust_crate_root); let ast_module_item: ItemMod = module_item.clone(); let ast_module_tokens = module_asts( ast_module_item, module_ast_calls, - &tile_rust_crate_root, + tile_rust_crate_root, raw_item_source, ); - let res = if entry_functions.len() == 0 { + let res = if entry_functions.is_empty() { quote! { pub mod #name { #![allow(nonstandard_style)] @@ -585,7 +585,7 @@ pub fn function(mut item: ItemFn, tile_rust_crate_root: &Ident) -> Result Result Result Result for Error { impl Error { pub fn to_compile_error(&self) -> TokenStream2 { match self { - Self::Syn(err) => err.to_compile_error().into(), + Self::Syn(err) => err.to_compile_error(), } } } diff --git a/cutile-macro/src/kernel_launcher_generator.rs b/cutile-macro/src/kernel_launcher_generator.rs index 5202fd02..bb089e75 100644 --- a/cutile-macro/src/kernel_launcher_generator.rs +++ b/cutile-macro/src/kernel_launcher_generator.rs @@ -161,7 +161,7 @@ impl RequiredGenerics { let mut type_params = vec![]; for name in &self.names { let is_launcher_type_param = self.launcher_type_params.contains(name); - if is_launcher_type_param && self.get_ty(&name) == SupportedGenericType::TypeParam { + if is_launcher_type_param && self.get_ty(name) == SupportedGenericType::TypeParam { type_params.push(format!("{}: Send + WithDType", name.clone())); } } @@ -172,8 +172,8 @@ impl RequiredGenerics { let mut type_params = vec![]; for name in &self.names { let is_launcher_type_param = self.launcher_type_params.contains(name); - if is_launcher_type_param && self.get_ty(&name) == SupportedGenericType::TypeParam { - type_params.push(format!("{}", name.clone())); + if is_launcher_type_param && self.get_ty(name) == SupportedGenericType::TypeParam { + type_params.push(name.clone().to_string()); } } syn::parse2::( @@ -203,7 +203,7 @@ impl RequiredGenerics { /// - `["a", "b"]` → `"(a, b)"` /// - `["a", "b", "c"]` → `"(a, (b, c))"` pub fn join_as_cons_tuple(vals: &Vec) -> String { - if vals.len() == 0 { + if vals.is_empty() { return "()".to_string(); } if vals.len() == 1 { @@ -233,7 +233,7 @@ fn zippable(expr: &str, wrap_as_val: bool) -> String { if !wrap_as_val { return expr.to_string(); } - return format!("value({})", expr); + format!("value({})", expr) } /// Generates async code to zip inputs into a cons-cell tuple structure. @@ -261,7 +261,7 @@ pub fn zip_cons(inputs: &Vec, var_name: &str, wrap_as_val: bool) -> Expr let mut zip_block = syn::parse2::(quote! {{ }}) .unwrap(); - if inputs.len() == 0 { + if inputs.is_empty() { return zip_block; } let mut i = inputs.len() - 1; @@ -308,7 +308,7 @@ pub fn zip_and_then_flatten(inputs: &Vec, var_name: &str, wrap_as_val: b let mut zip_block = syn::parse2::(quote! {{ }}) .unwrap(); - if inputs.len() == 0 { + if inputs.is_empty() { zip_block .block .stmts @@ -385,7 +385,7 @@ pub fn generate_launcher_arg_types( launcher_args_name: &str, ) -> (Type, TokenStream2) { let launcher_args_ident = Ident::new(launcher_args_name, Span::call_site()); - let launcher_args_type: Type = if generic_args.args.len() > 0 { + let launcher_args_type: Type = if !generic_args.args.is_empty() { parse_quote! { #launcher_args_ident #generic_args } } else { parse_quote! { #launcher_args_ident } @@ -496,7 +496,7 @@ pub fn generate_kernel_launcher( // Added support for * mut T to allow for unsafe kernels. match ty { Type::Reference(ref_ty) => { - let res = get_tensor_code(i, var_name, &ref_ty, &mut required_generics)?; + let res = get_tensor_code(i, var_name, ref_ty, &mut required_generics)?; arg_types.push(res.fn_arg.ty.as_ref().clone()); stride_args.push(res.stride_expr_str); builder_statements.extend(res.builder_statements); @@ -589,7 +589,7 @@ pub fn generate_kernel_launcher( r#"let {param_names_tuple_str}: {launcher_args_type_str} = input.execute(ctx)?;"# ))); - if required_generics.names.len() > 0 { + if !required_generics.names.is_empty() { launcher_method.block.stmts.push(parse_stmt(format!( r#" let function_generics: Vec = if self.function_generics.is_some() {{ @@ -859,7 +859,7 @@ fn get_tensor_code( let Some(type_ident) = type_ident else { return ty.err("Expected a named type identifier for tensor parameter."); }; - if type_ident.to_string() != "Tensor" { + if type_ident != "Tensor" { return ty.err(&format!("Expected Tensor type, got {}.", type_ident)); } let Some(GenericArgument::Type(syn::Type::Path(element_type_path))) = @@ -1018,111 +1018,105 @@ pub fn infer_shape_params_from_tensor_type( match generic_arg { GenericArgument::Type(type_param) => { // Currently, this is either shape or element_type. - match type_param { - syn::Type::Path(type_path) => { - let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); - match required_generics.get_ty(&last_ident) { - SupportedGenericType::TypeParam => { - // This is an element type. - required_generics - .launcher_type_params - .push(last_ident.clone()); - required_generics.expressions.insert( - last_ident.clone(), - Some(format!("vec![{var_name}.dtype().as_str().to_string()]")), - ); - } - SupportedGenericType::ConstArray => { - // This is a CGA type. - if is_mutable { - required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.partition_shape.iter().map(|x| x.to_string()).collect::>()"))); - } else { - // This might make sense for a small tensor. - required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.shape.iter().map(|x| x.to_string()).collect::>()"))); - } - } - SupportedGenericType::ConstScalar => { - return type_path.err( - "Unexpected constant scalar type in tensor generic argument.", - ); + if let syn::Type::Path(type_path) = type_param { + let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); + match required_generics.get_ty(&last_ident) { + SupportedGenericType::TypeParam => { + // This is an element type. + required_generics + .launcher_type_params + .push(last_ident.clone()); + required_generics.expressions.insert( + last_ident.clone(), + Some(format!("vec![{var_name}.dtype().as_str().to_string()]")), + ); + } + SupportedGenericType::ConstArray => { + // This is a CGA type. + if is_mutable { + required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.partition_shape.iter().map(|x| x.to_string()).collect::>()"))); + } else { + // This might make sense for a small tensor. + required_generics.expressions.insert(last_ident.clone(), Some(format!("{var_name}.shape.iter().map(|x| x.to_string()).collect::>()"))); } - SupportedGenericType::Unknown => {} } + SupportedGenericType::ConstScalar => { + return type_path.err( + "Unexpected constant scalar type in tensor generic argument.", + ); + } + SupportedGenericType::Unknown => {} } - _ => {} } } GenericArgument::Const(const_param) => { // println!("expand GenericArgument::Const? {const_param:#?}"); - match const_param { - Expr::Block(block_expr) => { - // This is something like Tensor - if block_expr.block.stmts.len() != 1 { - return block_expr.err(&format!( - "Expected exactly 1 statement in block expression, got {}.", - block_expr.block.stmts.len() - )); - } - let statement = &block_expr.block.stmts[0]; - let Stmt::Expr(statement_expr, _) = statement else { - return block_expr.err( - "Unexpected block expression: expected an expression statement.", - ); - }; - match statement_expr { - Expr::Array(array_expr) => { - // This is something like Tensor - for (i, elem) in array_expr.elems.iter().enumerate() { - match elem { - Expr::Lit(_lit) => { - // Nothing to do to build generic arg expressions. - continue; - } - Expr::Unary(_unary_expr) => { - // Nothing to do to build generic arg expressions. - continue; - } - Expr::Path(path) => { - let ident = get_ident_from_path_expr(path).to_string(); - match required_generics.get_ty(&ident) { - SupportedGenericType::TypeParam => { - // This is an element type. - return path.err("Unexpected type param in array type expression."); - } - SupportedGenericType::ConstArray => { - // This is a CGA type. - return path.err("Unexpected const generic array param in array type expression."); - } - SupportedGenericType::ConstScalar => { - if is_mutable { - required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.partition_shape[{i}].to_string()]"))); - } else { - required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.shape[{i}].to_string()]"))); - } + if let Expr::Block(block_expr) = const_param { + // This is something like Tensor + if block_expr.block.stmts.len() != 1 { + return block_expr.err(&format!( + "Expected exactly 1 statement in block expression, got {}.", + block_expr.block.stmts.len() + )); + } + let statement = &block_expr.block.stmts[0]; + let Stmt::Expr(statement_expr, _) = statement else { + return block_expr + .err("Unexpected block expression: expected an expression statement."); + }; + match statement_expr { + Expr::Array(array_expr) => { + // This is something like Tensor + for (i, elem) in array_expr.elems.iter().enumerate() { + match elem { + Expr::Lit(_lit) => { + // Nothing to do to build generic arg expressions. + continue; + } + Expr::Unary(_unary_expr) => { + // Nothing to do to build generic arg expressions. + continue; + } + Expr::Path(path) => { + let ident = get_ident_from_path_expr(path).to_string(); + match required_generics.get_ty(&ident) { + SupportedGenericType::TypeParam => { + // This is an element type. + return path.err("Unexpected type param in array type expression."); + } + SupportedGenericType::ConstArray => { + // This is a CGA type. + return path.err("Unexpected const generic array param in array type expression."); + } + SupportedGenericType::ConstScalar => { + if is_mutable { + required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.partition_shape[{i}].to_string()]"))); + } else { + required_generics.expressions.insert(ident.clone(), Some(format!("vec![{var_name}.shape[{i}].to_string()]"))); } - SupportedGenericType::Unknown => {} } + SupportedGenericType::Unknown => {} } - _ => { - return elem.err("Unsupported array element in tensor shape expression."); - } + } + _ => { + return elem.err( + "Unsupported array element in tensor shape expression.", + ); } } } - Expr::Repeat(repeat_expr) => { - // TODO (hme): Unclear under what circumstance it would be beneficial to support this. - return repeat_expr.err( - "Repeat expressions in tensor shape are not yet supported.", - ); - } - _ => { - return block_expr.err( - "Unexpected block expression in tensor const generic argument.", - ); - } + } + Expr::Repeat(repeat_expr) => { + // TODO (hme): Unclear under what circumstance it would be beneficial to support this. + return repeat_expr + .err("Repeat expressions in tensor shape are not yet supported."); + } + _ => { + return block_expr.err( + "Unexpected block expression in tensor const generic argument.", + ); } } - _ => {} } } _ => {} diff --git a/cutile-macro/src/rewrite_variadics.rs b/cutile-macro/src/rewrite_variadics.rs index 01809045..98cac519 100644 --- a/cutile-macro/src/rewrite_variadics.rs +++ b/cutile-macro/src/rewrite_variadics.rs @@ -100,7 +100,6 @@ use cutile_compiler::types::parse_signed_literal_as_i32; use proc_macro2::{Ident, Span, TokenTree}; use quote::ToTokens; use std::collections::BTreeMap; -#[allow(unused_assignments)] use std::collections::{HashMap, HashSet}; use syn::{ parse_quote, spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprCall, ExprMethodCall, @@ -201,7 +200,7 @@ fn try_get_path_expr_ident_str(maybe_path_expr: &Expr) -> Result, fn get_vod_from_call(expr: &mut ExprCall) -> Result, Error> { let name = match &*expr.func { Expr::Path(path_expr) => { - if path_expr.path.segments.len() == 0 { + if path_expr.path.segments.is_empty() { return Ok(None); } else { let fn_name = path_expr @@ -284,7 +283,7 @@ fn get_ident_generic_args( syn_err(type_path.span(), "Expected at least one path segment") })?; let last_seg = maybe_last_seg.clone(); - if last_seg.ident.to_string() != vtd.name { + if last_seg.ident != vtd.name { return Err(syn_err( last_seg.ident.span(), &format!( @@ -300,14 +299,11 @@ fn get_ident_generic_args( // This is a type of the form T<...> Ok((last_seg.ident.clone(), type_params.clone())) } - _ => Err(syn_err( - type_path.span(), - &format!("Unexpected generic arguments"), - )), + _ => Err(syn_err(type_path.span(), "Unexpected generic arguments")), } } Type::Reference(ref_type) => get_ident_generic_args(&ref_type.elem, vtd), - _ => Err(syn_err(ty.span(), &format!("Unexpected type"))), + _ => Err(syn_err(ty.span(), "Unexpected type")), } } @@ -427,7 +423,7 @@ fn get_concrete_op_or_method_ident_from_types( let Some(cga_instances) = get_cga_type(ty, const_instances)? else { return Err(syn_err( op_or_method_ident.span(), - &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for type: {}", ty.to_token_stream().to_string()), + &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for type: {}", ty.to_token_stream()), )); }; // This is a variadic type with cga instances. @@ -484,7 +480,7 @@ fn get_concrete_op_or_method_ident_from_types( return Err(syn_err( op_or_method_ident.span(), &format!("Unable to infer call to {}. Try binding it to a statically typed variable. \nDebug info:\n const_length_values={:#?}, vod.const_length_vars={:#?}", - op_or_method_ident.to_string(), + op_or_method_ident, const_length_values, vod.const_length_vars), )); @@ -497,7 +493,7 @@ fn get_concrete_op_or_method_ident_from_types( op_or_method_ident.span(), &format!( "Unable to infer call to {}. Try binding it to a statically typed variable.", - op_or_method_ident.to_string() + op_or_method_ident ), )); } @@ -506,16 +502,16 @@ fn get_concrete_op_or_method_ident_from_types( if cga_instances.is_none() { return Err(syn_err( op_or_method_ident.span(), - &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for output type: {}", output_type.to_token_stream().to_string()), + &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unable to get cga instances for output type: {}", output_type.to_token_stream()), )); } let cga_instances = cga_instances.unwrap(); - let (expected_type_name, vod_cga_var_names) = vod.output_map.clone(); + let (expected_type_name, vod_cga_var_names) = vod.output_map; if expected_type_name != vtd.name { return Err(syn_err( op_or_method_ident.span(), - &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unexpected output type: {}", output_type.to_token_stream().to_string()), + &format!("get_concrete_op_ident_from_types({op_or_method_ident}, ...): Unexpected output type: {}", output_type.to_token_stream()), )); } if vod_cga_var_names.len() != cga_instances.n.len() { @@ -559,7 +555,7 @@ fn get_concrete_op_or_method_ident_from_types( output_type } else { let (return_type_name, return_type_generic_args) = vod.return_type; - if return_type_generic_args.len() == 0 { + if return_type_generic_args.is_empty() { let ty = syn::parse::(return_type_name.parse().map_err(|_| { syn_err( op_or_method_ident.span(), @@ -592,7 +588,7 @@ fn get_concrete_op_or_method_ident_from_types( if output_type.is_none() { return Err(syn_err( op_or_method_ident.span(), - &format!("Failed to infer return type generic args {:?} \nop={} \nvod_cga_name_to_context_cga_name={vod_cga_name_to_context_cga_name:#?}", missing_cgas, op_or_method_ident.to_string()), + &format!("Failed to infer return type generic args {:?} \nop={} \nvod_cga_name_to_context_cga_name={vod_cga_name_to_context_cga_name:#?}", missing_cgas, op_or_method_ident), )); } output_type @@ -743,7 +739,7 @@ impl ConstInstances { }) } fn from_generics(generics: &Generics) -> Result { - let (cga_param, _u32_param) = parse_cgas(&generics); + let (cga_param, _u32_param) = parse_cgas(generics); let inst_u32: HashMap = HashMap::new(); let mut inst_array: HashMap = HashMap::new(); let var_arrays: HashMap = HashMap::new(); @@ -839,7 +835,7 @@ impl VariadicLengthIterator { ) })? + 1) as usize; i_max *= len; - if variadic_lengths.insert(var.clone(), len.clone()).is_some() { + if variadic_lengths.insert(var.clone(), len).is_some() { return Err(syn_err( Span::call_site(), &format!("Duplicate variadic_length_var '{var}'"), @@ -867,7 +863,7 @@ impl VariadicLengthIterator { } } else { i_max *= len; - variadic_lengths.insert(var.clone(), len.clone()); + variadic_lengths.insert(var.clone(), len); } } Ok(VariadicLengthIterator { @@ -899,7 +895,6 @@ impl VariadicLengthItem { // Ordered by key. self.variadic_length_instance .values() - .into_iter() .map(|x| *x as u32) .collect::>() } @@ -922,7 +917,7 @@ impl Iterator for VariadicLengthIterator { for len_var in &self.cga_length_vars { let len = *variadic_length_instance .get(len_var) - .expect(&format!("Unexpected length var {len_var}")); + .unwrap_or_else(|| panic!("Unexpected length var {len_var}")); cga_length_instance.push((len_var.clone(), len)); } Some(VariadicLengthItem { @@ -1013,7 +1008,7 @@ pub fn variadic_struct( )); } let mut result: Vec<(ItemStruct, Option)> = vec![]; - for (_, var_cga_iter_item) in cga_iter.enumerate() { + for var_cga_iter_item in cga_iter { let mut concrete = item.clone(); // This just constructs the current instantiation of the const generic arrays for this struct. // There is usually only one CGA for structs. @@ -1067,7 +1062,7 @@ pub fn variadic_struct( constructors.push(dyn_constructor); if num_dynamic == 0 { let constructor_name = - format!("{}", maybe_constructor_name.clone().unwrap()); + maybe_constructor_name.clone().unwrap().to_string(); let const_constructor = format!( r#" pub fn const_{constructor_name}() -> Self {{ @@ -1146,7 +1141,7 @@ pub fn variadic_trait( let cga_iter = VariadicLengthIterator::new(attributes, &cgas)?; let rewrite_variadics = RewriteVariadicsPass {}; let mut result: Vec = vec![]; - for (_, n_list) in cga_iter.enumerate() { + for n_list in cga_iter { let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?; result.push(rewrite_variadics.rewrite_trait(&item, &const_instances)?); } @@ -1179,7 +1174,7 @@ pub fn variadic_impl(attributes: &SingleMetaList, item: ItemImpl) -> Result = vec![]; - for (_, n_list) in cga_iter.enumerate() { + for n_list in cga_iter { let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?; result.push(rewrite_variadics.rewrite_impl(&item, &const_instances)?); } @@ -1187,7 +1182,7 @@ pub fn variadic_impl(attributes: &SingleMetaList, item: ItemImpl) -> Result = vec![]; // Iterate over the set of const generic arrays. - for (_, cga_iter_item) in cga_iter.enumerate() { + for cga_iter_item in cga_iter { // Generate as many items as the product of const generic array instances. let const_instances = const_instances_impl .instantiate_new_var_cgas(&cga_iter_item.vec_of_cga_lengths(), &cgas)?; @@ -1213,7 +1208,7 @@ pub(self) fn variadic_impl_fn_gen( } /// Rewrites a single impl method using the given const instantiations. -pub(self) fn rewrite_impl_fn( +fn rewrite_impl_fn( self_ty: &Type, item: &ImplItemFn, const_instances: &ConstInstances, @@ -1226,11 +1221,8 @@ pub(self) fn rewrite_impl_fn( } /// Desugars const generic arrays in a function signature's generics, inputs, and output. -pub(self) fn rewrite_fn_sig( - sig: &mut Signature, - const_instances: &ConstInstances, -) -> Result<(), Error> { - desugar_generics(&mut sig.generics, &const_instances)?; +fn rewrite_fn_sig(sig: &mut Signature, const_instances: &ConstInstances) -> Result<(), Error> { + desugar_generics(&mut sig.generics, const_instances)?; let mut desugared_inputs = sig.inputs.clone(); for input in desugared_inputs.iter_mut() { match input { @@ -1238,18 +1230,15 @@ pub(self) fn rewrite_fn_sig( // Leave this. } FnArg::Typed(fn_param) => { - let fn_param_type = desugar_ty(&*fn_param.ty, &const_instances)?; + let fn_param_type = desugar_ty(&fn_param.ty, const_instances)?; *fn_param.ty = fn_param_type; } } } sig.inputs = desugared_inputs; let mut desugared_outputs = sig.output.clone(); - match &mut desugared_outputs { - ReturnType::Type(_, return_type) => { - *return_type = Box::new(desugar_ty(&return_type.clone(), &const_instances)?); - } - _ => {} + if let ReturnType::Type(_, return_type) = &mut desugared_outputs { + **return_type = desugar_ty(&return_type.clone(), const_instances)?; } sig.output = desugared_outputs; Ok(()) @@ -1305,7 +1294,7 @@ pub fn variadic_op(attributes: &SingleMetaList, item: ItemFn) -> Result = vec![]; // Iterate over the set of const generic arrays. let rewrite_variadics = RewriteVariadicsPass {}; - for (_, n_list) in cga_iter.enumerate() { + for n_list in cga_iter { // Generate as many items as the product of const generic array instances. let const_instances = ConstInstances::from_variadic(&n_list, &cgas)?; result.push(rewrite_variadics.rewrite_function(&item, &const_instances)?); @@ -1314,7 +1303,7 @@ pub fn variadic_op(attributes: &SingleMetaList, item: ItemFn) -> Result Result<(), Error> { @@ -1368,7 +1357,7 @@ pub(self) fn desugar_generics( } /// Expands a CGA path into angle-bracketed individual const generic arguments. -pub(self) fn expand_cga( +fn expand_cga( path: &Path, instances: &ConstInstances, ) -> Result { @@ -1406,16 +1395,13 @@ pub(self) fn expand_cga( } else { Err(syn_err( path.span(), - &format!( - "{} is not a const generic array.", - path.to_token_stream().to_string() - ), + &format!("{} is not a const generic array.", path.to_token_stream()), )) } } /// Desugars variadic types in a path, replacing CGA syntax with concrete type names and args. -pub(self) fn desugar_path(path: &Path, instances: &ConstInstances) -> Result { +fn desugar_path(path: &Path, instances: &ConstInstances) -> Result { let mut result_path = path.clone(); for (i, seg) in path.segments.iter().enumerate() { let param_name = seg.ident.to_string(); @@ -1427,7 +1413,7 @@ pub(self) fn desugar_path(path: &Path, instances: &ConstInstances) -> Result Result { // This is a type of the form T<...> let (type_ident, last_seg_args) = - desugar_cga(&instances, &seg.ident, &type_params)?; + desugar_cga(instances, &seg.ident, type_params)?; ( type_ident.clone(), PathArguments::AngleBracketed(last_seg_args), @@ -1479,7 +1465,7 @@ pub(self) fn desugar_path(path: &Path, instances: &ConstInstances) -> Result Result<(), Error> { @@ -1487,15 +1473,12 @@ pub(self) fn desugar_generic_arguments( for arg in &mut generic_args.args { match arg { GenericArgument::Type(ty) => { - *arg = GenericArgument::Type(desugar_ty(&ty, &const_instances)?); + *arg = GenericArgument::Type(desugar_ty(ty, const_instances)?); } _ => { return Err(syn_err( span, - &format!( - "Unsupported generic argument {}", - arg.to_token_stream().to_string() - ), + &format!("Unsupported generic argument {}", arg.to_token_stream()), )) } } @@ -1504,7 +1487,7 @@ pub(self) fn desugar_generic_arguments( } /// Recursively desugars const generic array syntax within a type. -pub(self) fn desugar_ty(ty: &Type, instances: &ConstInstances) -> Result { +fn desugar_ty(ty: &Type, instances: &ConstInstances) -> Result { // Desugar const generic arrays as they appear as const generic arguments. Ok(match ty { Type::Path(type_path) => { @@ -1569,7 +1552,7 @@ pub(self) fn desugar_ty(ty: &Type, instances: &ConstInstances) -> Result { let mut result = tuple_type.clone(); for elem in &mut result.elems { - *elem = desugar_ty(&elem, instances)?; + *elem = desugar_ty(elem, instances)?; } Type::Tuple(result) } @@ -1579,7 +1562,7 @@ pub(self) fn desugar_ty(ty: &Type, instances: &ConstInstances) -> Result Result, Error> { @@ -1799,7 +1782,7 @@ pub(self) fn get_cga_type( type_ref.span(), &format!( "get_cga_type: Type::Reference not supported: {}", - type_ref.to_token_stream().to_string() + type_ref.to_token_stream() ), )); } @@ -1808,85 +1791,75 @@ pub(self) fn get_cga_type( } GenericArgument::Const(const_param) => { // println!("expand GenericArgument::Const? {const_param:#?}"); - match const_param { - Expr::Block(block_expr) => { - // TODO (hme): Would be great to get rid of this syntax. - // This is something like Tensor - if block_expr.block.stmts.len() != 1 { - return Err(syn_err( - block_expr.span(), - &format!( - "Expected exactly 1 statement in block expression, got {}", - block_expr.block.stmts.len() - ), - )); + if let Expr::Block(block_expr) = const_param { + // TODO (hme): Would be great to get rid of this syntax. + // This is something like Tensor + if block_expr.block.stmts.len() != 1 { + return Err(syn_err( + block_expr.span(), + &format!( + "Expected exactly 1 statement in block expression, got {}", + block_expr.block.stmts.len() + ), + )); + } + let statement = &block_expr.block.stmts[0]; + let Stmt::Expr(statement_expr, _) = statement else { + return Err(syn_err(block_expr.span(), "Unexpected block expression.")); + }; + match statement_expr { + Expr::Array(array_expr) => { + // This is something like Tensor + n.push(array_expr.elems.len() as u32); + cgas.push(Some(generic_arg.to_token_stream().to_string())); } - let statement = &block_expr.block.stmts[0]; - let Stmt::Expr(statement_expr, _) = statement else { - return Err(syn_err(block_expr.span(), "Unexpected block expression.")); - }; - match statement_expr { - Expr::Array(array_expr) => { - // This is something like Tensor - n.push(array_expr.elems.len() as u32); - cgas.push(Some(generic_arg.to_token_stream().to_string())); - } - Expr::Repeat(repeat_expr) => { - // println!("Expr::Repeat: {:?}", repeat_expr.expr); - let _thing_to_repeat = - repeat_expr.expr.to_token_stream().to_string(); - match &*repeat_expr.len { - Expr::Path(len_path) => { - // This is something like Tensor - let num_rep_var = len_path.to_token_stream().to_string(); - if !const_instances.inst_u32.contains_key(&num_rep_var) { - return Err(syn_err( - len_path.span(), - &format!( - "Expected instance for generic argument {}", - num_rep_var - ), - )); - } - let num_rep = - const_instances.inst_u32.get(&num_rep_var).unwrap(); - n.push(*num_rep); - cgas.push(Some(generic_arg.to_token_stream().to_string())); - } - Expr::Lit(len_lit) => { - // This is something like Tensor - let num_repetitions: u32 = len_lit - .to_token_stream() - .to_string() - .parse::() - .map_err(|e| { - syn_err( - len_lit.span(), - &format!( - "Failed to parse repeat length as u32: {e}" - ), - ) - })?; - n.push(num_repetitions); - cgas.push(Some(generic_arg.to_token_stream().to_string())); - } - _ => { + Expr::Repeat(repeat_expr) => { + // println!("Expr::Repeat: {:?}", repeat_expr.expr); + let _thing_to_repeat = repeat_expr.expr.to_token_stream().to_string(); + match &*repeat_expr.len { + Expr::Path(len_path) => { + // This is something like Tensor + let num_rep_var = len_path.to_token_stream().to_string(); + if !const_instances.inst_u32.contains_key(&num_rep_var) { return Err(syn_err( - ty.span(), - "Unexpected repeat expression.", - )) + len_path.span(), + &format!( + "Expected instance for generic argument {}", + num_rep_var + ), + )); } + let num_rep = + const_instances.inst_u32.get(&num_rep_var).unwrap(); + n.push(*num_rep); + cgas.push(Some(generic_arg.to_token_stream().to_string())); + } + Expr::Lit(len_lit) => { + // This is something like Tensor + let num_repetitions: u32 = len_lit + .to_token_stream() + .to_string() + .parse::() + .map_err(|e| { + syn_err( + len_lit.span(), + &format!( + "Failed to parse repeat length as u32: {e}" + ), + ) + })?; + n.push(num_repetitions); + cgas.push(Some(generic_arg.to_token_stream().to_string())); + } + _ => { + return Err(syn_err(ty.span(), "Unexpected repeat expression.")) } } - _ => { - return Err(syn_err( - block_expr.span(), - "Unexpected block expression.", - )) - } + } + _ => { + return Err(syn_err(block_expr.span(), "Unexpected block expression.")) } } - _ => {} } } _ => {} @@ -1975,7 +1948,7 @@ impl RewriteVariadicsPass { // This is not a variadic struct, so we don't attempt to rewrite its name. let mut item = item.clone(); for field in &mut item.fields { - field.ty = desugar_ty(&field.ty, &const_instances)?; + field.ty = desugar_ty(&field.ty, const_instances)?; } Ok(item) } @@ -1988,19 +1961,19 @@ impl RewriteVariadicsPass { let mut item = item.clone(); let mut variables: TrainMap = self.bind_parameters(None, &item.sig)?; let (inputs, output) = get_sig_types(&item.sig, None); - let inputs = inputs.into_iter().map(|x| Some(x)).collect::>(); + let inputs = inputs.into_iter().map(Some).collect::>(); item.sig.ident = get_concrete_op_ident_from_types( &item.sig.ident, &inputs, Some(output.clone()), - &const_instances, + const_instances, true, )? .0; - self.rewrite_sig(&mut item.sig, &const_instances)?; + self.rewrite_sig(&mut item.sig, const_instances)?; self.rewrite_statements( &mut item.block.stmts, - &const_instances, + const_instances, &mut variables, Some(output), )?; @@ -2013,7 +1986,7 @@ impl RewriteVariadicsPass { const_instances: &ConstInstances, ) -> Result { let mut item = item.clone(); - if const_instances.inst_u32.len() == 0 { + if const_instances.inst_u32.is_empty() { return Ok(item); } if const_instances.inst_u32.len() != 1 { @@ -2023,11 +1996,11 @@ impl RewriteVariadicsPass { )); } let key = const_instances.inst_u32.keys().next().unwrap().clone(); - let n = const_instances.inst_u32.get(&key).unwrap().clone(); + let n = *const_instances.inst_u32.get(&key).unwrap(); let trait_name = item.ident.to_string(); - let concrete_name = concrete_name(&trait_name, &vec![n]); + let concrete_name = concrete_name(&trait_name, &[n]); item.ident = Ident::new(&concrete_name, item.ident.span()); - desugar_generics(&mut item.generics, &const_instances)?; + desugar_generics(&mut item.generics, const_instances)?; // Update items. let mut impl_items: Vec = vec![]; for concrete_item in &mut item.items { @@ -2052,7 +2025,7 @@ impl RewriteVariadicsPass { ) })?; let (inputs, output) = get_sig_types(&result.sig, Some(&self_type)); - let inputs = inputs.into_iter().map(|x| Some(x)).collect::>(); + let inputs = inputs.into_iter().map(Some).collect::>(); result.sig.ident = get_concrete_op_or_method_ident_from_types( vod, &result.sig.ident, @@ -2067,12 +2040,7 @@ impl RewriteVariadicsPass { self.rewrite_sig(&mut result.sig, &const_instances)?; impl_items.push(TraitItem::Fn(result)); } - _ => { - return Err(syn_err( - concrete_item.span(), - &format!("Unsupported impl item"), - )) - } + _ => return Err(syn_err(concrete_item.span(), "Unsupported impl item")), } } item.items = impl_items; @@ -2086,14 +2054,14 @@ impl RewriteVariadicsPass { ) -> Result { let mut item = item.clone(); let self_ty = *item.self_ty.clone(); - *item.self_ty = desugar_ty(&*item.self_ty, &const_instances)?; - desugar_generics(&mut item.generics, &const_instances)?; + *item.self_ty = desugar_ty(&item.self_ty, const_instances)?; + desugar_generics(&mut item.generics, const_instances)?; let mut variadic_trait_vtd = None; // Update generics in trait definition. if let Some(trait_) = &mut item.trait_ { let path_copy = trait_.1.clone(); let path = &mut trait_.1; - if path.segments.len() == 0 { + if path.segments.is_empty() { return Err(syn_err( path.span(), "Expected at least one path segment in trait path", @@ -2102,26 +2070,18 @@ impl RewriteVariadicsPass { let last_seg = path.segments.last_mut().unwrap(); let ident_vtd = get_variadic_type_data(last_seg.ident.to_string().as_str()); if ident_vtd.is_some() { - match ident_vtd { - Some(vtd) => { - if const_instances.inst_u32.len() != 1 { - return Err(syn_err( - path.span(), - "Only one CGA is permitted for variadic traits.", - )); - } - *path = desugar_path(&path_copy, const_instances)?; - variadic_trait_vtd = Some(vtd); - } - None => {} - } - } else { - match &mut last_seg.arguments { - PathArguments::AngleBracketed(path_args) => { - desugar_generic_arguments(path_args, &const_instances)? + if let Some(vtd) = ident_vtd { + if const_instances.inst_u32.len() != 1 { + return Err(syn_err( + path.span(), + "Only one CGA is permitted for variadic traits.", + )); } - _ => {} + *path = desugar_path(&path_copy, const_instances)?; + variadic_trait_vtd = Some(vtd); } + } else if let PathArguments::AngleBracketed(path_args) = &mut last_seg.arguments { + desugar_generic_arguments(path_args, const_instances)? } } @@ -2131,7 +2091,7 @@ impl RewriteVariadicsPass { match concrete_item { ImplItem::Type(type_impl) => { let mut result = type_impl.clone(); - result.ty = desugar_ty(&type_impl.ty, &const_instances)?; + result.ty = desugar_ty(&type_impl.ty, const_instances)?; impl_items.push(ImplItem::Type(result)); } ImplItem::Fn(fn_impl) => { @@ -2149,8 +2109,8 @@ impl RewriteVariadicsPass { let results: Vec = variadic_impl_fn_gen( &attributes, &self_ty, - &fn_impl, - &const_instances, + fn_impl, + const_instances, )?; for result in results { impl_items.push(ImplItem::Fn(result)); @@ -2162,7 +2122,7 @@ impl RewriteVariadicsPass { self.rewrite_impl_fn( &self_ty, &mut result, - &const_instances, + const_instances, variadic_trait_vtd.clone(), )?; // println!("{:#?}", &result); @@ -2170,12 +2130,7 @@ impl RewriteVariadicsPass { } } } - _ => { - return Err(syn_err( - concrete_item.span(), - &format!("Unsupported impl item."), - )) - } + _ => return Err(syn_err(concrete_item.span(), "Unsupported impl item.")), } } item.items = impl_items; @@ -2201,18 +2156,15 @@ impl RewriteVariadicsPass { let method_name = item.sig.ident.to_string(); let vmmd = if variadic_trait_vtd.is_some() { let vtd = variadic_trait_vtd.unwrap(); - match get_variadic_method_data(&vtd, &method_name)? { - Some((op_name, vod)) => Some((op_name, vtd, vod)), - None => None, - } + get_variadic_method_data(&vtd, &method_name)?.map(|(op_name, vod)| (op_name, vtd, vod)) } else { - get_variadic_method_meta_data(&self_ty, &method_name)? + get_variadic_method_meta_data(self_ty, &method_name)? }; if let Some((_op_name, _vtd, vod)) = vmmd { // If it is, then rewrite the sig ident. // We do the same thing on the method call side. let (inputs, output) = get_sig_types(&item.sig, Some(self_ty)); - let inputs = inputs.into_iter().map(|x| Some(x)).collect::>(); + let inputs = inputs.into_iter().map(Some).collect::>(); // TODO (hme): This may result in redundant suffixes, but that should be okay. item.sig.ident = get_concrete_op_or_method_ident_from_types( vod, @@ -2249,7 +2201,7 @@ impl RewriteVariadicsPass { _ => { return Err(syn_err( fn_param.span(), - &format!("Unexpected function param pattern."), + "Unexpected function param pattern.", )) } } @@ -2282,7 +2234,7 @@ impl RewriteVariadicsPass { sig: &mut Signature, const_instances: &ConstInstances, ) -> Result<(), Error> { - rewrite_fn_sig(sig, &const_instances) + rewrite_fn_sig(sig, const_instances) } fn rewrite_statements( @@ -2318,7 +2270,7 @@ impl RewriteVariadicsPass { )?; } binding_ty = Some(*pat_type.ty.clone()); - let new_ty = desugar_ty(&*pat_type.ty, &const_instances)?; + let new_ty = desugar_ty(&pat_type.ty, const_instances)?; *pat_type.ty = new_ty; // Skip normal single-variable logic - compiler will handle tuple binding continue; @@ -2326,12 +2278,12 @@ impl RewriteVariadicsPass { _ => { return Err(syn_err( pat_type.span(), - &format!("let binding LHS not implemented."), + "let binding LHS not implemented.", )) } } binding_ty = Some(*pat_type.ty.clone()); - let new_ty = desugar_ty(&*pat_type.ty, &const_instances)?; + let new_ty = desugar_ty(&pat_type.ty, const_instances)?; // println!("rewrite_statements Stmt::Local Pat::Type {:#?}", new_ty); *pat_type.ty = new_ty; } @@ -2353,32 +2305,24 @@ impl RewriteVariadicsPass { } continue; // Skip normal single-variable logic } - _ => { - return Err(syn_err( - local.span(), - &format!("Local pattern type not supported"), - )) - } + _ => return Err(syn_err(local.span(), "Local pattern type not supported")), } if binding_name.is_none() { - return Err(syn_err(local.span(), &format!("Unable to rewrite expr."))); + return Err(syn_err(local.span(), "Unable to rewrite expr.")); } let binding_name = binding_name.unwrap(); - match &mut local.init { - Some(init) => { - // Rewrite the expression but preserve explicit type annotations - let inferred_ty = self.rewrite_expr( - &mut *init.expr, - const_instances, - variables, - binding_ty.clone(), - )?; - // Only use inferred type if we don't have an explicit type annotation - if binding_ty.is_none() { - binding_ty = inferred_ty; - } + if let Some(init) = &mut local.init { + // Rewrite the expression but preserve explicit type annotations + let inferred_ty = self.rewrite_expr( + &mut init.expr, + const_instances, + variables, + binding_ty.clone(), + )?; + // Only use inferred type if we don't have an explicit type annotation + if binding_ty.is_none() { + binding_ty = inferred_ty; } - None => {} } variables.insert( binding_name.clone(), @@ -2395,7 +2339,7 @@ impl RewriteVariadicsPass { let return_type = Some(*const_item.ty.clone()); // This is like a let binding with limitations. self.rewrite_expr( - &mut *const_item.expr, + &mut const_item.expr, const_instances, variables, return_type, @@ -2406,13 +2350,13 @@ impl RewriteVariadicsPass { item.span(), &format!( "{}\nOnly const local item definitions are supported.", - item.to_token_stream().to_string() + item.to_token_stream() ), )) } }; let Some(binding_name) = binding_name else { - return Err(syn_err(item.span(), &format!("Unable to rewrite expr."))); + return Err(syn_err(item.span(), "Unable to rewrite expr.")); }; variables.insert( binding_name.clone(), @@ -2441,7 +2385,7 @@ impl RewriteVariadicsPass { _ => { return Err(syn_err( assign_expr.span(), - &format!("Expr::Assign not supported"), + "Expr::Assign not supported", )) } } @@ -2504,7 +2448,7 @@ impl RewriteVariadicsPass { inner_expr_span, &format!( "Index expression not supported: {}", - index_expr.expr.to_token_stream().to_string() + index_expr.expr.to_token_stream() ), )); } @@ -2553,19 +2497,16 @@ impl RewriteVariadicsPass { let Type::Slice(slice_ty) = *ty.elem.clone() else { return Err(syn_err( expr_span, - &format!("Index expression not supported (reference)"), + "Index expression not supported (reference)", )); }; Ok(Some(*slice_ty.elem.clone())) } None => Err(syn_err( expr_span, - &format!("Failed to compute type for index expression"), - )), - Some(_other) => Err(syn_err( - expr_span, - &format!("Index expression not supported"), + "Failed to compute type for index expression", )), + Some(_other) => Err(syn_err(expr_span, "Index expression not supported")), } } } @@ -2611,7 +2552,7 @@ impl RewriteVariadicsPass { Expr::While(while_expr) => { // While loop: while condition { body } // Rewrite condition and body - self.rewrite_expr(&mut *while_expr.cond, const_instances, variables, None)?; + self.rewrite_expr(&mut while_expr.cond, const_instances, variables, None)?; let mut block_vars = variables.fork(); self.rewrite_statements( &mut while_expr.body.stmts, @@ -2641,7 +2582,7 @@ impl RewriteVariadicsPass { if let Some((_Else, else_expr)) = &mut if_expr.else_branch { let mut block_vars = variables.fork(); self.rewrite_expr( - &mut **else_expr, + else_expr, const_instances, &mut block_vars, return_type.clone(), @@ -2668,12 +2609,12 @@ impl RewriteVariadicsPass { ), Expr::Cast(cast_expr) => { self.rewrite_expr( - &mut *cast_expr.expr, + &mut cast_expr.expr, const_instances, variables, return_type.clone(), )?; - *cast_expr.ty = desugar_ty(&*cast_expr.ty, const_instances)?; + *cast_expr.ty = desugar_ty(&cast_expr.ty, const_instances)?; Ok(return_type) } Expr::Path(path_expr) => { @@ -2711,7 +2652,7 @@ impl RewriteVariadicsPass { ) } Expr::Reference(ref_expr) => self.rewrite_expr( - &mut *ref_expr.expr, + &mut ref_expr.expr, const_instances, variables, return_type.clone(), @@ -2726,26 +2667,26 @@ impl RewriteVariadicsPass { None => Ok(return_type), }, Expr::Assign(assign_expr) => self.rewrite_expr( - &mut *assign_expr.right, + &mut assign_expr.right, const_instances, variables, return_type.clone(), ), Expr::Unary(unary_expr) => self.rewrite_expr( - &mut *unary_expr.expr, + &mut unary_expr.expr, const_instances, variables, return_type.clone(), ), Expr::Binary(bin_expr) => { self.rewrite_expr( - &mut *bin_expr.left, + &mut bin_expr.left, const_instances, variables, return_type.clone(), )?; self.rewrite_expr( - &mut *bin_expr.right, + &mut bin_expr.right, const_instances, variables, return_type.clone(), @@ -2765,12 +2706,12 @@ impl RewriteVariadicsPass { Ok(return_type) } Expr::Repeat(repeat_expr) => { - self.rewrite_expr(&mut *repeat_expr.len, const_instances, variables, None)?; + self.rewrite_expr(&mut repeat_expr.len, const_instances, variables, None)?; Ok(return_type) } Expr::Field(field_expr) => { return_type = self.rewrite_expr( - &mut *field_expr.base, + &mut field_expr.base, const_instances, variables, return_type.clone(), @@ -2780,7 +2721,7 @@ impl RewriteVariadicsPass { Expr::Struct(struct_expr) => { // TODO (hme): Similar code fragment in desugar_ty. // Can this be refactored into a rewrite for any PathSegment? - if struct_expr.path.segments.len() == 0 { + if struct_expr.path.segments.is_empty() { return Err(syn_err( struct_expr.span(), "Expected at least one path segment in struct expression", @@ -2789,37 +2730,29 @@ impl RewriteVariadicsPass { let last_seg = struct_expr.path.segments.last_mut().unwrap(); let name = last_seg.ident.to_string(); let vtd = get_variadic_type_data(name.as_str()); - match vtd { - Some(_vtd) => { - if return_type.is_none() { - return Err(syn_err( - struct_expr.span(), - "Variadic structs require a static type annotation. Try assigning to a statically typed let binding.", - )); - } - let (last_type_ident, last_seg_args) = match &last_seg.arguments { - PathArguments::AngleBracketed(type_params) => { - let (type_ident, last_seg_args) = - desugar_cga(&const_instances, &last_seg.ident, &type_params)?; - ( - type_ident.clone(), - PathArguments::AngleBracketed(last_seg_args), - ) - } - PathArguments::None => (last_seg.ident.clone(), PathArguments::None), - _ => { - return Err(syn_err( - struct_expr.span(), - "Unexpected Path arguments.", - )) - } - }; - *last_seg = PathSegment { - ident: last_type_ident, - arguments: last_seg_args, - }; + if let Some(_vtd) = vtd { + if return_type.is_none() { + return Err(syn_err( + struct_expr.span(), + "Variadic structs require a static type annotation. Try assigning to a statically typed let binding.", + )); } - None => {} + let (last_type_ident, last_seg_args) = match &last_seg.arguments { + PathArguments::AngleBracketed(type_params) => { + let (type_ident, last_seg_args) = + desugar_cga(const_instances, &last_seg.ident, type_params)?; + ( + type_ident.clone(), + PathArguments::AngleBracketed(last_seg_args), + ) + } + PathArguments::None => (last_seg.ident.clone(), PathArguments::None), + _ => return Err(syn_err(struct_expr.span(), "Unexpected Path arguments.")), + }; + *last_seg = PathSegment { + ident: last_type_ident, + arguments: last_seg_args, + }; } for field in &mut struct_expr.fields { self.rewrite_expr(&mut field.expr, const_instances, variables, None)?; @@ -2927,7 +2860,7 @@ impl RewriteVariadicsPass { Expr::Lit(_lit_expr) => Ok(return_type), Expr::Paren(paren_expr) => { return_type = self.rewrite_expr( - &mut *paren_expr.expr, + &mut paren_expr.expr, const_instances, variables, return_type.clone(), @@ -2939,10 +2872,7 @@ impl RewriteVariadicsPass { // The compiler will handle parsing and compilation of closure bodies Ok(return_type) } - _ => Err(syn_err( - expr.span(), - &format!("Expression type not supported"), - )), + _ => Err(syn_err(expr.span(), "Expression type not supported")), } } @@ -2955,7 +2885,7 @@ impl RewriteVariadicsPass { ) -> Result, Error> { let result_path = desugar_path(&expr.path, const_instances)?; expr.path = result_path; - if expr.path.segments.len() == 0 { + if expr.path.segments.is_empty() { // TODO (hme): What would this be? return Ok(None); } @@ -2990,12 +2920,12 @@ impl RewriteVariadicsPass { let method_ident = &expr.method; let method_name = method_ident.to_string(); let self_ty = - match self.rewrite_expr(&mut *expr.receiver, const_instances, variables, None)? { + match self.rewrite_expr(&mut expr.receiver, const_instances, variables, None)? { Some(ty) => ty, None => { return Err(syn_err( expr.receiver.span(), - &format!("Unable to infer receiver type"), + "Unable to infer receiver type", )) } }; @@ -3073,7 +3003,7 @@ impl RewriteVariadicsPass { _ => { return Err(syn_err( expr.func.span(), - &format!("Unexpected function call expression."), + "Unexpected function call expression.", )) } }; @@ -3113,7 +3043,7 @@ impl RewriteVariadicsPass { return_type } }; - self.rewrite_expr(&mut *expr.func, const_instances, variables, None)?; + self.rewrite_expr(&mut expr.func, const_instances, variables, None)?; // println!("rewrite_call {}: maybe_inferred_rtype = {maybe_inferred_rtype:#?}", expr.to_token_stream().to_string()); Ok(maybe_inferred_rtype) } @@ -3137,7 +3067,7 @@ impl RewriteVariadicsPass { pub fn desugar_structure_cgas(item: &ItemStruct) -> Result { let const_instances = ConstInstances::from_generics(&item.generics)?; let rewrite_pass = RewriteVariadicsPass {}; - rewrite_pass.rewrite_struct(&item, &const_instances) + rewrite_pass.rewrite_struct(item, &const_instances) } /// Desugars const generic array syntax in a function definition. @@ -3157,7 +3087,7 @@ pub fn desugar_structure_cgas(item: &ItemStruct) -> Result { pub fn desugar_function_cgas(item: &ItemFn) -> Result { let rewrite_pass = RewriteVariadicsPass {}; let const_instances = ConstInstances::from_generics(&item.sig.generics)?; - rewrite_pass.rewrite_function(&item, &const_instances) + rewrite_pass.rewrite_function(item, &const_instances) } /// Desugars const generic array syntax in an impl block. diff --git a/cutile-macro/src/types.rs b/cutile-macro/src/types.rs index 1cf4d31c..5b617040 100644 --- a/cutile-macro/src/types.rs +++ b/cutile-macro/src/types.rs @@ -472,7 +472,7 @@ pub fn get_variadic_method_data( Some(op_name) => Ok(Some(( op_name, get_variadic_op_data(op_name) - .expect(format!("{op_name} is not a variadic op.").as_str()), + .unwrap_or_else(|| panic!("{op_name} is not a variadic op.")), ))), None => Ok(None), } @@ -1130,7 +1130,7 @@ impl ConstGenericArrayTypeListIterator { impl Iterator for ConstGenericArrayTypeListIterator { type Item = Result, Error>; fn next(&mut self) -> Option { - if self.state.len() == 0 { + if self.state.is_empty() { // First pass should always contain something. for item in &mut self.iterators { match item.next() { @@ -1145,34 +1145,32 @@ impl Iterator for ConstGenericArrayTypeListIterator { } } Some(Ok(self.state.clone())) + } else if self.done { + None } else { - if self.done { - None - } else { - for _i in 0..self.iterators.len() { - // Traverse in reverse to remain consistent with traversal order of individual ConstGenericArrayIterator. - // The traversal is a mixed-radix counter. - // We're done when the most significant position is None. - let i = (self.iterators.len() - 1) - _i; - let iter = &mut self.iterators[i]; - let item: Option = iter.next(); - match item { - Some(item) => { - self.state[i] = item; - break; - } - None => { - if i == 0 { - self.done = true; - return None; - } - self.iterators[i] = iter.renew(); - self.state[i] = self.iterators[i].next().unwrap(); + for _i in 0..self.iterators.len() { + // Traverse in reverse to remain consistent with traversal order of individual ConstGenericArrayIterator. + // The traversal is a mixed-radix counter. + // We're done when the most significant position is None. + let i = (self.iterators.len() - 1) - _i; + let iter = &mut self.iterators[i]; + let item: Option = iter.next(); + match item { + Some(item) => { + self.state[i] = item; + break; + } + None => { + if i == 0 { + self.done = true; + return None; } + self.iterators[i] = iter.renew(); + self.state[i] = self.iterators[i].next().unwrap(); } } - Some(Ok(self.state.clone())) } + Some(Ok(self.state.clone())) } } } diff --git a/cutile-macro/src/validate_dsl_syntax.rs b/cutile-macro/src/validate_dsl_syntax.rs index e429ae0e..e1b2548e 100644 --- a/cutile-macro/src/validate_dsl_syntax.rs +++ b/cutile-macro/src/validate_dsl_syntax.rs @@ -104,10 +104,10 @@ use crate::error::{Error, SpannedError}; // * mut T for unsafe kernels. pub fn validate_entry_point_parameters(item: &ItemFn) -> Result<(), Error> { let (input_types, _output_type) = get_sig_types(&item.sig, None); - for (_i, ty) in input_types.iter().enumerate() { + for ty in input_types.iter() { match ty { Type::Reference(_) => { - let Some(ident) = get_type_ident(&ty) else { + let Some(ident) = get_type_ident(ty) else { return ty.err("Not a supported parameter type."); }; let type_name = ident.to_string(); @@ -136,7 +136,7 @@ pub fn validate_entry_point_parameters(item: &ItemFn) -> Result<(), Error> { _ => { ty.err(&format!( "{} is not a supported parameter type.", - ty.to_token_stream().to_string() + ty.to_token_stream() ))?; } } diff --git a/cutile/src/_core.rs b/cutile/src/_core.rs index e594f304..309a106e 100644 --- a/cutile/src/_core.rs +++ b/cutile/src/_core.rs @@ -311,7 +311,7 @@ pub mod core { pub fn check_partition_access( part: &Partition, index: [i32; N], - ) -> () { + ) { // This is either instantiated, in which case an actual bounds check takes place, // or the check is performed statically and nothing is emitted. // The bounds check is implemented as an assertion. @@ -1340,7 +1340,7 @@ pub mod core { // TODO (hme): Bounds checks. let tensor_token: Token = get_tensor_token(self); let p: Partition = make_partition_view(self, tile, tensor_token); - return p; + p } pub fn partition_permuted<'a, const R: [i32; N], const I: [i32; N]>( &'a self, @@ -1351,7 +1351,7 @@ pub mod core { let tensor_token: Token = get_tensor_token(self); let p: Partition = make_partition_view_permuted(self, tile, dim_map, tensor_token); - return p; + p } pub unsafe fn partition_mut<'a, const R: [i32; N]>( &'a mut self, @@ -1400,7 +1400,7 @@ pub mod core { /// output.store(tile); /// } /// ``` - pub fn store(&mut self, result: Tile) -> () { + pub fn store(&mut self, result: Tile) { store_tile(self, result); } } @@ -1463,7 +1463,7 @@ pub mod core { pub fn set_tensor_token( tensor: &Tensor, token: Token, - ) -> () { + ) { unreachable!() } @@ -1564,9 +1564,9 @@ pub mod core { /// let tile = partition.load([5]); // Load 6th tile (64 elements starting at 5*64) /// ``` pub fn load(&self, index: [i32; N]) -> Tile { - check_partition_access(&self, index); + check_partition_access(self, index); let result: Tile = load_from_view(self, index); - return result; + result } } @@ -3628,10 +3628,7 @@ pub mod core { /// store_tile(tensor, tile); /// ``` #[cuda_tile::variadic_op(N = 6)] - pub fn store_tile( - y: &mut Tensor, - result: Tile, - ) -> () { + pub fn store_tile(y: &mut Tensor, result: Tile) { let tile_shape: Shape = y.shape(); let tensor_token: Token = get_tensor_token(y); let mut y_partition: PartitionMut = @@ -3667,7 +3664,7 @@ pub mod core { let pid: (i32, i32, i32) = get_tile_block_id(); let tile_shape: Shape = y.shape(); let tensor_token: Token = get_tensor_token(x); - let x_partition: Partition = make_partition_view(&x, tile_shape, tensor_token); + let x_partition: Partition = make_partition_view(x, tile_shape, tensor_token); let tile_x: Tile = load_from_view(&x_partition, [pid.0, pid.1]); tile_x } @@ -3710,7 +3707,7 @@ pub mod core { let pid: (i32, i32, i32) = get_tile_block_id(); let tile_shape: Shape = y.shape(); let tensor_token: Token = get_tensor_token(x); - let x_partition: Partition = make_partition_view(&x, tile_shape, tensor_token); + let x_partition: Partition = make_partition_view(x, tile_shape, tensor_token); let tile_x: Tile = load_from_view(&x_partition, [pid.0]); tile_x } diff --git a/cutile/src/api.rs b/cutile/src/api.rs index f04df4c1..40c52d69 100644 --- a/cutile/src/api.rs +++ b/cutile/src/api.rs @@ -489,7 +489,7 @@ pub(crate) fn candle_tensor_to_vec( ) -> (Vec, Vec, Vec) { let shape: Vec = tensor.shape().dims().iter().map(|x| *x as i32).collect(); let strides: Vec = tensor.stride().iter().map(|x| *x as i32).collect(); - let size: usize = tensor.shape().dims().iter().fold(1, |acc, x| acc * x); + let size: usize = tensor.shape().dims().iter().product(); let vec = tensor.reshape((size,)).unwrap().to_vec1().unwrap(); (vec, shape, strides) } diff --git a/cutile/src/error.rs b/cutile/src/error.rs index e9bfea8c..e7d4ce58 100644 --- a/cutile/src/error.rs +++ b/cutile/src/error.rs @@ -82,24 +82,24 @@ impl error::Error for Error {} /// Creates a `Error::Tensor` from the given message string. pub fn tensor_error(err_str: &str) -> Error { - return Error::Tensor(TensorError(err_str.to_string())); + Error::Tensor(TensorError(err_str.to_string())) } /// Returns `Err(Error::Tensor(...))` with the given message string. pub fn tensor_error_result(err_str: &str) -> Result { - return Err(tensor_error(err_str)); + Err(tensor_error(err_str)) } // Kernel Launch /// Creates a `Error::KernelLaunch` from the given message string. pub fn kernel_launch_error(err_str: &str) -> Error { - return Error::KernelLaunch(KernelLaunchError(err_str.to_string())); + Error::KernelLaunch(KernelLaunchError(err_str.to_string())) } /// Returns `Err(Error::KernelLaunch(...))` with the given message string. pub fn kernel_launch_error_result(err_str: &str) -> Result { - return Err(kernel_launch_error(err_str)); + Err(kernel_launch_error(err_str)) } // anyhow diff --git a/cutile/src/tensor.rs b/cutile/src/tensor.rs index ba3cf6ed..34a24ba6 100644 --- a/cutile/src/tensor.rs +++ b/cutile/src/tensor.rs @@ -339,9 +339,9 @@ impl Partition> { } } -impl Into> for Partition { - fn into(self) -> Arc { - Arc::new(self.unpartition()) +impl From> for Arc { + fn from(val: Partition) -> Self { + Arc::new(val.unpartition()) } } diff --git a/cutile/src/tile_kernel.rs b/cutile/src/tile_kernel.rs index 659bd56d..19a5bc0d 100644 --- a/cutile/src/tile_kernel.rs +++ b/cutile/src/tile_kernel.rs @@ -108,7 +108,7 @@ fn write_ir( ) { let filename = format!("{module_name}_{function_name}_{cache_hash_str}.{extension}"); let path = PathBuf::from(dir).join(filename); - fs::write(path.clone(), contents).expect(format!("Failed to write {path:?}").as_str()); // Writes the string as bytes + fs::write(path.clone(), contents).unwrap_or_else(|_| panic!("Failed to write {path:?}")); // Writes the string as bytes println!("IR written to {path:?}"); } @@ -174,7 +174,7 @@ pub fn compile_from_context Vec>( // A hit to the thread local kernel cache returns the compiled function. let func = get_cuda_function(device_id, &key)?; let validator = get_function_validator(device_id, &key)?; - return Ok((func, validator)); + Ok((func, validator)) } else { let gpu_name = get_gpu_name(device_id); // A miss compiles, caches, and returns the compiled function. @@ -277,7 +277,7 @@ pub fn compile_from_context Vec>( ); insert_cuda_function(device_id, &key, (module, function.clone()))?; insert_function_validator(device_id, &key, validator.clone())?; - return Ok((function, validator)); + Ok((function, validator)) } } @@ -314,7 +314,7 @@ pub fn infer_launch_grid( ) -> Result<(u32, u32, u32), Error> { if grid != (0, 0, 0) { // A launch grid was specified. - if inferred_grids.len() > 0 { + if !inferred_grids.is_empty() { validate_grids(grid, inferred_grids).with_context(|| { "Specified launch grid does not match inferred tensor partition grid" })?; @@ -322,7 +322,7 @@ pub fn infer_launch_grid( return Ok(grid); } // Try to infer launch grid. - if inferred_grids.len() == 0 { + if inferred_grids.is_empty() { return kernel_launch_error_result("Launch grid required."); } let grid = inferred_grids[0]; @@ -455,7 +455,7 @@ where inferred_grids: &[(u32, u32, u32)], ) -> Result<(u32, u32, u32), Error> { let grid = self.get_launch_grid(); - infer_launch_grid(grid, &inferred_grids) + infer_launch_grid(grid, inferred_grids) } /// Returns the currently configured launch grid dimensions. fn get_launch_grid(&self) -> (u32, u32, u32);