diff --git a/cutile-compiler/src/bounds.rs b/cutile-compiler/src/bounds.rs index 9b5fa47..8d8ba97 100644 --- a/cutile-compiler/src/bounds.rs +++ b/cutile-compiler/src/bounds.rs @@ -132,7 +132,7 @@ impl Add for Bounds { fn add(self, rhs: Bounds) -> Bounds { let a = self; let b = rhs; - let possible_bounds = vec![ + let possible_bounds = [ a.start + b.start, a.start + b.end, a.end + b.start, @@ -155,7 +155,7 @@ impl Sub for Bounds { fn sub(self, rhs: Bounds) -> Bounds { let a = self; let b = rhs; - let possible_bounds = vec![ + let possible_bounds = [ a.start - b.start, a.start - b.end, a.end - b.start, @@ -178,7 +178,7 @@ impl Mul for Bounds { fn mul(self, rhs: Bounds) -> Bounds { let a = self; let b = rhs; - let possible_bounds = vec![ + let possible_bounds = [ a.start * b.start, a.start * b.end, a.end * b.start, @@ -213,7 +213,7 @@ impl Div for Bounds { (_, 0) => panic!("Division by zero"), (0, _) => panic!("Division by zero"), _ => { - let possible_bounds = vec![ + let possible_bounds = [ a.start / b.start, a.start / b.end, a.end / b.start, @@ -239,7 +239,7 @@ impl Rem for Bounds { let a = self; let b = rhs; // TODO (hme): Verify this one. - let possible_bounds = vec![ + let possible_bounds = [ a.start % b.start, a.start % b.end, a.end % b.start, @@ -266,7 +266,7 @@ pub fn bop_bounds i64>(a: &Bounds, b: &Bounds, f: F if a.is_exact() && b.is_exact() { return Bounds::exact(f(a.start, b.start)); } - let possible_bounds = vec![ + let possible_bounds = [ f(a.start, b.start), f(a.start, b.end), f(a.end, b.start), diff --git a/cutile-compiler/src/compiler/_function.rs b/cutile-compiler/src/compiler/_function.rs index fd8983e..32f911c 100644 --- a/cutile-compiler/src/compiler/_function.rs +++ b/cutile-compiler/src/compiler/_function.rs @@ -64,6 +64,7 @@ struct FunctionParamTypes { } impl<'m> CUDATileFunctionCompiler<'m> { + #[allow(clippy::too_many_arguments)] pub fn new( modules: &'m CUDATileModules, module_name: &str, @@ -127,7 +128,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { // 7. Build stride_args HashMap. let stride_args: HashMap> = stride_args - .into_iter() + .iter() .map(|(k, v)| (k.to_string(), v.to_vec())) .collect::>(); @@ -155,20 +156,19 @@ impl<'m> CUDATileFunctionCompiler<'m> { .collect(); let (entry, validator) = generate_entry_point( modules, - &function, + function, &generic_vars, &stride_args, &spec_args_map, &scalar_hints_map, - &modules.primitives(), + modules.primitives(), &optimization_hints, )?; // 10. Check namespace collision. if modules .functions() - .get(kernel_naming.entry_name().as_str()) - .is_some() + .contains_key(kernel_naming.entry_name().as_str()) { return modules .resolve_span(module_name, &function.span()) @@ -517,14 +517,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { if std::env::var("CUTILE_DEBUG_COMPILER2").is_ok() { eprintln!( "compiler2: lowered entry function body:\n{}", - quote::quote!(#lowered_fn_item).to_string() + quote::quote!(#lowered_fn_item) ); } let return_value = self.compile_block( module, block_id, - &*lowered_fn_item.block, + &lowered_fn_item.block, generic_vars, &mut ctx, None, @@ -582,7 +582,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { None }; let value = self - .compile_expression(module, block_id, &arg, generic_args, ctx, expected)? + .compile_expression(module, block_id, arg, generic_args, ctx, expected)? .ok_or(self.jit_error( &arg.span(), &format!( @@ -606,7 +606,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let rust_ty_str = type_name::(); let rust_ty = syn::parse2::(rust_ty_str.parse()?).unwrap(); let tr_ty = self - .compile_type(&rust_ty, &generic_vars, &HashMap::new())? + .compile_type(&rust_ty, generic_vars, &HashMap::new())? .ok_or(self.jit_error(&rust_ty.span(), "failed to compile constant"))?; self.compile_constant_from_exact_bounds(module, block_id, bounds, tr_ty) } @@ -652,7 +652,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; let Some(const_ty_str) = get_cuda_tile_element_type_from_rust_primitive_str( &type_inst.rust_element_instance_ty, - &self.modules.primitives(), + self.modules.primitives(), ) else { return self .jit_error_result(&tr_ty.rust_ty.span(), "failed to compile constant value"); diff --git a/cutile-compiler/src/compiler/_module.rs b/cutile-compiler/src/compiler/_module.rs index bab05f6..5f03a70 100644 --- a/cutile-compiler/src/compiler/_module.rs +++ b/cutile-compiler/src/compiler/_module.rs @@ -447,8 +447,10 @@ fn trait_impl_matches_call( return false; } - let mut ctx = TraitMatchCtx::default(); - ctx.caller_array_params = generic_vars.inst_array.clone(); + let mut ctx = TraitMatchCtx{ + caller_array_params: generic_vars.inst_array.clone(), + ..Default::default() + }; collect_generics_for_trait_match(&item_impl.generics, &mut ctx); collect_generics_for_trait_match(&impl_method.sig.generics, &mut ctx); @@ -830,7 +832,7 @@ impl CUDATileModules { if matches!( generic_vars.instantiate_type(receiver_rust_ty, self.primitives())?, TypeInstance::ElementType(_) - ) { + ) { if let Some(impls) = self .name_resolver .trait_impls() diff --git a/cutile-compiler/src/compiler/_type.rs b/cutile-compiler/src/compiler/_type.rs index 17c209e..45371c5 100644 --- a/cutile-compiler/src/compiler/_type.rs +++ b/cutile-compiler/src/compiler/_type.rs @@ -131,6 +131,6 @@ fn rust_scalar_type(name: &str) -> Option { fn extract_pointer_element_type(ty_str: &str) -> Option { let after_mut = ty_str.split("mut").nth(1)?; let trimmed = after_mut.trim(); - let end = trimmed.find(|c: char| c == ',' || c == '>' || c == ' ')?; + let end = trimmed.find([',', '>', ' '])?; Some(trimmed[..end].to_string()) } diff --git a/cutile-compiler/src/compiler/_value.rs b/cutile-compiler/src/compiler/_value.rs index f6e7cc7..2cc1bad 100644 --- a/cutile-compiler/src/compiler/_value.rs +++ b/cutile-compiler/src/compiler/_value.rs @@ -364,9 +364,7 @@ impl TileRustValue { } pub fn take_type_meta_field(self, name: &str) -> Option { - let Some(mut type_meta) = self.type_meta else { - return None; - }; + let mut type_meta = self.type_meta?; type_meta.fields.remove(name) } diff --git a/cutile-compiler/src/compiler/compile_assume.rs b/cutile-compiler/src/compiler/compile_assume.rs index dbd86e8..2cfdb64 100644 --- a/cutile-compiler/src/compiler/compile_assume.rs +++ b/cutile-compiler/src/compiler/compile_assume.rs @@ -56,7 +56,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { "expected a simple function path for assume invocation", ); }; - let ident = get_ident_from_path_expr(&path_expr); + let ident = get_ident_from_path_expr(path_expr); let compiler_op_function = ident.to_string(); let mut args = self.compile_call_args(module, block_id, &call_expr.args, generic_vars, ctx)?; @@ -75,7 +75,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { "the first argument to `assume` must produce a value", ); }; - Ok(self.compile_value_assumption( + self.compile_value_assumption( module, block_id, val_value, @@ -83,10 +83,11 @@ impl<'m> CUDATileFunctionCompiler<'m> { &predicate_args, return_type, &call_expr.span(), - )?) + ) } /// Generates tile-ir assume operation with appropriate predicate attribute. + #[allow(clippy::too_many_arguments)] pub(crate) fn compile_value_assumption( &self, module: &mut Module, diff --git a/cutile-compiler/src/compiler/compile_binary_op.rs b/cutile-compiler/src/compiler/compile_binary_op.rs index 1852b3d..ea68353 100644 --- a/cutile-compiler/src/compiler/compile_binary_op.rs +++ b/cutile-compiler/src/compiler/compile_binary_op.rs @@ -138,6 +138,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { )?)) } + #[allow(clippy::too_many_arguments)] pub fn compile_binary_op_from_values( &self, module: &mut Module, @@ -156,8 +157,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { &format!( "binary `{:?}` requires operands of the same type, but got `{}` and `{}`", tile_rust_arithmetic_op, - lhs.ty.rust_ty.to_token_stream().to_string(), - rhs.ty.rust_ty.to_token_stream().to_string() + lhs.ty.rust_ty.to_token_stream(), + rhs.ty.rust_ty.to_token_stream() ), ); } @@ -180,14 +181,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { let operand_type = lhs.ty.clone(); let operand_rust_ty = &operand_type.rust_ty; let Some(operand_rust_element_type) = - operand_type.get_instantiated_rust_element_type(&self.modules.primitives()) + operand_type.get_instantiated_rust_element_type(self.modules.primitives()) else { return self.jit_error_result( span, &format!( "unable to determine element type for `{:?}` on `{}`", tile_rust_arithmetic_op, - operand_type.rust_ty.to_token_stream().to_string() + operand_type.rust_ty.to_token_stream() ), ); }; @@ -196,7 +197,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { span, &format!( "type `{}` cannot be used with binary `{:?}`", - operand_type.rust_ty.to_token_stream().to_string(), + operand_type.rust_ty.to_token_stream(), tile_rust_arithmetic_op ), ); @@ -206,7 +207,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let operand_result_ty = module.value_type(lhs_value).clone(); let Some(operand_cuda_tile_element_type) = - operand_type.get_cuda_tile_element_type(&self.modules.primitives())? + operand_type.get_cuda_tile_element_type(self.modules.primitives())? else { return self.jit_error_result( span, @@ -414,7 +415,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { span, &format!( "Binary operation is not implemented for {}", - operand_rust_ty.to_token_stream().to_string() + operand_rust_ty.to_token_stream() ), ); } @@ -426,7 +427,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { // Try to infer from lhs/rhs. if is_cmp { let bool_ty = syn::parse2::("bool".parse().unwrap()).unwrap(); - self.compile_type(&bool_ty, &generic_vars, &HashMap::new())? + self.compile_type(&bool_ty, generic_vars, &HashMap::new())? .unwrap() } else { operand_type @@ -448,17 +449,17 @@ impl<'m> CUDATileFunctionCompiler<'m> { } else { None }; - if let Some(bounds) = &op_bounds { + if let Some(bounds) = op_bounds { if bounds.is_exact() { // The lower/upper bounds are equivalent — emit a constant // instead. The op allocated above becomes dead (not appended // to any block). - return Ok(self.compile_constant_from_exact_bounds( + return self.compile_constant_from_exact_bounds( module, block_id, - bounds.clone(), + bounds, return_type, - )?); + ); } } diff --git a/cutile-compiler/src/compiler/compile_block.rs b/cutile-compiler/src/compiler/compile_block.rs index 0c00b6d..eb2c749 100644 --- a/cutile-compiler/src/compiler/compile_block.rs +++ b/cutile-compiler/src/compiler/compile_block.rs @@ -97,7 +97,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ), ); } - for (pat, value) in tuple.elems.iter().zip(elements.into_iter()) { + for (pat, value) in tuple.elems.iter().zip(elements) { self.bind_pattern_value(pat, value, inherited_mutability, ctx)?; } Ok(()) @@ -152,7 +152,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ), ); } - for (pat, value) in pats.into_iter().zip(elements.into_iter()) { + for (pat, value) in pats.into_iter().zip(elements) { self.bind_pattern_value(pat, value, inherited_mutability, ctx)?; } } @@ -206,7 +206,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ), ); } - for (pat, value) in tuple_struct.elems.iter().zip(elements.into_iter()) { + for (pat, value) in tuple_struct.elems.iter().zip(elements) { self.bind_pattern_value(pat, value, inherited_mutability, ctx)?; } Ok(()) @@ -260,7 +260,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(value) = self.compile_expression( module, block_id, - &*init.expr, + &init.expr, generic_args, ctx, init_ty, @@ -279,10 +279,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { Stmt::Item(item) => { match item { Item::Const(const_item) => { - let binding_name: Option = + let binding_name = Some(const_item.ident.to_string()); - let ct_ty: Option = self.compile_type( - &*const_item.ty, + let ct_ty = self.compile_type( + &const_item.ty, generic_args, &HashMap::new(), )?; @@ -295,7 +295,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { match self.compile_expression( module, block_id, - &*const_item.expr, + &const_item.expr, generic_args, ctx, ct_ty, @@ -309,7 +309,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &const_item.expr.span(), &format!( "failed to compile const initializer: `{}`", - const_item.expr.to_token_stream().to_string() + const_item.expr.to_token_stream() ), ) } @@ -375,7 +375,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { match self.compile_expression( module, block_id, - &*assign_expr.right, + &assign_expr.right, generic_args, ctx, rhs_ty, @@ -395,7 +395,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { return_value = self.compile_expression( module, block_id, - &*expr, + expr, generic_args, ctx, return_type.clone(), @@ -410,7 +410,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { return_value = self.compile_expression( module, block_id, - &*expr, + expr, generic_args, ctx, return_type.clone(), @@ -419,7 +419,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { self.compile_expression( module, block_id, - &*expr, + expr, generic_args, ctx, None, @@ -474,7 +474,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } Some(BlockTerminator::Return) => { self.resolve_span(&block_expr.span()) - .jit_assert(loop_carry_var_names.len() == 0, "unexpected state")?; + .jit_assert(loop_carry_var_names.is_empty(), "unexpected state")?; if return_value.is_some() { return self.jit_error_result( &block_expr.span(), diff --git a/cutile-compiler/src/compiler/compile_cuda_tile_op.rs b/cutile-compiler/src/compiler/compile_cuda_tile_op.rs index 399f6eb..7e00762 100644 --- a/cutile-compiler/src/compiler/compile_cuda_tile_op.rs +++ b/cutile-compiler/src/compiler/compile_cuda_tile_op.rs @@ -286,6 +286,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } } + #[allow(clippy::too_many_arguments)] fn offset_nested_mutable_indices( &self, module: &mut Module, @@ -393,18 +394,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { } }; - let cuda_tile_op_params = op_attrs - .parse_string_arr("params") - .unwrap_or_else(|| vec![]); + let cuda_tile_op_params = op_attrs.parse_string_arr("params").unwrap_or_default(); let cuda_tile_op_attribute_params = op_attrs .parse_string_arr("attribute_params") - .unwrap_or_else(|| vec![]); - let cuda_tile_op_hint_params = op_attrs - .parse_string_arr("hint_params") - .unwrap_or_else(|| vec![]); + .unwrap_or_default(); + let cuda_tile_op_hint_params = op_attrs.parse_string_arr("hint_params").unwrap_or_default(); let cuda_tile_op_named_attributes = op_attrs .parse_string_arr("named_attributes") - .unwrap_or_else(|| vec![]); + .unwrap_or_default(); let cuda_tile_op_static_params = op_attrs .parse_string_arr("static_params") .unwrap_or_default(); @@ -453,6 +450,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ) } + #[allow(clippy::too_many_arguments)] fn try_compile_cuda_tile_special_op( &self, module: &mut Module, @@ -657,7 +655,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; let element_type = tensor_value .ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) .ok_or_else(|| { self.jit_error( &call_expr.args[0].span(), @@ -911,17 +909,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { let (op_id, results) = op_builder.build(module); append_op(module, block_id, op_id); - let mut values = vec![]; - values.push(TileRustValue::new_structured_type( - results[0], - tile_elem_ty, - None, - )); - values.push(TileRustValue::new_primitive( - results[1], - token_elem_ty, - None, - )); + let values = vec![ + TileRustValue::new_structured_type(results[0], tile_elem_ty, None), + TileRustValue::new_primitive(results[1], token_elem_ty, None), + ]; Ok(Some(TileRustValue::new_compound(values, return_type_outer))) } @@ -1151,7 +1142,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let elem_ty_prefix = ptr_value .ty - .get_cuda_tile_element_type_prefix(&self.modules.primitives())?; + .get_cuda_tile_element_type_prefix(self.modules.primitives())?; let atomic_mode = AtomicMode::new(mode.as_str(), elem_ty_prefix)? as i64; let mut operands = vec![ptrs, arg]; @@ -1201,17 +1192,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { ) .build(module); append_op(module, block_id, op_id); - let mut values = vec![]; - values.push(TileRustValue::new_structured_type( - results[0], - tile_elem_ty, - None, - )); - values.push(TileRustValue::new_primitive( - results[1], - token_elem_ty, - None, - )); + let values = vec![ + TileRustValue::new_structured_type(results[0], tile_elem_ty, None), + TileRustValue::new_primitive(results[1], token_elem_ty, None), + ]; Ok(Some(TileRustValue::new_compound(values, return_type_outer))) } @@ -1369,17 +1353,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { ) .build(module); append_op(module, block_id, op_id); - let mut values = vec![]; - values.push(TileRustValue::new_structured_type( - results[0], - tile_elem_ty, - None, - )); - values.push(TileRustValue::new_primitive( - results[1], - token_elem_ty, - None, - )); + let values = vec![ + TileRustValue::new_structured_type(results[0], tile_elem_ty, None), + TileRustValue::new_primitive(results[1], token_elem_ty, None), + ]; Ok(Some(TileRustValue::new_compound(values, return_type_outer))) } @@ -1519,6 +1496,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { Ok(hint_params) } + #[allow(clippy::too_many_arguments)] fn compile_load_view_tko( &self, module: &mut Module, @@ -1637,7 +1615,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .result(tile_result_ir_ty) .result(token_result_ir_ty) .operands(all_operands.iter().copied()) - .attrs(opt_hint_attrs.into_iter()) + .attrs(opt_hint_attrs) .attr( "memory_ordering_semantics", Attribute::i32(memory_ordering_value), @@ -1664,6 +1642,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ))) } + #[allow(clippy::too_many_arguments)] fn compile_store_view_tko( &self, module: &mut Module, @@ -1777,7 +1756,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { OpBuilder::new(Opcode::StoreViewTko, self.ir_location(&call_expr.span())) .result(token_result_ir_ty) .operands(all_operands.iter().copied()) - .attrs(opt_hint_attrs.into_iter()) + .attrs(opt_hint_attrs) .attr( "memory_ordering_semantics", Attribute::i32(memory_ordering_value), @@ -1877,6 +1856,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ))) } + #[allow(clippy::too_many_arguments)] fn compile_shape_query_op( &self, module: &mut Module, @@ -1982,7 +1962,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .unwrap(); let elem_ty_str = operand_value .ty - .get_cuda_tile_element_type(&self.modules.primitives())? + .get_cuda_tile_element_type(self.modules.primitives())? .unwrap(); let elem_ir_ty = super::_type::make_scalar_tile_type(&elem_ty_str) .expect("failed to build scalar tile type for reduce element"); @@ -2064,7 +2044,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { result_value.value.unwrap() } else { let is_float = - super::_type::scalar_from_name(&elem_ty_str).map_or(false, |s| s.is_float()); + super::_type::scalar_from_name(&elem_ty_str).is_some_and(|s| s.is_float()); let add_opcode = if is_float { Opcode::AddF } else { Opcode::AddI }; let mut add_op_builder = OpBuilder::new(add_opcode, self.ir_location(&call_expr.span())) @@ -2142,7 +2122,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .expect("failed to convert scan result type"); let elem_ty_str = operand_value .ty - .get_cuda_tile_element_type(&self.modules.primitives())? + .get_cuda_tile_element_type(self.modules.primitives())? .unwrap(); let elem_ir_ty = super::_type::make_scalar_tile_type(&elem_ty_str) .expect("failed to build scalar tile type for scan element"); @@ -2211,7 +2191,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { result_value.value.unwrap() } else { let is_float = - super::_type::scalar_from_name(&elem_ty_str).map_or(false, |s| s.is_float()); + super::_type::scalar_from_name(&elem_ty_str).is_some_and(|s| s.is_float()); let add_opcode = if is_float { Opcode::AddF } else { Opcode::AddI }; let mut add_op_builder = OpBuilder::new(add_opcode, self.ir_location(&call_expr.span())) @@ -2278,6 +2258,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } /// General-purpose op compilation for CUDA Tile dialect operations. + #[allow(clippy::too_many_arguments)] fn compile_general_op( &self, module: &mut Module, @@ -2306,17 +2287,11 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; let return_type = if return_type.is_none() { - match rust_function_name.as_str() { - "constant" => { - return self.jit_error_result( - &call_expr.span(), - &format!( - "Return type required for {}", - call_expr.to_token_stream().to_string() - ), - ) - } - _ => {} + if rust_function_name == "constant" { + return self.jit_error_result( + &call_expr.span(), + &format!("Return type required for {}", call_expr.to_token_stream()), + ); } self.derive_type( module, @@ -2352,9 +2327,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { let field_meta_expr_parts = field_meta_expr_str.split(".").collect::>(); let field_meta_expr_param = field_meta_expr_parts[0]; let mut succeeded = false; - for i in 0..param_names.len() { - if param_names[i] == field_meta_expr_param { - let call_expr_arg = &call_expr.args[i]; + for (param_name, call_expr_arg) in param_names.iter().zip(call_expr.args.iter()) { + if param_name == field_meta_expr_param { let call_expr_arg_str = call_expr_arg.to_token_stream().to_string(); let final_expr_str = field_meta_expr_str.replace(field_meta_expr_param, &call_expr_arg_str); @@ -2387,65 +2361,61 @@ impl<'m> CUDATileFunctionCompiler<'m> { let mut operand_lengths: Vec = vec![]; let mut op_operands: Vec = Vec::new(); let mut compiled_args: Vec = Vec::new(); - for i in 0..cuda_tile_op_params.len() { - let call_expr_arg = &call_expr.args[i]; + for (op_param, call_expr_arg) in cuda_tile_op_params.iter().zip(call_expr.args.iter()) { let call_expr_arg_str = call_expr_arg.to_token_stream().to_string(); let op_arg = self.compile_expression(module, block_id, call_expr_arg, generic_args, ctx, None)?; if op_arg.is_none() { - return self - .jit_error_result(&call_expr.args[i].span(), "Failed to compile op arg"); + return self.jit_error_result(&call_expr_arg.span(), "Failed to compile op arg"); } let op_arg = op_arg.unwrap(); compiled_args.push(op_arg.clone()); - let op_param = &cuda_tile_op_params[i]; let mut arg_values: Vec = vec![]; - if op_arg.value.is_some() { - arg_values.push(op_arg.value.clone().unwrap()); - } else if op_arg.fields.is_some() { - let fields = op_arg.fields.as_ref().unwrap(); + if let Some(value) = op_arg.value { + arg_values.push(value); + } else if let Some(fields) = op_arg.fields { let op_path = op_param.split(".").collect::>(); if op_path.len() <= 1 { - return self.jit_error_result(&call_expr.args[i].span(), &format!("Field expression required for struct param {call_expr_arg_str}, got {op_param}")); + return self.jit_error_result(&call_expr_arg.span(), &format!("Field expression required for struct param {call_expr_arg_str}, got {op_param}")); } - let field = *op_path.last().clone().unwrap(); + let field = *op_path.last().unwrap(); match fields.get(field) { Some(field_value) => { - if field_value.value.is_some() { - arg_values.push(field_value.value.clone().unwrap()); - } else if field_value.values.is_some() { - for value in field_value.values.as_ref().unwrap().iter() { - let Some(v) = value.value.clone() else { - return self.jit_error_result(&call_expr.args[i].span(), &format!("Unexpected nested array {op_param} for {call_expr_arg_str}")); + if let Some(value) = field_value.value { + arg_values.push(value); + } else if let Some(values) = &field_value.values { + for value in values { + let Some(v) = value.value else { + return self.jit_error_result(&call_expr_arg.span(), &format!("Unexpected nested array {op_param} for {call_expr_arg_str}")); }; arg_values.push(v); } } else if field_value.fields.is_some() { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!( "Unexpected nested struct {op_param} for {call_expr_arg_str}" ), ); } else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Unexpected op param {op_param} for {call_expr_arg_str}"), ); } } None => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Failed to access field {op_param} for {call_expr_arg_str}"), ) } } - } else if op_arg.values.is_some() { - for value in op_arg.values.as_ref().unwrap().iter() { - let Some(v) = value.value.clone() else { + } else if let Some(values) = op_arg.values.as_ref() { + for value in values { + let Some(v) = value.value else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Unexpected nested array {op_param} for {call_expr_arg_str}"), ); }; @@ -2453,7 +2423,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } } else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Unexpected op param {op_param} for {call_expr_arg_str}"), ); } @@ -2467,16 +2437,16 @@ impl<'m> CUDATileFunctionCompiler<'m> { let (attr_name, attr_value) = (name_attr_split[0], name_attr_split[1]); if attr_name.starts_with("signedness") && attr_value == "inferred_signedness" { let elem_ty = compiled_args - .get(0) + .first() .and_then(|arg| { arg.ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) }) .expect("Failed to get element type for signedness inference."); for arg in &compiled_args { let arg_elem_ty = arg .ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) .expect("Operand types are not all equivalent."); if arg_elem_ty != elem_ty { return self.jit_error_result(&call_expr.span(), &format!("Element type mismatch for signedness inference: expected {elem_ty}, got {arg_elem_ty}")); @@ -2491,7 +2461,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { // Resolve static_params: ZST marker types -> tile-ir attributes. let resolved_static_attrs = resolve_static_params(cuda_tile_op_static_params, call_expr, fn_item) - .map_err(|e| JITError::Generic(e))?; + .map_err(JITError::Generic)?; for attr_str in &resolved_static_attrs { if let Some((name, val_str)) = attr_str.split_once('=') { let name = name.trim(); @@ -2558,7 +2528,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let mut cuda_tile_op_attr_params_iter = cuda_tile_op_attribute_params.iter(); let mut maybe_next_attr_param = cuda_tile_op_attr_params_iter.next(); let fn_params = get_sig_param_names(&fn_item.sig); - for i in 0..fn_params.len() { + for (fn_param, call_expr_arg) in fn_params.iter().zip(call_expr.args.iter()) { if maybe_next_attr_param.is_none() { break; } @@ -2573,11 +2543,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { let (attr_id, attr_ty): (&str, &str) = (op_attr[0], op_attr[1]); match attr_ty { "array" => { - if attr_id != fn_params[i] { + if attr_id != fn_param { continue; } maybe_next_attr_param = cuda_tile_op_attr_params_iter.next(); - let call_expr_arg = &call_expr.args[i]; let call_expr_arg_str = call_expr_arg.to_token_stream().to_string(); let op_arg = self.compile_expression( module, @@ -2589,7 +2558,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { )?; if op_arg.is_none() { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Failed to compile attribute arg for {call_expr_arg_str}"), ); } @@ -2600,44 +2569,39 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(cga) = get_cga_from_type(&op_arg.ty.rust_ty, generic_args) else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Failed to build attribute", ); }; - attrs.push(( - attr_id.to_string(), - Attribute::DenseI32Array( - cga.iter().map(|&x| x as i32).collect(), - ), - )); + attrs.push((attr_id.to_string(), Attribute::DenseI32Array(cga))); } _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Attribute type not implemented.", ) } }, _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Unexpected call arg {call_expr_arg_str} for {next_attr}"), ) } } } "dense" => { - if attr_id != fn_params[i] { + if attr_id != fn_param { continue; } - let (lit_value, _lit_ty_name) = match &call_expr.args[i] { + let (lit_value, _lit_ty_name) = match &call_expr_arg { Expr::Lit(lit_expr) => match &lit_expr.lit { Lit::Bool(b) => (b.value.to_string(), "i1".to_string()), Lit::Int(i) => (i.base10_digits().to_string(), "i32".to_string()), Lit::Float(f) => (f.base10_digits().to_string(), "f32".to_string()), _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Constant not supported", ) } @@ -2645,7 +2609,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { Expr::Unary(unary_expr) => { let UnOp::Neg(_) = unary_expr.op else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Only unary negation is supported for constant values", ); }; @@ -2659,14 +2623,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { } _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Unsupported literal type for negation", ) } }, _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Only literal negation is supported for constant values", ) } @@ -2702,7 +2666,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { "E" => (get_const_hex(ty.as_str(), "e")?, ty.clone()), _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Constant not supported", ) } @@ -2711,14 +2675,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { } _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Unsupported expression for named attribute.", ) } }; // Build a DenseElements attribute from the literal value. let elem_ty_str = return_type - .get_cuda_tile_element_type(&self.modules.primitives())? + .get_cuda_tile_element_type(self.modules.primitives())? .unwrap_or("i32".to_string()); let result_ir_ty = super::_type::scalar_from_name(&elem_ty_str) .map(|sc| { @@ -2749,18 +2713,18 @@ impl<'m> CUDATileFunctionCompiler<'m> { )); } "rounding" => { - if attr_id != fn_params[i] { + if attr_id != fn_param { continue; } maybe_next_attr_param = cuda_tile_op_attr_params_iter.next(); - let rounding_mode_str = match &call_expr.args[i] { + let rounding_mode_str = match &call_expr_arg { Expr::Lit(ExprLit { lit: Lit::Str(lit_str), .. }) => lit_str.value(), _ => { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), "Rounding mode must be a string literal.", ) } @@ -2774,7 +2738,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ]; if !VALID_MODES.contains(&rounding_mode_str.as_str()) { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!( "Invalid rounding mode: \"{}\". Valid values are: {}", rounding_mode_str, @@ -2793,40 +2757,40 @@ impl<'m> CUDATileFunctionCompiler<'m> { attrs.push(int_attr(attr_id, 1)); } "integer" => { - if attr_id != fn_params[i] { + if attr_id != fn_param { continue; } maybe_next_attr_param = cuda_tile_op_attr_params_iter.next(); let op_arg = self.compile_expression( module, block_id, - &call_expr.args[i], + call_expr_arg, generic_args, ctx, None, )?; if op_arg.is_none() { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Failed to compile integer attribute {attr_id}"), ); } let op_arg = op_arg.unwrap(); if op_arg.value.is_none() { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Integer attribute {attr_id} must be a value"), ); } if let Some(bounds) = op_arg.bounds { if bounds.is_exact() { - attrs.push(int_attr(attr_id, bounds.start as i64)); + attrs.push(int_attr(attr_id, bounds.start)); } else { - return self.jit_error_result(&call_expr.args[i].span(), &format!("Integer attribute {attr_id} must be a constant value, got bounds: {bounds:?}")); + return self.jit_error_result(&call_expr_arg.span(), &format!("Integer attribute {attr_id} must be a constant value, got bounds: {bounds:?}")); } } else { return self.jit_error_result( - &call_expr.args[i].span(), + &call_expr_arg.span(), &format!("Integer attribute {attr_id} must be a constant value"), ); } @@ -2854,7 +2818,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let mut op_builder = OpBuilder::new(opcode, self.ir_location(&call_expr.span())) .operands(op_operands.iter().copied()) - .attrs(attrs.into_iter()); + .attrs(attrs); if function_returns(fn_item) { match return_type.kind { @@ -2884,7 +2848,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { if let Type::Tuple(tuple_type) = &return_type.rust_ty { let mut elem_types = vec![]; for elem in &tuple_type.elems { - let elem_ty = self.compile_type(&elem, generic_args, &HashMap::new())?; + let elem_ty = self.compile_type(elem, generic_args, &HashMap::new())?; if elem_ty.is_none() { return self.jit_error_result(&call_expr.span(), "failed to compile type"); } let elem_ty = elem_ty.unwrap(); if elem_ty.tile_ir_ty.is_none() { return self.jit_error_result(&call_expr.span(), "failed to compile tile type"); } @@ -2920,7 +2884,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } } Ok(Some(TileRustValue::new_compound(values, return_type))) - } else { self.jit_error_result(&call_expr.span(), &format!("operations that return multiple values must use a tuple return type, got `{}`", return_type.rust_ty.to_token_stream().to_string())) } + } else { self.jit_error_result(&call_expr.span(), &format!("operations that return multiple values must use a tuple return type, got `{}`", return_type.rust_ty.to_token_stream())) } } Kind::Struct => self.jit_error_result(&call_expr.span(), "this operation cannot return a struct; only scalar and structured (tile) types are supported as return types"), Kind::String => self.jit_error_result(&call_expr.span(), "this operation cannot return a string; only scalar and structured (tile) types are supported as return types"), @@ -2962,7 +2926,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(val) = self.compile_expression( module, block_id, - &expr, + expr, generic_vars, ctx, None, @@ -2987,10 +2951,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { } let re_repl = Regex::new(r"\{\}").unwrap(); for (i, element_ty) in element_type_instance.into_iter().enumerate() { - let rust_element_type_instance = element_ty.expect( - format!("failed to determine element type for print argument {}", i) - .as_str(), - ); + let rust_element_type_instance = element_ty.unwrap_or_else(|| { + panic!("failed to determine element type for print argument {}", i) + }); if !re_repl.is_match(&str_literal) { return self.jit_error_result( &mac.span(), @@ -3000,7 +2963,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(tile_element_type_instance) = get_cuda_tile_element_type_from_rust_primitive_str( &rust_element_type_instance, - &self.modules.primitives(), + self.modules.primitives(), ) else { return self.jit_error_result(&mac.span(), &format!("unable to determine tile element type for `{rust_element_type_instance}`")); diff --git a/cutile-compiler/src/compiler/compile_expression.rs b/cutile-compiler/src/compiler/compile_expression.rs index 5f0c1f7..aa26f94 100644 --- a/cutile-compiler/src/compiler/compile_expression.rs +++ b/cutile-compiler/src/compiler/compile_expression.rs @@ -417,8 +417,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { append_op(module, block_id, op_id); let mut values = Vec::with_capacity(rank); - for axis in 0..rank { - let mut value = TileRustValue::new_primitive(results[axis], i32_ty.clone(), None); + for (axis, result) in results.into_iter().enumerate().take(rank) { + let mut value = TileRustValue::new_primitive(result, i32_ty.clone(), None); let parent_axis = pv.dim_map.get(axis).copied().ok_or_else(|| { self.jit_error( span, @@ -562,7 +562,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .ok_or_else(|| self.jit_error(span, "failed to compile Dim type"))?, }; let dim_origin = value.dim_origin.clone(); - let bounds = value.bounds.clone(); + let bounds = value.bounds; let mut fields = BTreeMap::new(); fields.insert("size".to_string(), value); let mut dim = TileRustValue::new_struct(fields, dim_type); @@ -1005,7 +1005,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(dim_value) = ctx.vars.get(&dim_name).cloned() else { return Ok(false); }; - if !get_type_ident(&dim_value.ty.rust_ty).is_some_and(|ident| ident == "Dim") { + if get_type_ident(&dim_value.ty.rust_ty).is_none_or(|ident| ident != "Dim") { return Ok(false); } let Some(dim_origin) = Self::value_dim_origin(&dim_value) else { @@ -1056,12 +1056,12 @@ impl<'m> CUDATileFunctionCompiler<'m> { let i32_type = self .compile_type(&parse_quote!(i32), generic_vars, &HashMap::new())? .ok_or_else(|| self.jit_error(&for_expr.span(), "failed to compile i32 type"))?; - let upper_bounds = dim_value.bounds.clone().or_else(|| { + let upper_bounds = dim_value.bounds.or_else(|| { dim_value .fields .as_ref() .and_then(|fields| fields.get("size")) - .and_then(|size| size.bounds.clone()) + .and_then(|size| size.bounds) }); let mut iterand_val = if let Some(bounds) = upper_bounds { let upper = bounds.end - 1; @@ -1249,8 +1249,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { "failed to compile range end expression", ); }; - let iterand_lower_const = start_val.bounds.clone(); - let iterand_upper_const = end_val.bounds.clone(); + let iterand_lower_const = start_val.bounds; + let iterand_upper_const = end_val.bounds; let lower_bound = start_val.value.unwrap(); let upper_bound = end_val.value.unwrap(); let step_value = if let Some(step_expr) = maybe_step_expr { @@ -1371,7 +1371,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { module, loop_block_id, &for_expr.body, - &generic_vars, + generic_vars, &mut for_variables, return_type, )?; @@ -1434,7 +1434,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { }) = self.compile_expression( module, loop_block_id, - &*while_expr.cond, + &while_expr.cond, generic_vars, &mut loop_variables, return_type.clone(), @@ -1577,7 +1577,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(conditional_val) = self.compile_expression( module, block_id, - &*if_expr.cond, + &if_expr.cond, generic_vars, ctx, None, @@ -1797,7 +1797,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { module, block_id, &block_expr.block, - &generic_vars, + generic_vars, &mut inner_block_vars, return_type, )?; @@ -1822,7 +1822,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { module, block_id, &block_expr.block, - &generic_vars, + generic_vars, &mut inner_block_vars, return_type, )?; @@ -1894,14 +1894,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; fields.insert(field_name, field_value); } - return Ok(Some(TileRustValue::new_struct(fields, return_type))); + Ok(Some(TileRustValue::new_struct(fields, return_type))) } Expr::Reference(ref_expr) => { // TODO (hme): Check whether all expr types can be supported. let return_type = match return_type { Some(ty) => { if let syn::Type::Reference(ref_type) = ty.rust_ty { - self.compile_type(&*ref_type.elem, generic_vars, &HashMap::new())? + self.compile_type(&ref_type.elem, generic_vars, &HashMap::new())? } else { None } @@ -1941,12 +1941,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { ctx, return_type, )?), - _ => { - return self.jit_error_result( - &ref_expr.span(), - "this reference expression form is not supported", - ) - } + _ => self.jit_error_result( + &ref_expr.span(), + "this reference expression form is not supported", + ), } } Expr::Tuple(tuple_expr) => { @@ -1982,7 +1980,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { match self.compile_expression( module, block_id, - &elem, + elem, generic_vars, ctx, elem_return_type, @@ -2034,14 +2032,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { Some(return_type) => { match &return_type.rust_ty { syn::Type::Array(array_type) => self.compile_type( - &*array_type.elem, + &array_type.elem, generic_vars, &HashMap::new(), )?, syn::Type::Slice(slice) => { // TODO (hme): Confirm this is right. self.compile_type( - &*slice.elem, + &slice.elem, generic_vars, &HashMap::new(), )? @@ -2051,7 +2049,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &elem.span(), &format!( "unexpected element type `{}`", - return_type.rust_ty.to_token_stream().to_string() + return_type.rust_ty.to_token_stream() ), ) } @@ -2062,7 +2060,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { match self.compile_expression( module, block_id, - &elem, + elem, generic_vars, ctx, elem_ty, @@ -2076,8 +2074,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { } }; } - let return_type = if return_type.is_none() { - if values.len() == 0 { + let return_type = if let Some(return_type) = return_type { + return_type + } else { + if values.is_empty() { return self.jit_error_result( &array_expr.span(), "unable to infer type for empty array; add a type annotation", @@ -2107,8 +2107,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { ) } } - } else { - return_type.unwrap() }; Ok(Some(TileRustValue::new_compound(values, return_type))) } @@ -2172,8 +2170,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { ); }; let values: Vec = vec![value; len]; - let return_type = if return_type.is_none() { - if values.len() == 0 { + let return_type = if let Some(return_type) = return_type { + return_type + } else { + if values.is_empty() { return self.jit_error_result( &repeat_expr.span(), "unable to infer type for zero-length repeat expression; add a type annotation", @@ -2203,8 +2203,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { ) } } - } else { - return_type.unwrap() }; Ok(Some(TileRustValue::new_compound(values, return_type))) } @@ -2301,18 +2299,18 @@ impl<'m> CUDATileFunctionCompiler<'m> { // 4. Single-segment, not a local, not in resolver — error. let suggestion = self.modules.name_resolver.find_all_definitions(&var_name); if suggestion.is_empty() { - return self.jit_error_result( + self.jit_error_result( &path_expr.span(), &format!("undefined variable `{var_name}`"), - ); + ) } else { - return self.jit_error_result( + self.jit_error_result( &path_expr.span(), &format!( "undefined variable `{var_name}` (did you mean the function defined in {}?)", suggestion.join(", ") ), - ); + ) } } Expr::Call(call_expr) => { @@ -2330,9 +2328,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { return_type, ); } - let ident = get_ident_from_path_expr(&path_expr); + let ident = get_ident_from_path_expr(path_expr); // Handle Some(...) specially - it's a Rust Option constructor, not a function call - if ident.to_string() == "Some" { + if ident == "Some" { if call_expr.args.len() != 1 { return self.jit_error_result( &call_expr.span(), @@ -2375,9 +2373,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { option_type, ))); } - if let Some(_) = self + if self .modules .get_cuda_tile_op_attrs(ident.to_string().as_str()) + .is_some() { Ok(self.compile_cuda_tile_op_call( module, @@ -2412,24 +2411,22 @@ impl<'m> CUDATileFunctionCompiler<'m> { module_name, fn_item, call_expr, - &generic_vars, + generic_vars, ctx, return_type, )?) } } else { - return self.jit_error_result( + self.jit_error_result( &call_expr.func.span(), &format!("call to `{}` is not supported", &call_expr_func_str), - ); + ) } } - _ => { - return self.jit_error_result( - &call_expr.func.span(), - &format!("Call to {} not supported.", &call_expr_func_str), - ) - } + _ => self.jit_error_result( + &call_expr.func.span(), + &format!("Call to {} not supported.", &call_expr_func_str), + ), } } Expr::MethodCall(method_call_expr) => { @@ -2456,8 +2453,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { if let Some(value) = self.compile_global_method_call( module, block_id, - &method_call_expr, - &generic_vars, + method_call_expr, + generic_vars, ctx, return_type.clone(), )? { @@ -2466,8 +2463,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { Ok(self.inline_method_call( module, block_id, - &method_call_expr, - &generic_vars, + method_call_expr, + generic_vars, ctx, return_type, )?) @@ -2505,7 +2502,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(field_value) = fields.get(&field_name.to_string()) else { return self.jit_error_result( &field_name.span(), - &format!("{} is not a field.", field_name.to_string()), + &format!("{} is not a field.", field_name), ); }; Ok(Some(field_value.clone())) @@ -2572,7 +2569,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let str = format!("-{}", int_lit.base10_digits()); let val = -int_lit .base10_parse::() - .expect(format!("Failed to parse literal {str}").as_str()) + .unwrap_or_else(|_| panic!("Failed to parse literal {str}")) as i64; (str, Some(Bounds::exact(val))) } @@ -2584,7 +2581,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } }; let Some(cuda_tile_ty) = return_type - .get_cuda_tile_element_type(&self.modules.primitives())? + .get_cuda_tile_element_type(self.modules.primitives())? else { return self.jit_error_result( &lit_expr.span(), @@ -2624,12 +2621,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { op_result, ct_type, bounds, ))) } - _ => { - return self.jit_error_result( - &unary_expr.span(), - "Non-const unary expressions not supported.", - ) - } + _ => self.jit_error_result( + &unary_expr.span(), + "Non-const unary expressions not supported.", + ), } } Expr::Cast(cast_expr) => { @@ -2637,7 +2632,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .compile_expression( module, block_id, - &*cast_expr.expr, + &cast_expr.expr, generic_vars, ctx, None, @@ -2645,7 +2640,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { .unwrap(); let src_elem_ty: String = src_expr .ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) .unwrap(); let dst_elem_ty: String = get_rust_element_type_primitive(&cast_expr.ty); match (src_elem_ty.as_str(), dst_elem_ty.as_str()) { @@ -2686,7 +2681,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &lit_expr.span(), &format!( "Failed to infer type for lit expr {}.", - lit_expr.to_token_stream().to_string() + lit_expr.to_token_stream() ), ); }; @@ -2702,7 +2697,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let str = int_lit.base10_digits().to_string(); let val = int_lit .base10_parse::() - .expect(format!("Failed to parse literal {str}").as_str()) + .unwrap_or_else(|_| panic!("Failed to parse literal {str}")) as i64; (str, Some(Bounds::exact(val))) } @@ -2715,7 +2710,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { } }; let Some(cuda_tile_ty) = - return_type.get_cuda_tile_element_type(&self.modules.primitives())? + return_type.get_cuda_tile_element_type(self.modules.primitives())? else { return self.jit_error_result( &lit_expr.span(), @@ -2759,7 +2754,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { Ok(self.compile_binary_op( module, block_id, - &bin_expr, + bin_expr, generic_vars, ctx, return_type.clone(), @@ -2831,17 +2826,17 @@ impl<'m> CUDATileFunctionCompiler<'m> { // Closures cannot be used as standalone expressions in CUDA Tile. // They are only supported as arguments to specific operations (e.g., reduce, scan) // that compile them into tile-ir regions. - return self.jit_error_result( + self.jit_error_result( &closure_expr.span(), "closures are not supported as standalone values; \ they can only be used as arguments to operations like `reduce()` or `scan()`", - ); + ) } Expr::Index(index_expr) => { let Some(expr_val) = self.compile_expression( module, block_id, - &*index_expr.expr, + &index_expr.expr, generic_vars, ctx, return_type.clone(), @@ -2858,7 +2853,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let Some(index_val) = self.compile_expression( module, block_id, - &*index_expr.index, + &index_expr.index, generic_vars, ctx, i32_type, @@ -2930,15 +2925,12 @@ impl<'m> CUDATileFunctionCompiler<'m> { return Ok(Some(values.remove(index))); } } - return self.jit_error_result( + self.jit_error_result( &index_expr.expr.span(), "indexing is only supported on tuple/compound values and shape-like descriptors", - ); - } - _ => { - return self - .jit_error_result(&expr.span(), "this expression form is not supported") + ) } + _ => self.jit_error_result(&expr.span(), "this expression form is not supported"), } }) // stacker::maybe_grow } diff --git a/cutile-compiler/src/compiler/compile_global.rs b/cutile-compiler/src/compiler/compile_global.rs index 7dab06f..bdf7dc1 100644 --- a/cutile-compiler/src/compiler/compile_global.rs +++ b/cutile-compiler/src/compiler/compile_global.rs @@ -67,7 +67,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { })?; let value_ty = TileIrType::Tile(TileType { shape: vec![1], - element_type: TileElementType::Scalar(scalar.clone()), + element_type: TileElementType::Scalar(scalar), }); let init_expr = self.global_static_initializer(item)?; let init_value = self.global_scalar_initializer_value(&init_expr, module_name)?; @@ -203,6 +203,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ))) } + #[allow(clippy::too_many_arguments)] fn compile_global_store( &self, module: &mut Module, @@ -288,6 +289,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ))) } + #[allow(clippy::too_many_arguments)] fn compile_global_atomic_add( &self, module: &mut Module, diff --git a/cutile-compiler/src/compiler/compile_inline.rs b/cutile-compiler/src/compiler/compile_inline.rs index 5d77236..b63d09d 100644 --- a/cutile-compiler/src/compiler/compile_inline.rs +++ b/cutile-compiler/src/compiler/compile_inline.rs @@ -91,11 +91,12 @@ fn set_lit_span(lit: &mut syn::Lit, span: Span) { } impl<'m> CUDATileFunctionCompiler<'m> { + #[allow(clippy::too_many_arguments)] pub fn inline_function_call( &self, module: &mut Module, block_id: BlockId, - module_name: &String, + module_name: &str, fn_item: &ItemFn, call_expr: &ExprCall, generic_vars: &GenericVars, @@ -130,7 +131,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let param_names = get_sig_param_names(&fn_item.sig); let (input_params, _output_param) = get_sig_types(&fn_item.sig, None); let mut call_variables = CompilerContext::empty(); - call_variables.module_scope.push(module_name.clone()); + call_variables.module_scope.push(module_name.into()); let mut outer2inner_map = HashMap::new(); let sig_param_mutability = get_sig_param_mutability(&fn_item.sig); @@ -169,7 +170,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { generic_arg_inference.map_args_to_params(&call_arg_rust_tys, None); // println!("inline_function_call {:#?}: generic_vars={generic_vars:#?} \nexpr_generic_args={expr_generic_args:#?} \ngeneric_arg_inference={generic_arg_inference:#?}", fn_item.sig.ident.to_string()); generic_arg_inference - .get_generic_vars_instance(&generic_vars, &self.modules.primitives()) + .get_generic_vars_instance(generic_vars, self.modules.primitives()) }; self.add_module_const_vars(&mut call_generic_vars); // Add function call const generics as variables. diff --git a/cutile-compiler/src/compiler/compile_intrinsic.rs b/cutile-compiler/src/compiler/compile_intrinsic.rs index 7b40bbe..f45a450 100644 --- a/cutile-compiler/src/compiler/compile_intrinsic.rs +++ b/cutile-compiler/src/compiler/compile_intrinsic.rs @@ -120,6 +120,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { }) } + #[allow(clippy::too_many_arguments)] fn compile_nested_mutable_access_offset_metadata( &self, module: &mut Module, @@ -185,7 +186,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { pid } else { let ratio_value = - self.compile_constant(module, block_id, generic_vars, ratio as i32)?; + self.compile_constant(module, block_id, generic_vars, ratio)?; self.compile_binary_op_from_values( module, block_id, @@ -213,6 +214,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { /// with `#[cuda_tile::compiler_op(...)]`. These are internal operations /// like mma, tile ops, shape ops, reduce, arithmetic, cast, convert, /// return_type_meta_field, set_type_meta_field, check, and assume. + #[allow(clippy::too_many_arguments)] pub fn compile_compiler_op_call( &self, module: &mut Module, @@ -226,7 +228,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { return_type: Option, ) -> Result, JITError> { let call_expr_func_str = call_expr.func.to_token_stream().to_string(); - let ident = get_ident_from_path_expr(&path_expr); + let ident = get_ident_from_path_expr(path_expr); let Some(compiler_op_name) = compiler_op_attrs.parse_string("name") else { return self.jit_error_result( &call_expr.span(), @@ -242,7 +244,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { let out = operands.remove(0); let out_type = out.ty.clone(); let Some(out_rust_element_type) = - out_type.get_instantiated_rust_element_type(&self.modules.primitives()) + out_type.get_instantiated_rust_element_type(self.modules.primitives()) else { return self.jit_error_result( &call_expr.span(), @@ -262,7 +264,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ); }; let Some(out_cuda_tile_element_type) = - out_type.get_cuda_tile_element_type(&self.modules.primitives())? + out_type.get_cuda_tile_element_type(self.modules.primitives())? else { return self.jit_error_result( &call_expr.span(), @@ -273,13 +275,13 @@ impl<'m> CUDATileFunctionCompiler<'m> { ); }; let out_is_float = super::_type::scalar_from_name(&out_cuda_tile_element_type) - .map_or(false, |s| s.is_float()); + .is_some_and(|s| s.is_float()); let (opcode, attrs) = if out_is_float { (Opcode::MmaF, vec![]) } else if !out_is_float { let Some(lhs_elem_ty) = lhs .ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) else { return self.jit_error_result( &call_expr.span(), @@ -288,7 +290,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; let Some(rhs_elem_ty) = lhs .ty - .get_instantiated_rust_element_type(&self.modules.primitives()) + .get_instantiated_rust_element_type(self.modules.primitives()) else { return self.jit_error_result( &call_expr.span(), @@ -472,7 +474,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { "expected a structured type for the dimension map argument", ); }; - let Some(dim_map) = type_inst.try_extract_cga(&generic_vars) else { + let Some(dim_map) = type_inst.try_extract_cga(generic_vars) else { return self.jit_error_result( &call_expr.args[1].span(), "dimension map must be a const generic array type", @@ -497,12 +499,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { }; Ok(Some(dst_slice)) } - _ => { - return self.jit_error_result( - &call_expr.span(), - &format!("unrecognized shape operation `{}`", compiler_op_function), - ); - } + _ => self.jit_error_result( + &call_expr.span(), + &format!("unrecognized shape operation `{}`", compiler_op_function), + ), } } "dim_new" => { @@ -611,7 +611,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { })?, }; let dim_origin = value.dim_origin.clone(); - let bounds = value.bounds.clone(); + let bounds = value.bounds; let mut fields = BTreeMap::new(); fields.insert("size".to_string(), value); let mut dim = TileRustValue::new_struct(fields, return_type); @@ -1667,7 +1667,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ir_operand_type.clone(), // operand_i_prev_iter ]; let (local_block_id, local_block_args) = build_block(module, local_var_types); - let local_var_names = vec!["curr", "prev"]; + let local_var_names = ["curr", "prev"]; let mut local_vars = CompilerContext::empty(); for i in 0..local_block_args.len() { let value: Value = local_block_args[i]; @@ -1776,12 +1776,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { &call_expr.span(), )?)) } - _ => { - return self.jit_error_result( - &call_expr.span(), - &format!("arithmetic ops with {num_operands} operands not supported"), - ); - } + _ => self.jit_error_result( + &call_expr.span(), + &format!("arithmetic ops with {num_operands} operands not supported"), + ), } } "cast" => { @@ -1812,22 +1810,19 @@ impl<'m> CUDATileFunctionCompiler<'m> { } "tile_to_scalar" => { let Some(element_type) = - get_element_type_structured(&old_type, &self.modules.primitives()) + get_element_type_structured(&old_type, self.modules.primitives()) else { return self.jit_error_result( &call_expr.span(), &format!( "Failed to cast from {} to {}", - old_type.to_token_stream().to_string(), - get_sig_output_type(&fn_item.sig) - .to_token_stream() - .to_string() + old_type.to_token_stream(), + get_sig_output_type(&fn_item.sig).to_token_stream() ), ); }; new_value.ty.rust_ty = - syn::parse2::(format!("{element_type}").parse().unwrap()) - .unwrap(); + syn::parse2::(element_type.parse().unwrap()).unwrap(); } "pointer_to_tile" => { let element_type = get_rust_element_type_primitive(&old_type); @@ -1842,16 +1837,14 @@ impl<'m> CUDATileFunctionCompiler<'m> { } "tile_to_pointer" => { let Some(element_type) = - get_element_type_structured(&old_type, &self.modules.primitives()) + get_element_type_structured(&old_type, self.modules.primitives()) else { return self.jit_error_result( &call_expr.span(), &format!( "Failed to cast from {} to {}", - old_type.to_token_stream().to_string(), - get_sig_output_type(&fn_item.sig) - .to_token_stream() - .to_string() + old_type.to_token_stream(), + get_sig_output_type(&fn_item.sig).to_token_stream() ), ); }; @@ -1887,8 +1880,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { ); } let mut arg = args.pop().unwrap(); - let new_type_compiled = if return_type.is_some() { - return_type.unwrap() + let new_type_compiled = if let Some(return_type) = return_type { + return_type } else { let PathArguments::AngleBracketed(generic_args) = &path_expr.path.segments.last().unwrap().arguments @@ -1897,7 +1890,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &path_expr.span(), &format!( "Failed to get type parameters for {}", - path_expr.to_token_stream().to_string() + path_expr.to_token_stream() ), ); }; @@ -1915,18 +1908,18 @@ impl<'m> CUDATileFunctionCompiler<'m> { &path_expr.span(), &format!( "Failed to get type parameters for {}", - path_expr.to_token_stream().to_string() + path_expr.to_token_stream() ), ); }; let Some(new_type_compiled) = - self.compile_type(&new_type, &generic_vars, &HashMap::new())? + self.compile_type(new_type, generic_vars, &HashMap::new())? else { return self.jit_error_result( &call_expr.span(), &format!( "{compiler_op_function} failed to compile new type: {}", - new_type.to_token_stream().to_string() + new_type.to_token_stream() ), ); }; @@ -1957,16 +1950,16 @@ impl<'m> CUDATileFunctionCompiler<'m> { return Ok(Some(arg)); } let output_type = - tile_ir_type_from_trt(&new_type_compiled, &self.modules.primitives()) + tile_ir_type_from_trt(&new_type_compiled, self.modules.primitives()) .ok_or_else(|| { - self.jit_error( - &call_expr.span(), - &format!( - "Failed to obtain tile-ir type for convert {}", - call_expr.to_token_stream().to_string() - ), - ) - })?; + self.jit_error( + &call_expr.span(), + &format!( + "Failed to obtain tile-ir type for convert {}", + call_expr.to_token_stream() + ), + ) + })?; // These aren't required for all ops. let (op_id, results) = match ( old_element_type_str.as_str(), @@ -1994,9 +1987,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { // Integer → float: IToF with signedness from source type. (from, to) if super::_type::scalar_from_name(from) - .map_or(false, |s| s.is_integer()) + .is_some_and(|s| s.is_integer()) && super::_type::scalar_from_name(to) - .map_or(false, |s| s.is_float()) => + .is_some_and(|s| s.is_float()) => { let signedness = signedness_attr( "signedness", @@ -2011,7 +2004,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { call_expr .args .to_token_stream() - .to_string() ), ); }; @@ -2024,9 +2016,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { // Float → integer: FToI with signedness from target type. (from, to) if super::_type::scalar_from_name(from) - .map_or(false, |s| s.is_float()) + .is_some_and(|s| s.is_float()) && super::_type::scalar_from_name(to) - .map_or(false, |s| s.is_integer()) => + .is_some_and(|s| s.is_integer()) => { let signedness = signedness_attr( "signedness", @@ -2040,7 +2032,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { call_expr .args .to_token_stream() - .to_string() ), ); }; @@ -2058,9 +2049,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { // cutile-python's _get_type_conversion_encoder. (from, to) if super::_type::scalar_from_name(from) - .map_or(false, |s| s.is_float()) + .is_some_and(|s| s.is_float()) && super::_type::scalar_from_name(to) - .map_or(false, |s| s.is_float()) => + .is_some_and(|s| s.is_float()) => { let rounding = rounding_mode_attr("nearest_even"); let Some(input_value) = arg.value else { @@ -2071,7 +2062,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { call_expr .args .to_token_stream() - .to_string() ), ); }; @@ -2097,12 +2087,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { new_type_compiled, ))) } - _ => { - return self.jit_error_result( - &call_expr.span(), - &format!("Unsupported convert compiler_op: {}", compiler_op_function), - ); - } + _ => self.jit_error_result( + &call_expr.span(), + &format!("Unsupported convert compiler_op: {}", compiler_op_function), + ), } } "return_type_meta_field" => { @@ -2253,19 +2241,19 @@ impl<'m> CUDATileFunctionCompiler<'m> { &call_expr.args[0].span(), &format!( "first argument to `set_type_meta_field` must be a simple variable path, got `{}`", - call_expr.to_token_stream().to_string() + call_expr.to_token_stream() ), ); }; let var_name = get_ident_from_path_expr(var_arg) .to_token_stream() .to_string(); - if ctx.vars.get(var_name.as_str()).is_none() { + if !ctx.vars.contains_key(var_name.as_str()) { return self.jit_error_result( &call_expr.args[0].span(), &format!( "first argument to `set_type_meta_field` must be a known variable, got `{}`", - call_expr.to_token_stream().to_string() + call_expr.to_token_stream() ), ); } @@ -2303,7 +2291,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { ); } ctx.vars.insert(var_name.clone(), result_value); - return Ok(None); + Ok(None) } "check" => { if self.entry_attrs.get_entry_arg_bool("unchecked_accesses") { @@ -2320,12 +2308,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { generic_vars, ctx, )?), - _ => { - return self.jit_error_result( - &call_expr.span(), - &format!("Unexpected compiler_op call {}", &call_expr_func_str), - ); - } + _ => self.jit_error_result( + &call_expr.span(), + &format!("Unexpected compiler_op call {}", &call_expr_func_str), + ), } } "assume" => { @@ -2333,12 +2319,10 @@ impl<'m> CUDATileFunctionCompiler<'m> { self.compile_assumption_call(call_expr, module, block_id, generic_vars, ctx)?; Ok(Some(tr_value)) } - _ => { - return self.jit_error_result( - &call_expr.span(), - &format!("Unexpected compiler_op {compiler_op_attrs:#?}"), - ); - } + _ => self.jit_error_result( + &call_expr.span(), + &format!("Unexpected compiler_op {compiler_op_attrs:#?}"), + ), } } @@ -2449,7 +2433,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &call_expr.span(), &format!( "expected a structured type instance for dimension map, got `{}`", - dim_map.rust_ty.to_token_stream().to_string() + dim_map.rust_ty.to_token_stream() ), ); }; @@ -2541,9 +2525,8 @@ impl<'m> CUDATileFunctionCompiler<'m> { continue; } } - if index_value.bounds.is_some() && is_static_shape_dim { + if let (Some(bounds), true) = (index_value.bounds, is_static_shape_dim) { // We can do a static bounds check. - let bounds = index_value.bounds.unwrap(); let num_partitions = (static_shape_dim as i64 + static_tile_dim as i64 - 1) / static_tile_dim as i64; if !(0 <= bounds.start && bounds.end < num_partitions) { @@ -2649,6 +2632,6 @@ impl<'m> CUDATileFunctionCompiler<'m> { .build(module); append_op(module, block_id, assert_op_id); } - return Ok(None); + Ok(None) } } diff --git a/cutile-compiler/src/compiler/compile_type.rs b/cutile-compiler/src/compiler/compile_type.rs index 2f3d4ef..ed0d3b7 100644 --- a/cutile-compiler/src/compiler/compile_type.rs +++ b/cutile-compiler/src/compiler/compile_type.rs @@ -37,13 +37,13 @@ impl<'m> CUDATileFunctionCompiler<'m> { match ty { // Array, Slice, and Tuple compile to the same compiler representation (compound values). syn::Type::Tuple(tuple) => { - if tuple.elems.len() == 0 { + if tuple.elems.is_empty() { return Ok(None); } else { let unknown_type_instance = TypeInstanceUserType::instantiate( - &ty, + ty, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), ) .unwrap(); let type_instance = TypeInstance::UserType(unknown_type_instance); @@ -52,9 +52,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { } syn::Type::Array(_) => { let unknown_type_instance = TypeInstanceUserType::instantiate( - &ty, + ty, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), ) .unwrap(); let type_instance = TypeInstance::UserType(unknown_type_instance); @@ -62,16 +62,16 @@ impl<'m> CUDATileFunctionCompiler<'m> { } syn::Type::Slice(_) => { let unknown_type_instance = TypeInstanceUserType::instantiate( - &ty, + ty, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), ) .unwrap(); let type_instance = TypeInstance::UserType(unknown_type_instance); return Ok(Some(TileRustType::new_compound(type_instance))); } syn::Type::Reference(ref_ty) => { - let mut res = self.compile_type(&*ref_ty.elem, generic_vars, type_params)?; + let mut res = self.compile_type(&ref_ty.elem, generic_vars, type_params)?; match &mut res { Some(cuda_tile_ty) => { cuda_tile_ty.rust_ty = ty.clone(); @@ -85,9 +85,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { let type_name = ident.to_string(); if type_name == "Option" { let option_type_instance = TypeInstanceUserType::instantiate( - &ty, + ty, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), ) .unwrap(); return Ok(Some(TileRustType::new_enum(TypeInstance::UserType( @@ -98,9 +98,9 @@ impl<'m> CUDATileFunctionCompiler<'m> { ty_attrs = self.modules.get_cuda_tile_type_attrs(type_name.as_str()); if ty_attrs.is_none() { let unknown_type_instance = TypeInstanceUserType::instantiate( - &ty, + ty, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), ) .unwrap(); let type_instance = TypeInstance::UserType(unknown_type_instance); @@ -109,12 +109,12 @@ impl<'m> CUDATileFunctionCompiler<'m> { type_instance, ))); } - structure = Some((type_name.clone(), &item_struct)); + structure = Some((type_name.clone(), item_struct)); type_instance = - Some(generic_vars.instantiate_type(ty, &self.modules.primitives())?); + Some(generic_vars.instantiate_type(ty, self.modules.primitives())?); } else { let local_type_instance = - generic_vars.instantiate_type(ty, &self.modules.primitives())?; + generic_vars.instantiate_type(ty, self.modules.primitives())?; if let TypeInstance::StringType(_string_inst) = local_type_instance { return Ok(Some(TileRustType::new_string(TypeInstance::StringType( _string_inst, @@ -144,13 +144,13 @@ impl<'m> CUDATileFunctionCompiler<'m> { None => return self.jit_error_result(&ty.span(), "Failed to compile type"), }, syn::Type::Ptr(_) => { - let type_name = get_type_ident(&ty); + let type_name = get_type_ident(ty); if type_name.is_none() { return self.jit_error_result(&ty.span(), "Failed to compile type"); } let type_name = type_name.unwrap().to_string(); let local_type_instance = - generic_vars.instantiate_type(ty, &self.modules.primitives())?; + generic_vars.instantiate_type(ty, self.modules.primitives())?; let Some(element_type_instance_str) = local_type_instance.get_rust_element_instance_ty() else { @@ -192,7 +192,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &ty.span(), &format!( "Unable to compile compiling type {} using attrs {ty_attrs:#?}", - ty.to_token_stream().to_string() + ty.to_token_stream() ), ); } @@ -255,7 +255,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { return Ok(Some(TileRustType::new_structured_type( type_name, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), args, type_instance, )?)); @@ -272,7 +272,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &ty.span(), &format!( "Unable to compile type {} using attrs {scalar_attrs:#?}", - element_instance.generic_ty.to_token_stream().to_string() + element_instance.generic_ty.to_token_stream() ), ); } @@ -283,13 +283,13 @@ impl<'m> CUDATileFunctionCompiler<'m> { None, Some(TypeInstance::ElementType(element_instance.clone())), )]; - return Ok(Some(TileRustType::new_primitive_type( + Ok(Some(TileRustType::new_primitive_type( type_name, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), args, TypeInstance::ElementType(element_instance), - )?)); + )?)) } else if let Some(TypeInstance::PtrType(ptr_instance)) = type_instance { let Some(pointer_attrs) = self.modules.get_primitives_attrs("Pointer", "* mut E") else { @@ -302,7 +302,7 @@ impl<'m> CUDATileFunctionCompiler<'m> { &ty.span(), &format!( "Unable to compile compiling type {} using attrs {pointer_attrs:#?}", - ty.to_token_stream().to_string() + ty.to_token_stream() ), ); } @@ -313,18 +313,18 @@ impl<'m> CUDATileFunctionCompiler<'m> { None, Some(TypeInstance::PtrType(ptr_instance.clone())), )]; - return Ok(Some(TileRustType::new_primitive_type( + Ok(Some(TileRustType::new_primitive_type( type_name, generic_vars, - &self.modules.primitives(), + self.modules.primitives(), args, TypeInstance::PtrType(ptr_instance), - )?)); + )?)) } else { - return self.jit_error_result( + self.jit_error_result( &ty.span(), &format!("Unable to instantiate Scalar or Pointer impls: type_instance={type_instance:#?}"), - ); + ) } } } diff --git a/cutile-compiler/src/compiler/shared_utils.rs b/cutile-compiler/src/compiler/shared_utils.rs index 45c0c9e..c771ec6 100644 --- a/cutile-compiler/src/compiler/shared_utils.rs +++ b/cutile-compiler/src/compiler/shared_utils.rs @@ -92,13 +92,13 @@ impl AtomicMode { &format!("invalid atomic mode `{mode}`; valid modes are: And, Or, Xor, Add, AddF, Max, Min, Umax, Umin, Xchg"), ), }; - if elem_ty_prefix == ElementTypePrefix::Float { - if ![AtomicMode::XChg, AtomicMode::AddF].contains(&result) { - return SourceLocation::unknown().jit_error_result(&format!( - "float types only support `Xchg` and `AddF` atomic modes, got `{:?}`", - result - )); - } + if elem_ty_prefix == ElementTypePrefix::Float + && ![AtomicMode::XChg, AtomicMode::AddF].contains(&result) + { + return SourceLocation::unknown().jit_error_result(&format!( + "float types only support `Xchg` and `AddF` atomic modes, got `{:?}`", + result + )); } Ok(result) } @@ -334,7 +334,7 @@ pub fn extract_zst_type_name(expr: &syn::Expr, param_name: &str) -> Result Ok(path.path.segments.last().unwrap().ident.to_string()), _ => SourceLocation::unknown().jit_error_result(&format!( "`{param_name}` must be a unit-struct type-as-value path, got `{}`", - expr.to_token_stream().to_string() + expr.to_token_stream() )), } } @@ -391,7 +391,7 @@ pub fn extract_string_literal( } _ => SourceLocation::unknown().jit_error_result(&format!( "`{param_name}` must be a string literal, got `{}`", - expr.to_token_stream().to_string() + expr.to_token_stream() )), } } @@ -497,7 +497,7 @@ pub fn collect_mutated_variables_from_block( ) -> Result, JITError> { let mut local_vars: HashSet = HashSet::new(); let mut result: BTreeSet = BTreeSet::new(); - for (_i, statement) in block.stmts.iter().enumerate() { + for statement in block.stmts.iter() { match statement { Stmt::Local(local) => { let mut var_names: Vec = vec![]; @@ -628,7 +628,7 @@ pub fn dedup(v: &mut Vec) { pub fn parse_list_of_expr(tokens: TokenStream) -> Result, JITError> { let mut args: Vec = vec![]; let mut arg_expr: Vec = vec![]; - for (_i, token) in tokens.clone().into_iter().enumerate() { + for token in tokens.clone().into_iter() { match &token { TokenTree::Literal(_lit) => { arg_expr.push(token.clone()); @@ -638,7 +638,7 @@ pub fn parse_list_of_expr(tokens: TokenStream) -> Result, JITError> { } TokenTree::Punct(punct) => { if punct.as_char() == ',' { - if arg_expr.len() > 0 { + if !arg_expr.is_empty() { let expr = syn::parse2::(arg_expr.into_iter().collect()).unwrap(); args.push(expr); @@ -649,14 +649,12 @@ pub fn parse_list_of_expr(tokens: TokenStream) -> Result, JITError> { } } _ => { - return SourceLocation::unknown().jit_error_result(&format!( - "unexpected token `{}` in expression list", - token.to_string() - )); + return SourceLocation::unknown() + .jit_error_result(&format!("unexpected token `{}` in expression list", token)); } } } - if arg_expr.len() > 0 { + if !arg_expr.is_empty() { let expr = syn::parse2::(arg_expr.into_iter().collect()).unwrap(); args.push(expr); } @@ -677,7 +675,7 @@ pub fn update_token( let Some(var_arg_ident) = get_ident_from_expr(var_arg) else { return SourceLocation::unknown().jit_error_result(&format!( "expected a variable name, got `{}`", - var_arg.to_token_stream().to_string() + var_arg.to_token_stream() )); }; let var_name = var_arg_ident.to_string(); @@ -710,7 +708,7 @@ pub fn get_token_from_expr( let Some(var_arg_ident) = get_ident_from_expr(var_arg) else { return SourceLocation::unknown().jit_error_result(&format!( "expected a variable name, got `{}`", - var_arg.to_token_stream().to_string() + var_arg.to_token_stream() )); }; let var_name = var_arg_ident.to_string(); @@ -742,7 +740,7 @@ pub fn update_outer_block_type_meta( inner_block_vars: &mut CompilerContext, outer_block_vars: &mut CompilerContext, field_name: String, -) -> () { +) { let mut var_map = std::collections::HashMap::new(); for var_name in outer_block_vars.var_keys() { var_map.insert(var_name.clone(), var_name.clone()); @@ -756,7 +754,7 @@ pub fn update_type_meta( outer_block_vars: &mut CompilerContext, outer2inner_vars: &std::collections::HashMap, _field_name: String, -) -> () { +) { use super::shared_types::Mutability; let outer_keys_ = outer_block_vars.var_keys(); let outer_keys = outer_keys_ diff --git a/cutile-compiler/src/compiler/tile_rust_type.rs b/cutile-compiler/src/compiler/tile_rust_type.rs index cd5faaa..fc6ea19 100644 --- a/cutile-compiler/src/compiler/tile_rust_type.rs +++ b/cutile-compiler/src/compiler/tile_rust_type.rs @@ -21,7 +21,7 @@ use syn::ItemImpl; /// Build a `TypeInstance::StructuredType` for a synthetic tile type with element info. fn synthetic_tile_instance(rust_ty: syn::Type, element_name: &str, shape: &[i32]) -> TypeInstance { let elem_ty = syn::parse_str::(element_name).unwrap_or(rust_ty.clone()); - TypeInstance::StructuredType(crate::generics::TypeInstanceStructuredType { + TypeInstance::StructuredType(Box::new(crate::generics::TypeInstanceStructuredType { generic_ty: rust_ty.clone(), instance_ty: rust_ty, primitive_type: Some(crate::generics::TypInstancePrimitiveType::ElementType( @@ -32,12 +32,12 @@ fn synthetic_tile_instance(rust_ty: syn::Type, element_name: &str, shape: &[i32] }, )), shape: shape.to_vec(), - }) + })) } /// Build a `TypeInstance::StructuredType` for a synthetic pointer tile type. fn synthetic_ptr_instance(rust_ty: syn::Type, element_name: &str) -> TypeInstance { - TypeInstance::StructuredType(crate::generics::TypeInstanceStructuredType { + TypeInstance::StructuredType(Box::new(crate::generics::TypeInstanceStructuredType { generic_ty: rust_ty.clone(), instance_ty: rust_ty.clone(), primitive_type: Some(crate::generics::TypInstancePrimitiveType::PtrType( @@ -49,7 +49,7 @@ fn synthetic_ptr_instance(rust_ty: syn::Type, element_name: &str) -> TypeInstanc }, )), shape: vec![], - }) + })) } /// A compiled type binding: maps a Rust `syn::Type` to its CUDA Tile type metadata. @@ -107,12 +107,10 @@ impl TileRustType { .ok_or_else(|| { JITError::generic_err(&format!( "unable to determine element type for `{}`", - self.rust_ty.to_token_stream().to_string() + self.rust_ty.to_token_stream() )) })?; - Ok(super::shared_utils::ElementTypePrefix::new( - &cuda_elem_ty_str, - )?) + super::shared_utils::ElementTypePrefix::new(&cuda_elem_ty_str) } pub(crate) fn get_cuda_tile_type_str(&self) -> Option { self.cuda_tile_ty_str.clone() @@ -132,7 +130,7 @@ impl TileRustType { let rust_ty = type_instance.get_source_type().clone(); let type_param_str = params .iter_mut() - .map(|tp| tp.instantiate(generic_vars, &primitives)) + .map(|tp| tp.instantiate(generic_vars, primitives)) .collect::, _>>()? .join(","); let type_str = format!("{}<{}>", cuda_tile_name, type_param_str); @@ -156,12 +154,12 @@ impl TileRustType { type_instance: TypeInstance, ) -> Result { let rust_ty = type_instance.get_source_type().clone(); - let type_str = if params.len() == 0 { - format!("{}", cuda_tile_name) + let type_str = if params.is_empty() { + cuda_tile_name.to_string() } else { let type_param_str = params .iter_mut() - .map(|tp| tp.instantiate(generic_args, &primitives)) + .map(|tp| tp.instantiate(generic_args, primitives)) .collect::, _>>()? .join(","); format!("{}<{}>", cuda_tile_name, type_param_str) diff --git a/cutile-compiler/src/error.rs b/cutile-compiler/src/error.rs index cacc795..03f14df 100644 --- a/cutile-compiler/src/error.rs +++ b/cutile-compiler/src/error.rs @@ -60,11 +60,11 @@ impl error::Error for JITError {} impl JITError { /// Create a `Generic` error value (not wrapped in `Result`). pub fn generic_err(err_str: &str) -> JITError { - return JITError::Generic(err_str.to_string()); + JITError::Generic(err_str.to_string()) } /// Create a `Generic` error wrapped in `Err`. pub fn generic(err_str: &str) -> Result { - return Err(JITError::generic_err(err_str)); + Err(JITError::generic_err(err_str)) } /// Create a `Located` error that carries a real source location captured /// at proc macro expansion time. diff --git a/cutile-compiler/src/generics.rs b/cutile-compiler/src/generics.rs index c073577..79b6691 100644 --- a/cutile-compiler/src/generics.rs +++ b/cutile-compiler/src/generics.rs @@ -138,7 +138,7 @@ impl GenericVars { { Some(GenericVarType::ConstVariable) } else { - return None; + None } } @@ -390,7 +390,7 @@ impl GenericVars { let name = const_param.ident.to_string(); let from_generic_args = self; if let Some(res) = try_get_const_generic_from_generic_argument( - &generic_arg, + generic_arg, from_generic_args, ) { // This is something like @@ -398,24 +398,20 @@ impl GenericVars { inst_i32.insert(name.clone(), res); } else { let Some(res) = - get_cga_from_generic_argument(&generic_arg, from_generic_args) + get_cga_from_generic_argument(generic_arg, from_generic_args) else { return SourceLocation::unknown().jit_error_result(&format!( "unable to resolve generic argument `{}` for parameter `{}`", - generic_arg.to_token_stream().to_string(), - generic_param.to_token_stream().to_string() + generic_arg.to_token_stream(), + generic_param.to_token_stream() )); }; // This is something like // {[...]} -> CONST_ARRAY_PARAM inst_array.insert(name.clone(), res); if let Expr::Path(length_expr) = &ty_arr.len { - let length_var = length_expr - .path - .get_ident() - .unwrap() - .to_string() - .to_string(); + let length_var = + length_expr.path.get_ident().unwrap().to_string(); len2array.insert(length_var, name.clone()); } } @@ -423,8 +419,8 @@ impl GenericVars { _ => { return SourceLocation::unknown().jit_error_result(&format!( "unable to resolve generic argument `{}` for parameter `{}`", - generic_arg.to_token_stream().to_string(), - generic_param.to_token_stream().to_string() + generic_arg.to_token_stream(), + generic_param.to_token_stream() )); } } @@ -504,8 +500,8 @@ impl GenericVars { _ => { return SourceLocation::unknown().jit_error_result(&format!( "unable to resolve generic argument `{}` for parameter `{}`", - generic_arg.to_token_stream().to_string(), - generic_param.to_token_stream().to_string() + generic_arg.to_token_stream(), + generic_param.to_token_stream() )); } } @@ -622,12 +618,12 @@ impl GenericVars { if let Some(instance) = TypeInstanceStructuredType::instantiate(&maybe_generic_ty, self, primitives) { - return Ok(TypeInstance::StructuredType(instance)); + Ok(TypeInstance::StructuredType(Box::new(instance))) } else { - return SourceLocation::unknown().jit_error_result(&format!( + SourceLocation::unknown().jit_error_result(&format!( "unable to resolve generic type `{}`", - maybe_generic_ty.to_token_stream().to_string() - )); + maybe_generic_ty.to_token_stream() + )) } } } @@ -657,7 +653,7 @@ pub enum TypeInstance { /// A pointer type (`*mut E` / `*const E`). PtrType(TypeInstancePtrType), /// A shaped type with element type and dimensions (e.g. `Tile`). - StructuredType(TypeInstanceStructuredType), + StructuredType(Box), } impl TypeInstance { @@ -879,7 +875,7 @@ impl Instantiable for TypeInstancePtrType { .unwrap(); Some(Self { generic_ty: maybe_generic_ty.clone(), - instance_ty: instance_ty, + instance_ty, is_mutable, rust_element_instance_ty: concrete_ptr_ty.to_string(), }) @@ -931,7 +927,7 @@ impl Instantiable for TypeInstanceStructuredType { .path .segments .last_mut() - .expect(format!("Unexpected structured type {maybe_generic_ty:#?}.").as_str()); + .unwrap_or_else(|| panic!("Unexpected structured type {maybe_generic_ty:#?}.")); let PathArguments::AngleBracketed(type_params) = &mut last_seg.arguments else { panic!( "Unexpected structured type generic arguments {:#?} for {maybe_generic_ty:#?}", @@ -954,10 +950,8 @@ impl Instantiable for TypeInstanceStructuredType { let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); // println!("get_variadic_type_args: Type::Path: {}", last_ident); - if generic_vars.inst_array.contains_key(&last_ident) { + if let Some(array_instance) = generic_vars.inst_array.get(&last_ident) { // This is something like Shape for const generic array D: [i32; N]. - let array_instance = - generic_vars.inst_array.get(&last_ident).unwrap(); if shape.is_some() && !allows_extra_cga { panic!("Unexpected array arg: {last_ident:#?}") } @@ -1018,13 +1012,13 @@ impl Instantiable for TypeInstanceStructuredType { // Map-shape metadata beyond the tile shape does not affect // TileRustType's element or shape instantiation. } else { - panic!("Failed to get cuda tile type for ty={} \n generic_arg={generic_arg:#?} \n generic_args={generic_vars:#?}", maybe_generic_ty.to_token_stream().to_string()); + panic!("Failed to get cuda tile type for ty={} \n generic_arg={generic_arg:#?} \n generic_args={generic_vars:#?}", maybe_generic_ty.to_token_stream()); } } syn::Type::Ptr(_) => { let Some(ptr_inst) = TypeInstancePtrType::instantiate( type_param, - &generic_vars, + generic_vars, primitives, ) else { panic!("Unexpected primitives {primitives:#?}.") @@ -1149,13 +1143,14 @@ impl Instantiable for TypeInstanceStructuredType { // This is something like Tensor let num_rep_var = len_path.to_token_stream().to_string(); - if !generic_vars.get_i32(&num_rep_var).is_some() { - panic!( - "Expected instance for generic argument {}", - num_rep_var - ); - } - generic_vars.get_i32(&num_rep_var).unwrap() + generic_vars.get_i32(&num_rep_var).unwrap_or_else( + || { + panic!( + "Expected instance for generic argument {}", + num_rep_var + ) + }, + ) } Expr::Lit(len_lit) => { // This is something like Tensor @@ -1316,11 +1311,7 @@ impl GenericArgInference { } /// Maps positional call arguments to their corresponding parameter names. - pub fn map_args_to_params( - &mut self, - call_arg_rust_tys: &Vec, - self_ty: Option<&Type>, - ) -> () { + pub fn map_args_to_params(&mut self, call_arg_rust_tys: &[syn::Type], self_ty: Option<&Type>) { let (fn_arg_types, _return_type) = get_sig_types(&self.sig, self_ty); // Get the generic parameters in this function signature. for i in 0..call_arg_rust_tys.len() { @@ -1446,7 +1437,7 @@ impl GenericArgInference { let Some(method_params) = &self.method_params else { panic!( "Method params undefined for {}", - method_call_expr.to_token_stream().to_string() + method_call_expr.to_token_stream() ) }; assert_eq!(expr_generic_args.args.len(), method_params.len()); @@ -1547,7 +1538,7 @@ impl GenericArgInference { /// Returns `true` if all generic parameters have been resolved. pub fn verify(&self) -> bool { // Check if computed and succeeded. - for (_key, val) in &self.param2arg { + for val in self.param2arg.values() { if val.is_none() { return false; } @@ -1701,7 +1692,7 @@ impl GenericArgInference { self.add_generic_args(type_param, type_arg); } - fn add_generic_args(&mut self, type_param: &syn::Type, type_arg: &syn::Type) -> () { + fn add_generic_args(&mut self, type_param: &syn::Type, type_arg: &syn::Type) { // Adds generic arguments to arg_map. // arg_map maps generic parameters (present in arg_map upon initialization) to various GenericArgument patterns (see below). // Each key in arg_map specifies the set of generic parameters in a function / method signature. @@ -1794,7 +1785,7 @@ impl GenericArgInference { } (syn::Type::Ptr(arg_type_ptr), syn::Type::Ptr(param_type_ptr)) => { // Something like (PointerTile<*mut f32, ...>, PointerTile<*mut E, ...>) - let param_elem_ty = match get_type_ident(&*param_type_ptr.elem) { + let param_elem_ty = match get_type_ident(¶m_type_ptr.elem) { Some(ident) => ident.to_string(), None => panic!( "Unable to extract ident from pointer {param_type_ptr:#?}" @@ -1966,7 +1957,7 @@ impl GenericArgInference { } (syn::Type::Ptr(arg_type_ptr), syn::Type::Ptr(param_type_ptr)) => { // Something like (PointerTile<*mut f32, ...>, PointerTile<*mut E, ...>) - let param_elem_ty = match get_type_ident(&*param_type_ptr.elem) { + let param_elem_ty = match get_type_ident(¶m_type_ptr.elem) { Some(ident) => ident.to_string(), None => panic!("Unable to extract ident from pointer {param_type_ptr:#?}"), }; @@ -1997,7 +1988,7 @@ impl GenericArgInference { pub fn infer_type(&self, ty: &syn::Type, _generic_vars: &GenericVars) -> syn::Type { let arg_map = &self.param2arg; // println!("Infer generic args for {} using \n {arg_map:#?}", ty.to_token_stream().to_string()); - let Some(mut result_args) = maybe_generic_args(&ty) else { + let Some(mut result_args) = maybe_generic_args(ty) else { // Is it a generic arg itself? // TODO (hme): *Really* need to make this recursive and just call with the following types. let mut result = ty.clone(); @@ -2180,36 +2171,29 @@ impl GenericArgInference { else { panic!("Unexpected block expression.") }; - match param_stmt_expr { - Expr::Array(param_array_expr) => { - for i in 0..param_array_expr.elems.iter().len() { - let param_elem = &mut param_array_expr.elems[i]; - let param_var = param_elem.to_token_stream().to_string(); - match arg_map.get(param_var.as_str()) { - None => { - // This is not a generic parameter. - } - Some(None) => { - panic!( - "Failed to infer generic parameter {param_var}" - ) - } - Some(Some(( - GenericArgType::GenericConstExpr, - target_expr, - ))) => { - *param_elem = syn::parse2::( - target_expr.parse().unwrap(), - ) - .unwrap(); - } - Some(Some((arg_type, _arg))) => { - panic!("Unexpected arg type {arg_type:#?}") - } + if let Expr::Array(param_array_expr) = param_stmt_expr { + for param_elem in &mut param_array_expr.elems { + let param_var = param_elem.to_token_stream().to_string(); + match arg_map.get(param_var.as_str()) { + None => { + // This is not a generic parameter. + } + Some(None) => { + panic!("Failed to infer generic parameter {param_var}") + } + Some(Some(( + GenericArgType::GenericConstExpr, + target_expr, + ))) => { + *param_elem = + syn::parse2::(target_expr.parse().unwrap()) + .unwrap(); + } + Some(Some((arg_type, _arg))) => { + panic!("Unexpected arg type {arg_type:#?}") } } } - _ => {} } } Expr::Path(param_path) => { @@ -2247,7 +2231,7 @@ pub fn get_cga_from_type(ty: &syn::Type, generic_args: &GenericVars) -> Option Option { let mut result: Option = None; match generic_arg { - GenericArgument::Type(type_param) => { - match type_param { - syn::Type::Path(type_path) => { - let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); - // println!("get_variadic_type_args: Type::Path: {}", last_ident); - if generic_args.inst_i32.contains_key(&last_ident) { - // This is something like N for const generic N: i32. - result = Some(generic_args.inst_i32.get(&last_ident).unwrap().clone()); - } - // If it's anything else, then return None. - } - _ => {} + GenericArgument::Type(syn::Type::Path(type_path)) => { + let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); + // println!("get_variadic_type_args: Type::Path: {}", last_ident); + if let Some(&value) = generic_args.inst_i32.get(&last_ident) { + // This is something like N for const generic N: i32. + result = Some(value); } } - GenericArgument::Const(const_param) => { - // println!("expand GenericArgument::Const? {const_param:#?}"); - match const_param { - Expr::Lit(lit) => { - let Lit::Int(int_lit) = &lit.lit else { - panic!("Expected int literal, got {:#?}", lit) - }; - // This is something like 32 in Tile - // TODO (hme): Add a test for this. - result = Some(int_lit.base10_parse().unwrap()); - } - _ => {} - } + GenericArgument::Const(Expr::Lit(lit)) => { + let Lit::Int(int_lit) = &lit.lit else { + panic!("Expected int literal, got {:#?}", lit) + }; + // This is something like 32 in Tile + // TODO (hme): Add a test for this. + result = Some(int_lit.base10_parse().unwrap()); } _ => {} } @@ -2322,87 +2294,69 @@ pub fn get_cga_from_generic_argument( ) -> Option> { let mut shape: Option> = None; match generic_arg { - GenericArgument::Type(type_param) => { - match type_param { - syn::Type::Path(type_path) => { - // This must be a CGA, or it will fail. - let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); - // println!("get_variadic_type_args: Type::Path: {}", last_ident); - if generic_args.inst_array.contains_key(&last_ident) { - // This is something like Shape for const generic array D: [i32; N]. - let array_instance = generic_args.inst_array.get(&last_ident).unwrap(); - if shape.is_some() { - panic!("Unexpected array arg: {last_ident:#?}") - } - shape = Some(array_instance.clone()); - } else if generic_args.inst_i32.contains_key(&last_ident) { - // This is something like N for const generic N: i32. - // This should have been handled by - // try_get_const_generic_from_generic_argument. - unimplemented!( - "Unexpected const arg {last_ident} for type {type_param:#?}" - ); - } else { - unimplemented!("Failed to get cga for {type_param:#?}"); - } + GenericArgument::Type(syn::Type::Path(type_path)) => { + // This must be a CGA, or it will fail. + let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); + // println!("get_variadic_type_args: Type::Path: {}", last_ident); + if let Some(array_instance) = generic_args.inst_array.get(&last_ident) { + // This is something like Shape for const generic array D: [i32; N]. + if shape.is_some() { + panic!("Unexpected array arg: {last_ident:#?}") } - _ => {} + shape = Some(array_instance.clone()); + } else if generic_args.inst_i32.contains_key(&last_ident) { + // This is something like N for const generic N: i32. + // This should have been handled by + // try_get_const_generic_from_generic_argument. + unimplemented!("Unexpected const arg {last_ident} for type {type_path:#?}"); + } else { + unimplemented!("Failed to get cga for {type_path:#?}"); } } - GenericArgument::Const(const_param) => { - // println!("expand GenericArgument::Const? {const_param:#?}"); - match const_param { - Expr::Block(block_expr) => { - // This is something like Tensor - assert_eq!(block_expr.block.stmts.len(), 1); - let statement = &block_expr.block.stmts[0]; - let Stmt::Expr(statement_expr, _) = statement else { - panic!("Unexpected block expression.") - }; - match statement_expr { - Expr::Array(array_expr) => { - // This is something like Tensor - let mut _shape: Vec = vec![]; - for elem in &array_expr.elems { - _shape.push(parse_expr_as_i32(elem, generic_args)); + GenericArgument::Const(Expr::Block(block_expr)) => { + // This is something like Tensor + assert_eq!(block_expr.block.stmts.len(), 1); + let statement = &block_expr.block.stmts[0]; + let Stmt::Expr(statement_expr, _) = statement else { + panic!("Unexpected block expression.") + }; + match statement_expr { + Expr::Array(array_expr) => { + // This is something like Tensor + let mut _shape: Vec = vec![]; + for elem in &array_expr.elems { + _shape.push(parse_expr_as_i32(elem, generic_args)); + } + shape = Some(_shape); + } + Expr::Repeat(repeat_expr) => { + // println!("Expr::Repeat: {:?}", repeat_expr.expr); + let thing_to_repeat = parse_expr_as_i32(&repeat_expr.expr, generic_args); + match &*repeat_expr.len { + Expr::Path(len_path) => { + // This is something like Tensor + let num_rep_var = len_path.to_token_stream().to_string(); + if generic_args.get_i32(&num_rep_var).is_none() { + panic!("Expected instance for generic argument {}", num_rep_var); } - shape = Some(_shape); + let num_rep = generic_args.get_i32(&num_rep_var).unwrap(); + shape = Some(vec![thing_to_repeat; num_rep as usize]); } - Expr::Repeat(repeat_expr) => { - // println!("Expr::Repeat: {:?}", repeat_expr.expr); - let thing_to_repeat = - parse_expr_as_i32(&repeat_expr.expr, generic_args); - match &*repeat_expr.len { - Expr::Path(len_path) => { - // This is something like Tensor - let num_rep_var = len_path.to_token_stream().to_string(); - if !generic_args.get_i32(&num_rep_var).is_some() { - panic!( - "Expected instance for generic argument {}", - num_rep_var - ); - } - let num_rep = generic_args.get_i32(&num_rep_var).unwrap(); - shape = Some(vec![thing_to_repeat; num_rep as usize]); - } - Expr::Lit(len_lit) => { - // This is something like Tensor - let num_rep: u32 = len_lit - .to_token_stream() - .to_string() - .parse::() - .unwrap(); - shape = Some(vec![thing_to_repeat; num_rep as usize]); - } - _ => { - unimplemented!("Unexpected repeat expression: {repeat_expr:#?}") - } - } + Expr::Lit(len_lit) => { + // This is something like Tensor + let num_rep: u32 = len_lit + .to_token_stream() + .to_string() + .parse::() + .unwrap(); + shape = Some(vec![thing_to_repeat; num_rep as usize]); + } + _ => { + unimplemented!("Unexpected repeat expression: {repeat_expr:#?}") } - _ => panic!("Unexpected block expression."), } } - _ => {} + _ => panic!("Unexpected block expression."), } } _ => {} @@ -2417,7 +2371,7 @@ pub fn parse_expr_as_i32(expr: &Expr, generic_args: &GenericVars) -> i32 { Expr::Path(path) => { let ident = get_ident_from_path_expr(path); match generic_args.inst_i32.get(ident.to_string().as_str()) { - Some(val) => return *val, + Some(&val) => val, None => panic!("Undefined generic parameter {ident}"), } } diff --git a/cutile-compiler/src/hints.rs b/cutile-compiler/src/hints.rs index 1729c01..7f9a25c 100644 --- a/cutile-compiler/src/hints.rs +++ b/cutile-compiler/src/hints.rs @@ -153,47 +153,46 @@ impl OptimizationHints { result.target_gpu_name = Some(target_gpu_name); for sm_key_val in &opt_hints.elems { let (opt_key, opt_value) = Self::parse_key_value(sm_key_val)?; - match opt_key.as_str() { - _ => { - if !opt_key.starts_with("sm_") { - return SourceLocation::unknown().jit_error_result(&format!( - "Unexpected optimization hint {}.", - sm_key_val.to_token_stream().to_string() - )); - } - let Expr::Tuple(hints_tuple) = opt_value else { - return SourceLocation::unknown() - .jit_error_result("expected a tuple expression for architecture-specific optimization hints"); - }; - let mut sm_hints_result = SMHints::new(opt_key.clone()); - for hint_key_val in hints_tuple.elems.iter() { - let (key, hints) = Self::parse_key_value(hint_key_val)?; - match key.as_str() { - "num_cta_in_cga" => sm_hints_result.set_num_cta_in_cga(&hints)?, - "occupancy" => sm_hints_result.set_occupancy(&hints)?, - "max_divisibility" => sm_hints_result.set_max_divisibility(&hints)?, - "allow_tma" | "latency" => { - return SourceLocation::unknown().jit_error_result(&format!( - "'{key}' is a per-op hint and cannot be set at the entry level. \ - Use it as a parameter on individual load/store operations instead." - )); - } - _ => { - return SourceLocation::unknown().jit_error_result(&format!( - "Unexpected optimization hint key '{key}'." - )); - } + { + if !opt_key.starts_with("sm_") { + return SourceLocation::unknown().jit_error_result(&format!( + "Unexpected optimization hint {}.", + sm_key_val.to_token_stream() + )); + } + let Expr::Tuple(hints_tuple) = opt_value else { + return SourceLocation::unknown().jit_error_result( + "expected a tuple expression for architecture-specific optimization hints", + ); + }; + let mut sm_hints_result = SMHints::new(opt_key.clone()); + for hint_key_val in hints_tuple.elems.iter() { + let (key, hints) = Self::parse_key_value(hint_key_val)?; + match key.as_str() { + "num_cta_in_cga" => sm_hints_result.set_num_cta_in_cga(&hints)?, + "occupancy" => sm_hints_result.set_occupancy(&hints)?, + "max_divisibility" => sm_hints_result.set_max_divisibility(&hints)?, + "allow_tma" | "latency" => { + return SourceLocation::unknown().jit_error_result(&format!( + "'{key}' is a per-op hint and cannot be set at the entry level. \ + Use it as a parameter on individual load/store operations instead." + )); + } + _ => { + return SourceLocation::unknown().jit_error_result(&format!( + "Unexpected optimization hint key '{key}'." + )); } } - if result - .tile_as_hints - .insert(opt_key.clone(), sm_hints_result) - .is_some() - { - return SourceLocation::unknown().jit_error_result(&format!( - "Duplicate optimization hint key '{opt_key}'." - )); - } + } + if result + .tile_as_hints + .insert(opt_key.clone(), sm_hints_result) + .is_some() + { + return SourceLocation::unknown().jit_error_result(&format!( + "Duplicate optimization hint key '{opt_key}'." + )); } } } diff --git a/cutile-compiler/src/kernel_entry_generator.rs b/cutile-compiler/src/kernel_entry_generator.rs index d3bcfbb..bf87796 100644 --- a/cutile-compiler/src/kernel_entry_generator.rs +++ b/cutile-compiler/src/kernel_entry_generator.rs @@ -197,11 +197,7 @@ impl TensorInput { fn_args } - fn get_dynamic_elements( - &self, - static_elements: &Vec, - i_arg_name: String, - ) -> Vec { + fn get_dynamic_elements(&self, static_elements: &[String], i_arg_name: String) -> Vec { let var_name = self.var_name.clone(); let mut dynamic_elements = vec![]; for (i, dim) in static_elements.iter().enumerate() { @@ -452,6 +448,7 @@ fn generic_arg_to_const_array_string(arg: &GenericArgument) -> Result().expect(format!("{s}").as_str()) + s.parse::().unwrap_or_else(|_| panic!("{s}")) } }) .collect::>(), @@ -545,7 +542,7 @@ pub fn generate_entry_point( if s == "- 1" { -1 } else { - s.parse::().expect(format!("{s}").as_str()) + s.parse::().unwrap_or_else(|_| panic!("{}", s)) } }) .collect::>(), @@ -687,183 +684,154 @@ pub fn get_tensor_shape( let mut shape: Option = None; for generic_arg in &type_generic_args.args { match generic_arg { - GenericArgument::Type(type_param) => { + GenericArgument::Type(syn::Type::Path(type_path)) => { // Currently, this is either shape or element_type - match type_param { - syn::Type::Path(type_path) => { - let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); - // println!("get_variadic_type_args: Type::Path: {}", last_ident); - if shape.is_none() && generic_vars.inst_array.contains_key(&last_ident) { - // This is something like Shape for const generic array D: [i32; N]. - let array_instance = generic_vars.inst_array.get(&last_ident).unwrap(); - shape = Some(InputTensorShape { - generic_cga_var: Some(last_ident.clone()), - shape_param: last_ident, - shape: array_instance.iter().map(|elem| elem.to_string()).collect(), - }); - } - } - _ => {} + let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); + // println!("get_variadic_type_args: Type::Path: {}", last_ident); + if shape.is_none() && generic_vars.inst_array.contains_key(&last_ident) { + // This is something like Shape for const generic array D: [i32; N]. + let array_instance = generic_vars.inst_array.get(&last_ident).unwrap(); + shape = Some(InputTensorShape { + generic_cga_var: Some(last_ident.clone()), + shape_param: last_ident, + shape: array_instance.iter().map(|elem| elem.to_string()).collect(), + }); } } - GenericArgument::Const(const_param) => { - // println!("expand GenericArgument::Const? {const_param:#?}"); - match const_param { - Expr::Block(block_expr) => { - // This is something like Tensor - if block_expr.block.stmts.len() != 1 { - return SourceLocation::unknown().jit_error_result(&format!( - "Expected exactly 1 statement in block expression, got {}", - block_expr.block.stmts.len() - )); - } - let statement = &block_expr.block.stmts[0]; - let Stmt::Expr(statement_expr, _) = statement else { - return SourceLocation::unknown() - .jit_error_result("Unexpected block expression."); - }; - match statement_expr { - Expr::Array(array_expr) => { - // This is something like Tensor - let mut _shape = vec![]; - for elem in &array_expr.elems { - match elem { - Expr::Lit(lit) => { - let val = match &lit.lit { - Lit::Int(int_lit) => int_lit.to_string(), - _ => return SourceLocation::unknown().jit_error_result( - &format!("Unexpected array element {elem:#?} in {array_expr:#?}"), - ), - }; - _shape.push(val); - } - Expr::Unary(unary_expr) => { - _shape.push(unary_expr.to_token_stream().to_string()); - } - Expr::Path(path) => { - let ident = get_ident_from_path_expr(path); - match generic_vars - .inst_i32 - .get(ident.to_string().as_str()) - { - Some(val) => _shape.push(val.to_string()), - None => { - return SourceLocation::unknown() - .jit_error_result(&format!( - "Undefined generic parameter {ident}" - )); - } - } - } - Expr::Index(index) => { - let Expr::Path(path) = index.expr.as_ref() else { - return SourceLocation::unknown().jit_error_result( - &format!( - "Unexpected const generic array base {elem:#?}" - ), - ); - }; - let ident = get_ident_from_path_expr(path); - let Some(shape) = generic_vars - .inst_array - .get(ident.to_string().as_str()) - else { - return SourceLocation::unknown() - .jit_error_result(&format!( - "Undefined const generic array parameter {ident}" - )); - }; - let i = crate::types::parse_signed_literal_as_i32( - &index.index, - ); - let Some(dim) = shape.get(i as usize) else { - return SourceLocation::unknown() - .jit_error_result(&format!( - "Index {i} out of bounds for const generic array `{ident}` of length {}", - shape.len() - )); - }; - _shape.push(dim.to_string()); - } - _ => { - return SourceLocation::unknown().jit_error_result( - &format!( + GenericArgument::Const(Expr::Block(block_expr)) => { + // This is something like Tensor + if block_expr.block.stmts.len() != 1 { + return SourceLocation::unknown().jit_error_result(&format!( + "Expected exactly 1 statement in block expression, got {}", + block_expr.block.stmts.len() + )); + } + let statement = &block_expr.block.stmts[0]; + let Stmt::Expr(statement_expr, _) = statement else { + return SourceLocation::unknown() + .jit_error_result("Unexpected block expression."); + }; + match statement_expr { + Expr::Array(array_expr) => { + // This is something like Tensor + let mut _shape = vec![]; + for elem in &array_expr.elems { + match elem { + Expr::Lit(lit) => { + let val = match &lit.lit { + Lit::Int(int_lit) => int_lit.to_string(), + _ => return SourceLocation::unknown() + .jit_error_result(&format!( "Unexpected array element {elem:#?} in {array_expr:#?}" - ), - ) + )), + }; + _shape.push(val); + } + Expr::Unary(unary_expr) => { + _shape.push(unary_expr.to_token_stream().to_string()); + } + Expr::Path(path) => { + let ident = get_ident_from_path_expr(path); + match generic_vars.inst_i32.get(ident.to_string().as_str()) { + Some(val) => _shape.push(val.to_string()), + None => { + return SourceLocation::unknown().jit_error_result( + &format!("Undefined generic parameter {ident}"), + ); } } } + Expr::Index(index) => { + let Expr::Path(path) = index.expr.as_ref() else { + return SourceLocation::unknown().jit_error_result( + &format!( + "Unexpected const generic array base {elem:#?}" + ), + ); + }; + let ident = get_ident_from_path_expr(path); + let Some(shape) = + generic_vars.inst_array.get(ident.to_string().as_str()) + else { + return SourceLocation::unknown().jit_error_result( + &format!( + "Undefined const generic array parameter {ident}" + ), + ); + }; + let i = crate::types::parse_signed_literal_as_i32(&index.index); + let Some(dim) = shape.get(i as usize) else { + return SourceLocation::unknown() + .jit_error_result(&format!( + "Index {i} out of bounds for const generic array `{ident}` of length {}", + shape.len() + )); + }; + _shape.push(dim.to_string()); + } + _ => { + return SourceLocation::unknown().jit_error_result(&format!( + "Unexpected array element {elem:#?} in {array_expr:#?}" + )) + } + } + } + if shape.is_none() { + shape = Some(InputTensorShape { + generic_cga_var: None, + shape_param: block_expr.block.to_token_stream().to_string(), + shape: _shape, + }); + } + } + Expr::Repeat(repeat_expr) => { + // println!("Expr::Repeat: {:?}", repeat_expr.expr); + let thing_to_repeat = repeat_expr.expr.to_token_stream().to_string(); + match &*repeat_expr.len { + Expr::Path(len_path) => { + // This is something like Tensor + let num_rep_var = len_path.to_token_stream().to_string(); + if generic_vars.get_i32(&num_rep_var).is_none() { + return SourceLocation::unknown().jit_error_result(&format!( + "Expected instance for generic argument {}", + num_rep_var + )); + } + let num_rep = generic_vars.get_i32(&num_rep_var).unwrap(); if shape.is_none() { shape = Some(InputTensorShape { generic_cga_var: None, shape_param: block_expr.block.to_token_stream().to_string(), - shape: _shape, + shape: vec![thing_to_repeat; num_rep as usize], }); } } - Expr::Repeat(repeat_expr) => { - // println!("Expr::Repeat: {:?}", repeat_expr.expr); - let thing_to_repeat = - repeat_expr.expr.to_token_stream().to_string(); - match &*repeat_expr.len { - Expr::Path(len_path) => { - // This is something like Tensor - let num_rep_var = len_path.to_token_stream().to_string(); - if !generic_vars.get_i32(&num_rep_var).is_some() { - return SourceLocation::unknown().jit_error_result( - &format!( - "Expected instance for generic argument {}", - num_rep_var - ), - ); - } - let num_rep = generic_vars.get_i32(&num_rep_var).unwrap(); - if shape.is_none() { - shape = Some(InputTensorShape { - generic_cga_var: None, - shape_param: block_expr - .block - .to_token_stream() - .to_string(), - shape: vec![thing_to_repeat; num_rep as usize], - }); - } - } - Expr::Lit(len_lit) => { - // This is something like Tensor - let num_rep: u32 = len_lit - .to_token_stream() - .to_string() - .parse::() - .unwrap(); - if shape.is_none() { - shape = Some(InputTensorShape { - generic_cga_var: None, - shape_param: block_expr - .block - .to_token_stream() - .to_string(), - shape: vec![thing_to_repeat; num_rep as usize], - }); - } - } - _ => { - return SourceLocation::unknown().jit_error_result( - &format!( - "Unexpected repeat expression: {repeat_expr:#?}" - ), - ) - } + Expr::Lit(len_lit) => { + // This is something like Tensor + let num_rep: u32 = len_lit + .to_token_stream() + .to_string() + .parse::() + .unwrap(); + if shape.is_none() { + shape = Some(InputTensorShape { + generic_cga_var: None, + shape_param: block_expr.block.to_token_stream().to_string(), + shape: vec![thing_to_repeat; num_rep as usize], + }); } } _ => { - return SourceLocation::unknown() - .jit_error_result("Unexpected block expression.") + return SourceLocation::unknown().jit_error_result(&format!( + "Unexpected repeat expression: {repeat_expr:#?}" + )) } } } - _ => {} + _ => { + return SourceLocation::unknown() + .jit_error_result("Unexpected block expression.") + } } } _ => {} diff --git a/cutile-compiler/src/passes/node_ids.rs b/cutile-compiler/src/passes/node_ids.rs index 28694a9..0db81a6 100644 --- a/cutile-compiler/src/passes/node_ids.rs +++ b/cutile-compiler/src/passes/node_ids.rs @@ -38,7 +38,7 @@ pub fn assign_block_expr_ids(block: &mut syn::Block) { pub fn expr_id(expr: &Expr) -> Option { expr_attrs(expr)? .iter() - .find_map(|attr| node_id_from_attr(attr)) + .find_map(node_id_from_attr) } pub fn set_expr_id(expr: &mut Expr, id: NodeId) { @@ -83,7 +83,7 @@ impl VisitMut for NodeIdAssigner { match expr { Expr::Assign(assign) => { // The destination is binding syntax, not a value expression. - self.visit_expr_mut(&mut *assign.right); + self.visit_expr_mut(&mut assign.right); } Expr::Call(call) => { // The callee is name-resolution syntax in this DSL. diff --git a/cutile-compiler/src/passes/type_inference.rs b/cutile-compiler/src/passes/type_inference.rs index e851f58..e666a23 100644 --- a/cutile-compiler/src/passes/type_inference.rs +++ b/cutile-compiler/src/passes/type_inference.rs @@ -363,7 +363,7 @@ enum InferVarKind { #[derive(Clone, Debug)] enum InferredTy { - Known(TileRustType), + Known(Box), Var(InferVarId), Unknown, } @@ -559,7 +559,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { self.infer_expr(&init.expr, annotated_type.clone())? } else { match self.infer_expr_term(&init.expr, None)? { - InferredTy::Known(ty) => Some(ty), + InferredTy::Known(ty) => Some(*ty), term @ InferredTy::Var(_) => { if let Some(name) = local_binding_name(&local.pat) { self.bind_inferred_var(name, term); @@ -797,7 +797,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { self.infer_bool_binary(binary)? } else if expected.is_none() && binary_result_matches_operands(&binary.op) { match self.infer_expr_term(expr, None)? { - InferredTy::Known(ty) => Some(ty), + InferredTy::Known(ty) => Some(*ty), term @ InferredTy::Var(_) => { self.record_expr_term(expr, term); None @@ -826,7 +826,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { Expr::Unary(unary) => { if expected.is_none() && matches!(unary.op, syn::UnOp::Neg(_)) { match self.infer_expr_term(expr, None)? { - InferredTy::Known(ty) => Some(ty), + InferredTy::Known(ty) => Some(*ty), term @ InferredTy::Var(_) => { self.record_expr_term(expr, term); None @@ -861,7 +861,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { then_ty.or(else_ty).or(Some(expected)) } else { match self.infer_if_expr_term(if_expr)? { - InferredTy::Known(ty) => Some(ty), + InferredTy::Known(ty) => Some(*ty), term @ InferredTy::Var(_) => { self.record_expr_term(expr, term); None @@ -913,7 +913,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { }; if let Some(ty) = inferred.clone() { - self.record_expr_term(expr, InferredTy::Known(ty.clone())); + self.record_expr_term(expr, InferredTy::Known(Box::new(ty.clone()))); self.results.insert_expr_type(expr, ty); } Ok(inferred) @@ -922,7 +922,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { fn bind_tile_var(&mut self, name: String, ty: TileRustType) { self.syn_vars.insert(name.clone(), ty.rust_ty.clone()); self.local_terms - .insert(name.clone(), InferredTy::Known(ty.clone())); + .insert(name.clone(), InferredTy::Known(Box::new(ty.clone()))); self.vars.insert(name, ty); } @@ -967,14 +967,15 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { let _ = self.infer_expr(source_expr, None)?; self.results .insert_resolved_expr_type(cast_expr, ResolvedType::from_syn_type(target_ty)); - match self - .compiler + self.compiler .compile_type(target_ty, self.generic_vars, &HashMap::new()) - { - Ok(ty) => Ok(ty), - Err(err) if is_surface_only_scalar_type(target_ty) => Ok(None), - Err(err) => Err(err), - } + .or_else(|e| { + if is_surface_only_scalar_type(target_ty) { + Ok(None) + } else { + Err(e) + } + }) } fn infer_bool_binary( @@ -1267,23 +1268,24 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { if pats.len().saturating_sub(1) > exprs.len() { return Ok(()); } - for idx in 0..rest_pos { - pairs.push((pats[idx], idx)); - } - let suffix_len = pats.len() - rest_pos - 1; - for suffix_idx in 0..suffix_len { - pairs.push(( - pats[rest_pos + 1 + suffix_idx], - exprs.len() - suffix_len + suffix_idx, - )); + for (idx, pat) in pats.iter().enumerate() { + if idx == rest_pos { + continue; + } + let idx = if idx < rest_pos { + idx + } else { + exprs.len() - pats.len() + idx + }; + pairs.push((pat, idx)); } } None => { if pats.len() != exprs.len() { return Ok(()); } - for idx in 0..pats.len() { - pairs.push((pats[idx], idx)); + for (idx, pat) in pats.iter().enumerate() { + pairs.push((pat, idx)); } } } @@ -1357,9 +1359,9 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { return Ok(term); } if let Some(ty) = self.infer_expr(expr, Some(expected.clone()))? { - return Ok(InferredTy::Known(ty)); + return Ok(InferredTy::Known(Box::new(ty))); } - return Ok(InferredTy::Known(expected)); + return Ok(InferredTy::Known(Box::new(expected))); } if let Some(existing) = self.expr_term(expr) { @@ -1373,16 +1375,16 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { self.record_expr_term(expr, term.clone()); Ok(term) } else if let Some(ty) = self.vars.get(&name).cloned() { - Ok(InferredTy::Known(ty)) + Ok(InferredTy::Known(Box::new(ty))) } else if let Some(ty) = self.infer_global_const_type(path)? { - Ok(InferredTy::Known(ty)) + Ok(InferredTy::Known(Box::new(ty))) } else { Ok(InferredTy::Unknown) } } Expr::Path(path) => { if let Some(ty) = self.infer_associated_const_type(path)? { - Ok(InferredTy::Known(ty)) + Ok(InferredTy::Known(Box::new(ty))) } else { Ok(InferredTy::Unknown) } @@ -1395,7 +1397,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { else { return Ok(InferredTy::Unknown); }; - Ok(InferredTy::Known(ty)) + Ok(InferredTy::Known(Box::new(ty))) } else { Ok(self.literal_infer_var(expr, lit)) } @@ -1414,7 +1416,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { let term = self.infer_block_term(&block.block)?; if matches!(term, InferredTy::Unknown) { if let Some(ty) = self.infer_expr(expr, None)? { - return Ok(InferredTy::Known(ty)); + return Ok(InferredTy::Known(Box::new(ty))); } } self.record_expr_term(expr, term.clone()); @@ -1424,7 +1426,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { let term = self.infer_block_term(&unsafe_expr.block)?; if matches!(term, InferredTy::Unknown) { if let Some(ty) = self.infer_expr(expr, None)? { - return Ok(InferredTy::Known(ty)); + return Ok(InferredTy::Known(Box::new(ty))); } } self.record_expr_term(expr, term.clone()); @@ -1442,7 +1444,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { return Ok(InferredTy::Unknown); }; let _ = self.infer_bool_binary(binary)?; - let term = InferredTy::Known(bool_ty); + let term = InferredTy::Known(Box::new(bool_ty)); self.record_expr_term(expr, term.clone()); Ok(term) } @@ -1450,7 +1452,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { let Some(ty) = self.infer_cast_target(expr, &cast.expr, &cast.ty)? else { return Ok(InferredTy::Unknown); }; - let term = InferredTy::Known(ty); + let term = InferredTy::Known(Box::new(ty)); self.record_expr_term(expr, term.clone()); Ok(term) } @@ -1461,7 +1463,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { Ok(term) } _ => match self.infer_expr(expr, None)? { - Some(ty) => Ok(InferredTy::Known(ty)), + Some(ty) => Ok(InferredTy::Known(Box::new(ty))), None => Ok(InferredTy::Unknown), }, } @@ -1621,7 +1623,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { fn term_known_type(&self, term: &InferredTy) -> Option { match self.normalize_term(term.clone()) { - InferredTy::Known(ty) => Some(ty.clone()), + InferredTy::Known(ty) => Some(*ty), InferredTy::Var(id) => { let root_id = self.find_var_id(id)?; self.inference @@ -1672,7 +1674,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { } (InferredTy::Known(known), term @ InferredTy::Var(_)) | (term @ InferredTy::Var(_), InferredTy::Known(known)) => { - self.unify_with_known(term, known) + self.unify_with_known(term, *known) } (InferredTy::Var(lhs), InferredTy::Var(rhs)) => self.unify_vars(lhs, rhs), } @@ -1750,7 +1752,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { ) -> Result { match self.normalize_term(term) { InferredTy::Known(known) => Ok(InferredTy::Known(known)), - InferredTy::Unknown => Ok(InferredTy::Known(expected)), + InferredTy::Unknown => Ok(InferredTy::Known(Box::new(expected))), InferredTy::Var(id) => { let Some(root_id) = self.find_var_id(id) else { return Ok(InferredTy::Unknown); @@ -1759,12 +1761,12 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { return Ok(InferredTy::Unknown); }; if let Some(known) = &var.value { - return Ok(InferredTy::Known(known.clone())); + return Ok(InferredTy::Known(Box::new(known.clone()))); } if literal_kind_accepts_type(var.kind, &expected.rust_ty) { var.value = Some(expected.clone()); var.origin_propagated = false; - Ok(InferredTy::Known(expected)) + Ok(InferredTy::Known(Box::new(expected))) } else { Ok(InferredTy::Var(id)) } @@ -2452,7 +2454,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { if let Expr::Closure(closure) = arg { if let Some(signature) = signature { if let Some((param_types, return_type)) = - self.instantiate_closure_signature(&signature, &generic_arg_inf) + self.instantiate_closure_signature(signature, &generic_arg_inf) { self.infer_closure(closure, ¶m_types, return_type)?; } else { @@ -2797,7 +2799,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { } } - for ((trait_name, self_name), _impls) in self.compiler.modules.trait_impls() { + for (trait_name, self_name) in self.compiler.modules.trait_impls().keys() { if self_name == &qualifier_name { if let Some(ty) = self.trait_associated_const_syn_type(trait_name, &const_name, &self_ty) @@ -2806,7 +2808,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { } } } - for ((trait_name, self_name), _impl_item) in self.compiler.modules.primitives() { + for (trait_name, self_name) in self.compiler.modules.primitives().keys() { if self_name == &qualifier_name { if let Some(ty) = self.trait_associated_const_syn_type(trait_name, &const_name, &self_ty) @@ -3015,7 +3017,7 @@ impl<'a, 'm> TypeInferenceCx<'a, 'm> { return Ok(None); } let mut generic_arg_inf = GenericArgInference::new_method(impl_item, impl_method); - generic_arg_inf.map_args_to_params(&call_arg_rust_tys.to_vec(), Some(self_ty)); + generic_arg_inf.map_args_to_params(call_arg_rust_tys, Some(self_ty)); generic_arg_inf.apply_provided_generics_method_call(method_call, self.generic_vars); if !generic_arg_inf.verify() { return Ok(None); @@ -4430,7 +4432,7 @@ pub fn infer_method_generics( } let mut generic_arg_inference = GenericArgInference::new_method(impl_item, impl_method); - generic_arg_inference.map_args_to_params(&call_arg_rust_tys.to_vec(), Some(self_ty)); + generic_arg_inference.map_args_to_params(call_arg_rust_tys, Some(self_ty)); let inferred = generic_arg_inference.get_generic_vars_instance(caller_generic_vars, primitives); if method_call.turbofish.is_some() { diff --git a/cutile-compiler/src/specialization.rs b/cutile-compiler/src/specialization.rs index 597310e..7e5f2ab 100644 --- a/cutile-compiler/src/specialization.rs +++ b/cutile-compiler/src/specialization.rs @@ -163,7 +163,7 @@ pub fn compute_spec( .map(|(&s, &d)| (s, d)) .collect(); sorted.sort(); - spec.elements_disjoint = sorted.first().map_or(true, |(s, _)| *s > 0); + spec.elements_disjoint = sorted.first().is_none_or(|(s, _)| *s > 0); for w in sorted.windows(2) { if w[1].0 <= 0 || w[1].0 < w[0].0 * w[0].1 { spec.elements_disjoint = false; diff --git a/cutile-compiler/src/syn_utils.rs b/cutile-compiler/src/syn_utils.rs index 27dc707..b517566 100644 --- a/cutile-compiler/src/syn_utils.rs +++ b/cutile-compiler/src/syn_utils.rs @@ -122,32 +122,28 @@ impl SingleMetaList { let Meta::List(meta_list) = attr.meta else { panic!("Unexpected attribute list {:#?}", attr.meta) }; - let tokens = proc_macro2::TokenStream::from(meta_list.tokens.clone()); + let tokens = meta_list.tokens.clone(); let mut result = syn::parse2::(tokens).unwrap(); result.name = Some(meta_list.path.to_token_stream().to_string()); result.meta_list = Some(meta_list); - return result; + result } /// Returns the attribute path as a single string. - pub fn name_as_str(&self) -> Option { - match &self.name { - Some(s) => Some(s.clone()), - None => None, - } + pub fn name_as_str(&self) -> Option<&str> { + self.name.as_deref() } /// Returns the attribute path split by `::` separators. pub fn name_as_vec(&self) -> Option> { - match &self.name { - Some(s) => Some(s.as_str().split(" :: ").collect()), - None => None, - } + self.name + .as_ref() + .map(|s| s.as_str().split(" :: ").collect()) } fn get_value(&self, name: &str) -> Option<&Expr> { for item in &self.variables { match item { Meta::NameValue(name_value) => { let meta_ident = name_value.path.get_ident(); - let meta_name = meta_ident.clone().unwrap().to_string(); + let meta_name = meta_ident.unwrap().to_string(); if name == meta_name { return Some(&name_value.value); } @@ -247,10 +243,10 @@ impl Parse for SingleMetaList { } } -impl Into> for SingleMetaList { - fn into(self) -> Vec { +impl From for Vec { + fn from(val: SingleMetaList) -> Self { let mut res = vec![]; - for meta in self.variables { + for meta in val.variables { let attr: Attribute = parse_quote! { #[noname(#meta)] }; @@ -261,7 +257,7 @@ impl Into> for SingleMetaList { } /// Removes all attributes whose path matches one of the given names. -pub fn clear_attributes(attr_names: HashSet<&str>, attrs: &mut Vec) -> () { +pub fn clear_attributes(attr_names: HashSet<&str>, attrs: &mut Vec) { // filter == keep *attrs = attrs .clone() @@ -310,10 +306,7 @@ pub fn get_attribute( /// Looks up an attribute by full path and parses it as a [`SingleMetaList`]. pub fn get_meta_list(attr_name: &str, outer_attrs: &Vec) -> Option { - match get_attribute(attr_name, outer_attrs, false) { - Some(attr) => Some(SingleMetaList::from_attribute(attr)), - None => None, - } + get_attribute(attr_name, outer_attrs, false).map(SingleMetaList::from_attribute) } /// Like [`get_meta_list`] but matches only the last path segment. @@ -321,10 +314,7 @@ pub fn get_meta_list_by_last_segment( last_seg: &str, outer_attrs: &Vec, ) -> Option { - match get_attribute(last_seg, outer_attrs, true) { - Some(attr) => Some(SingleMetaList::from_attribute(attr)), - None => None, - } + get_attribute(last_seg, outer_attrs, true).map(SingleMetaList::from_attribute) } /// Finds the first `cuda_tile::*` attribute and parses it as a [`SingleMetaList`]. @@ -621,7 +611,7 @@ pub fn get_sig_output_type(sig: &Signature) -> Type { pub fn function_returns(fn_item: &ItemFn) -> bool { match &fn_item.sig.output { ReturnType::Type(_, return_type) => match &**return_type { - Type::Tuple(type_tuple) => type_tuple.elems.len() > 0, + Type::Tuple(type_tuple) => !type_tuple.elems.is_empty(), _ => true, }, ReturnType::Default => false, @@ -642,7 +632,7 @@ pub fn get_ident_from_path_expr(path_expr: &ExprPath) -> Ident { pub fn get_ident_from_expr(expr: &Expr) -> Option { match expr { Expr::Path(path_expr) => Some(get_ident_from_path(&path_expr.path)), - Expr::Reference(ref_expr) => get_ident_from_expr(&*ref_expr.expr), + Expr::Reference(ref_expr) => get_ident_from_expr(&ref_expr.expr), _ => None, } } @@ -752,7 +742,7 @@ pub fn get_supported_generic_params(generics: &Generics) -> Vec<(String, Option< } /// Removes lifetime arguments from angle-bracketed generic arguments in place. -pub fn strip_generic_args_lifetimes(gen_args: &mut AngleBracketedGenericArguments) -> () { +pub fn strip_generic_args_lifetimes(gen_args: &mut AngleBracketedGenericArguments) { let mut res = gen_args.args.clone(); res.clear(); for gen_arg in gen_args.args.iter() { @@ -765,7 +755,7 @@ pub fn strip_generic_args_lifetimes(gen_args: &mut AngleBracketedGenericArgument } /// Removes lifetime parameters from generics in place. -pub fn strip_generics_lifetimes(generics: &mut Generics) -> () { +pub fn strip_generics_lifetimes(generics: &mut Generics) { let mut res = generics.params.clone(); res.clear(); for gen_param in generics.params.iter() { diff --git a/cutile-compiler/src/type_aliases.rs b/cutile-compiler/src/type_aliases.rs index 150a6e9..e55ce10 100644 --- a/cutile-compiler/src/type_aliases.rs +++ b/cutile-compiler/src/type_aliases.rs @@ -142,7 +142,7 @@ pub fn normalize_item_fn_param_type_aliases( let FnArg::Typed(PatType { ty, .. }) = arg else { continue; }; - *ty = Box::new(normalize_type_aliases(ty, aliases)?); + **ty = normalize_type_aliases(ty, aliases)?; } Ok(item) } @@ -292,7 +292,7 @@ fn build_alias_substitution( } let mut subst = AliasSubstitution::default(); - for (formal, actual) in formals.into_iter().zip(actual_args.into_iter()) { + for (formal, actual) in formals.into_iter().zip(actual_args) { match (formal, actual) { (GenericParam::Type(param), GenericArgument::Type(ty)) => { subst.types.insert(param.ident.to_string(), ty); diff --git a/cutile-compiler/src/types.rs b/cutile-compiler/src/types.rs index 9abc72d..71da227 100644 --- a/cutile-compiler/src/types.rs +++ b/cutile-compiler/src/types.rs @@ -204,7 +204,7 @@ impl TypeParamPrimitive { } _ => SourceLocation::unknown().jit_error_result(&format!( "unsupported primitive type `{}`", - self.rust_ty.to_token_stream().to_string() + self.rust_ty.to_token_stream() )), } } @@ -555,10 +555,11 @@ pub fn get_ptr_type_instance( }; if is_element_type(&maybe_element_type, primitives) { Some((prefix, maybe_element_type)) - } else if let Some(element_type) = generic_vars.inst_types.get(&maybe_element_type) { - Some((prefix, element_type.to_string())) } else { - None + generic_vars + .inst_types + .get(&maybe_element_type) + .map(|element_type| (prefix, element_type.to_string())) } } @@ -575,12 +576,12 @@ pub fn get_cuda_tile_element_type_from_rust_primitive_str( /// Returns the Rust identifier string for a primitive type. pub fn get_rust_element_type_primitive(ty: &syn::Type) -> String { - let type_ident = get_type_ident(&ty); + let type_ident = get_type_ident(ty); assert!( type_ident.is_some(), "get_element_type_primitive failed for {ty:#?}" ); - return type_ident.unwrap().to_string(); + type_ident.unwrap().to_string() } /// Returns the CUDA Tile element type string for a Rust primitive type. @@ -612,8 +613,8 @@ pub fn get_element_type_structured( let (_type_ident, type_generic_args) = get_ident_generic_args(ty); let mut element_type: Option = None; for generic_arg in &type_generic_args.args { - match generic_arg { - GenericArgument::Type(type_param) => match type_param { + if let GenericArgument::Type(type_param) = generic_arg { + match type_param { syn::Type::Path(type_path) => { let ident_str = type_path.path.segments.last().unwrap().ident.to_string(); if get_primitives_attrs("ElementType", &ident_str, primitives).is_some() { @@ -635,8 +636,7 @@ pub fn get_element_type_structured( element_type = get_element_type_structured(&type_ref.elem, primitives) } _ => {} - }, - _ => {} + } } } element_type @@ -647,9 +647,7 @@ pub fn get_cuda_tile_element_type_structured( ty: &syn::Type, primitives: &HashMap<(String, String), ItemImpl>, ) -> Option { - let Some(rust_element_type) = get_element_type_structured(ty, primitives) else { - return None; - }; + let rust_element_type = get_element_type_structured(ty, primitives)?; get_cuda_tile_element_type_from_rust_primitive_str(&rust_element_type, primitives) } @@ -684,7 +682,7 @@ pub fn parse_signed_literal_as_i32(expr: &Expr) -> i32 { Lit::Int(int_lit) => int_lit.base10_parse().unwrap(), _ => unimplemented!("Unexpected array element {expr:#?}"), }; - return val; + val } Expr::Unary(unary_expr) => match unary_expr.op { UnOp::Neg(_) => match &*unary_expr.expr { @@ -693,7 +691,7 @@ pub fn parse_signed_literal_as_i32(expr: &Expr) -> i32 { Lit::Int(int_lit) => int_lit.base10_parse().unwrap(), _ => unimplemented!("Unexpected array element {expr:#?}"), }; - return -val; + -val } _ => panic!("Unexpected unary expr {unary_expr:#?}"), }, @@ -755,28 +753,20 @@ pub fn get_type_mutability(ty: &Type) -> bool { /// Tries to extract a const generic array from a type's generic arguments. pub fn try_extract_cga(ty: &Type, generic_vars: &GenericVars) -> Option> { - let Some(mut type_generic_args) = maybe_generic_args(ty) else { - return None; - }; + let mut type_generic_args = maybe_generic_args(ty)?; strip_generic_args_lifetimes(&mut type_generic_args); let mut result = None; for generic_arg in type_generic_args.args.iter() { match generic_arg { GenericArgument::Lifetime(_) => continue, - GenericArgument::Type(type_param) => { + GenericArgument::Type(syn::Type::Path(type_path)) => { // Currently, this is either shape or element_type - match type_param { - syn::Type::Path(type_path) => { - let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); - // println!("get_variadic_type_args: Type::Path: {}", last_ident); - if generic_vars.inst_array.contains_key(&last_ident) { - // This is something like Shape for const generic array D: [i32; N]. - let array_instance = generic_vars.inst_array.get(&last_ident).unwrap(); - result = Some(array_instance.clone()); - } - } - _ => {} + let last_ident = type_path.path.segments.last().unwrap().ident.to_string(); + // println!("get_variadic_type_args: Type::Path: {}", last_ident); + if let Some(array_instance) = generic_vars.inst_array.get(&last_ident) { + // This is something like Shape for const generic array D: [i32; N]. + result = Some(array_instance.clone()); } } GenericArgument::Const(const_expr) => { @@ -902,13 +892,12 @@ pub fn try_extract_cga(ty: &Type, generic_vars: &GenericVars) -> Option Expr::Path(len_path) => { // This is something like Tensor let num_rep_var = len_path.to_token_stream().to_string(); - if !generic_vars.get_i32(&num_rep_var).is_some() { + generic_vars.get_i32(&num_rep_var).unwrap_or_else(|| { panic!( "Expected instance for generic argument {}", num_rep_var - ); - } - generic_vars.get_i32(&num_rep_var).unwrap() + ) + }) } Expr::Lit(len_lit) => { // This is something like Tensor diff --git a/cutile-macro/src/_module.rs b/cutile-macro/src/_module.rs index f63f45b..d6d9b8d 100644 --- a/cutile-macro/src/_module.rs +++ b/cutile-macro/src/_module.rs @@ -539,7 +539,7 @@ pub fn structure(mut item: ItemStruct) -> Result { ); // println!("structure {ident}: {attributes:#?}"); let res = match attributes { - Some(attributes) => match attributes.name_as_str().unwrap().as_str() { + Some(attributes) => match attributes.name_as_str().unwrap() { "cuda_tile :: variadic_struct" => { let items = variadic_struct(&attributes, item)?; let structs = items.iter().map(|item| item.0.clone()).collect::>();