Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cutile-compiler/src/bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl Add for Bounds<i64> {
fn add(self, rhs: Bounds<i64>) -> Bounds<i64> {
let a = self;
let b = rhs;
let possible_bounds = vec![
let possible_bounds = [
a.start + b.start,
a.start + b.end,
a.end + b.start,
Expand All @@ -155,7 +155,7 @@ impl Sub for Bounds<i64> {
fn sub(self, rhs: Bounds<i64>) -> Bounds<i64> {
let a = self;
let b = rhs;
let possible_bounds = vec![
let possible_bounds = [
a.start - b.start,
a.start - b.end,
a.end - b.start,
Expand All @@ -178,7 +178,7 @@ impl Mul for Bounds<i64> {
fn mul(self, rhs: Bounds<i64>) -> Bounds<i64> {
let a = self;
let b = rhs;
let possible_bounds = vec![
let possible_bounds = [
a.start * b.start,
a.start * b.end,
a.end * b.start,
Expand Down Expand Up @@ -213,7 +213,7 @@ impl Div for Bounds<i64> {
(_, 0) => panic!("Division by zero"),
(0, _) => panic!("Division by zero"),
_ => {
let possible_bounds = vec![
let possible_bounds = [
a.start / b.start,
a.start / b.end,
a.end / b.start,
Expand All @@ -239,7 +239,7 @@ impl Rem for Bounds<i64> {
let a = self;
let b = rhs;
// TODO (hme): Verify this one.
let possible_bounds = vec![
let possible_bounds = [
a.start % b.start,
a.start % b.end,
a.end % b.start,
Expand All @@ -266,7 +266,7 @@ pub fn bop_bounds<F: Fn(i64, i64) -> i64>(a: &Bounds<i64>, b: &Bounds<i64>, f: F
if a.is_exact() && b.is_exact() {
return Bounds::exact(f(a.start, b.start));
}
let possible_bounds = vec![
let possible_bounds = [
f(a.start, b.start),
f(a.start, b.end),
f(a.end, b.start),
Expand Down
20 changes: 10 additions & 10 deletions cutile-compiler/src/compiler/_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct FunctionParamTypes {
}

impl<'m> CUDATileFunctionCompiler<'m> {
#[allow(clippy::too_many_arguments)]
pub fn new(
modules: &'m CUDATileModules,
module_name: &str,
Expand Down Expand Up @@ -127,7 +128,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {

// 7. Build stride_args HashMap.
let stride_args: HashMap<String, Vec<i32>> = stride_args
.into_iter()
.iter()
.map(|(k, v)| (k.to_string(), v.to_vec()))
.collect::<HashMap<_, _>>();

Expand Down Expand Up @@ -155,20 +156,19 @@ impl<'m> CUDATileFunctionCompiler<'m> {
.collect();
let (entry, validator) = generate_entry_point(
modules,
&function,
function,
&generic_vars,
&stride_args,
&spec_args_map,
&scalar_hints_map,
&modules.primitives(),
modules.primitives(),
&optimization_hints,
)?;

// 10. Check namespace collision.
if modules
.functions()
.get(kernel_naming.entry_name().as_str())
.is_some()
.contains_key(kernel_naming.entry_name().as_str())
{
return modules
.resolve_span(module_name, &function.span())
Expand Down Expand Up @@ -517,14 +517,14 @@ impl<'m> CUDATileFunctionCompiler<'m> {
if std::env::var("CUTILE_DEBUG_COMPILER2").is_ok() {
eprintln!(
"compiler2: lowered entry function body:\n{}",
quote::quote!(#lowered_fn_item).to_string()
quote::quote!(#lowered_fn_item)
);
}

let return_value = self.compile_block(
module,
block_id,
&*lowered_fn_item.block,
&lowered_fn_item.block,
generic_vars,
&mut ctx,
None,
Expand Down Expand Up @@ -582,7 +582,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
None
};
let value = self
.compile_expression(module, block_id, &arg, generic_args, ctx, expected)?
.compile_expression(module, block_id, arg, generic_args, ctx, expected)?
.ok_or(self.jit_error(
&arg.span(),
&format!(
Expand All @@ -606,7 +606,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
let rust_ty_str = type_name::<T>();
let rust_ty = syn::parse2::<syn::Type>(rust_ty_str.parse()?).unwrap();
let tr_ty = self
.compile_type(&rust_ty, &generic_vars, &HashMap::new())?
.compile_type(&rust_ty, generic_vars, &HashMap::new())?
.ok_or(self.jit_error(&rust_ty.span(), "failed to compile constant"))?;
self.compile_constant_from_exact_bounds(module, block_id, bounds, tr_ty)
}
Expand Down Expand Up @@ -652,7 +652,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
};
let Some(const_ty_str) = get_cuda_tile_element_type_from_rust_primitive_str(
&type_inst.rust_element_instance_ty,
&self.modules.primitives(),
self.modules.primitives(),
) else {
return self
.jit_error_result(&tr_ty.rust_ty.span(), "failed to compile constant value");
Expand Down
8 changes: 5 additions & 3 deletions cutile-compiler/src/compiler/_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,10 @@ fn trait_impl_matches_call(
return false;
}

let mut ctx = TraitMatchCtx::default();
ctx.caller_array_params = generic_vars.inst_array.clone();
let mut ctx = TraitMatchCtx{
caller_array_params: generic_vars.inst_array.clone(),
..Default::default()
};
collect_generics_for_trait_match(&item_impl.generics, &mut ctx);
collect_generics_for_trait_match(&impl_method.sig.generics, &mut ctx);

Expand Down Expand Up @@ -830,7 +832,7 @@ impl CUDATileModules {
if matches!(
generic_vars.instantiate_type(receiver_rust_ty, self.primitives())?,
TypeInstance::ElementType(_)
) {
) {
if let Some(impls) = self
.name_resolver
.trait_impls()
Expand Down
2 changes: 1 addition & 1 deletion cutile-compiler/src/compiler/_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ fn rust_scalar_type(name: &str) -> Option<ScalarType> {
fn extract_pointer_element_type(ty_str: &str) -> Option<String> {
let after_mut = ty_str.split("mut").nth(1)?;
let trimmed = after_mut.trim();
let end = trimmed.find(|c: char| c == ',' || c == '>' || c == ' ')?;
let end = trimmed.find([',', '>', ' '])?;
Some(trimmed[..end].to_string())
}
4 changes: 1 addition & 3 deletions cutile-compiler/src/compiler/_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ impl TileRustValue {
}

pub fn take_type_meta_field(self, name: &str) -> Option<Self> {
let Some(mut type_meta) = self.type_meta else {
return None;
};
let mut type_meta = self.type_meta?;
type_meta.fields.remove(name)
}

Expand Down
7 changes: 4 additions & 3 deletions cutile-compiler/src/compiler/compile_assume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
"expected a simple function path for assume invocation",
);
};
let ident = get_ident_from_path_expr(&path_expr);
let ident = get_ident_from_path_expr(path_expr);
let compiler_op_function = ident.to_string();
let mut args =
self.compile_call_args(module, block_id, &call_expr.args, generic_vars, ctx)?;
Expand All @@ -75,18 +75,19 @@ impl<'m> CUDATileFunctionCompiler<'m> {
"the first argument to `assume` must produce a value",
);
};
Ok(self.compile_value_assumption(
self.compile_value_assumption(
module,
block_id,
val_value,
compiler_op_function.as_str(),
&predicate_args,
return_type,
&call_expr.span(),
)?)
)
}

/// Generates tile-ir assume operation with appropriate predicate attribute.
#[allow(clippy::too_many_arguments)]
pub(crate) fn compile_value_assumption(
&self,
module: &mut Module,
Expand Down
25 changes: 13 additions & 12 deletions cutile-compiler/src/compiler/compile_binary_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
)?))
}

#[allow(clippy::too_many_arguments)]
pub fn compile_binary_op_from_values(
&self,
module: &mut Module,
Expand All @@ -156,8 +157,8 @@ impl<'m> CUDATileFunctionCompiler<'m> {
&format!(
"binary `{:?}` requires operands of the same type, but got `{}` and `{}`",
tile_rust_arithmetic_op,
lhs.ty.rust_ty.to_token_stream().to_string(),
rhs.ty.rust_ty.to_token_stream().to_string()
lhs.ty.rust_ty.to_token_stream(),
rhs.ty.rust_ty.to_token_stream()
),
);
}
Expand All @@ -180,14 +181,14 @@ impl<'m> CUDATileFunctionCompiler<'m> {
let operand_type = lhs.ty.clone();
let operand_rust_ty = &operand_type.rust_ty;
let Some(operand_rust_element_type) =
operand_type.get_instantiated_rust_element_type(&self.modules.primitives())
operand_type.get_instantiated_rust_element_type(self.modules.primitives())
else {
return self.jit_error_result(
span,
&format!(
"unable to determine element type for `{:?}` on `{}`",
tile_rust_arithmetic_op,
operand_type.rust_ty.to_token_stream().to_string()
operand_type.rust_ty.to_token_stream()
),
);
};
Expand All @@ -196,7 +197,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
span,
&format!(
"type `{}` cannot be used with binary `{:?}`",
operand_type.rust_ty.to_token_stream().to_string(),
operand_type.rust_ty.to_token_stream(),
tile_rust_arithmetic_op
),
);
Expand All @@ -206,7 +207,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
let operand_result_ty = module.value_type(lhs_value).clone();

let Some(operand_cuda_tile_element_type) =
operand_type.get_cuda_tile_element_type(&self.modules.primitives())?
operand_type.get_cuda_tile_element_type(self.modules.primitives())?
else {
return self.jit_error_result(
span,
Expand Down Expand Up @@ -414,7 +415,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
span,
&format!(
"Binary operation is not implemented for {}",
operand_rust_ty.to_token_stream().to_string()
operand_rust_ty.to_token_stream()
),
);
}
Expand All @@ -426,7 +427,7 @@ impl<'m> CUDATileFunctionCompiler<'m> {
// Try to infer from lhs/rhs.
if is_cmp {
let bool_ty = syn::parse2::<syn::Type>("bool".parse().unwrap()).unwrap();
self.compile_type(&bool_ty, &generic_vars, &HashMap::new())?
self.compile_type(&bool_ty, generic_vars, &HashMap::new())?
.unwrap()
} else {
operand_type
Expand All @@ -448,17 +449,17 @@ impl<'m> CUDATileFunctionCompiler<'m> {
} else {
None
};
if let Some(bounds) = &op_bounds {
if let Some(bounds) = op_bounds {
if bounds.is_exact() {
// The lower/upper bounds are equivalent — emit a constant
// instead. The op allocated above becomes dead (not appended
// to any block).
return Ok(self.compile_constant_from_exact_bounds(
return self.compile_constant_from_exact_bounds(
module,
block_id,
bounds.clone(),
bounds,
return_type,
)?);
);
}
}

Expand Down
Loading