diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 4708dd0..8006d5e 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -257,6 +257,7 @@ Other values pass through unchanged. """ to_tile_arg(x) = x to_tile_arg(arr::AbstractArray) = TileArray(arr) +to_tile_arg(t::Type) = Constant(t) # Tiled Broadcast — TiledStyle wins over CuArrayStyle BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}() diff --git a/src/compiler/codegen/expressions.jl b/src/compiler/codegen/expressions.jl index ab15707..ca9c521 100644 --- a/src/compiler/codegen/expressions.jl +++ b/src/compiler/codegen/expressions.jl @@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type)) # Bounds checking is always disabled in Tile IR kernels. # Emit false so IfOps referencing this SSA can resolve the condition. return emit_constant!(ctx, false, Bool) + elseif expr.head === :static_parameter + # Static type parameter reference (e.g., V in `f(::T{V}) where {V}`). + # Look up the concrete value from the method's sptypes. + idx = expr.args[1]::Int + sp = ctx.sci.sptypes[idx] + sptyp = sp isa CC.VarState ? sp.typ : sp + val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp) + return emit_value!(ctx, val) elseif expr.head === :code_coverage_effect return nothing else diff --git a/src/compiler/passes/canonicalize.jl b/src/compiler/passes/canonicalize.jl index f6dedb0..dbf6fbd 100644 --- a/src/compiler/passes/canonicalize.jl +++ b/src/compiler/passes/canonicalize.jl @@ -126,6 +126,7 @@ function scalar_elim_block!(block::Block) current_type = value_type(inst) current_type === nothing && continue + is_token_type(current_type) && continue T = CC.widenconst(current_type) T <: Tile && continue # already tile-typed T <: Number || continue # only promote scalar number types @@ -133,6 +134,7 @@ function scalar_elim_block!(block::Block) for op in ops op_type = value_type(block, op) op_type === nothing && continue + is_token_type(op_type) && continue OT = CC.widenconst(op_type) OT <: Tile || continue S = OT.parameters[2] @@ -162,6 +164,7 @@ function scalar_elim_block!(block::Block) for inst in instructions(block) current_type = value_type(inst) current_type === nothing && continue + is_token_type(current_type) && continue new_type = promote_scalar_type(CC.widenconst(current_type)) new_type === nothing && continue update_type!(block, inst, new_type) @@ -170,6 +173,7 @@ function scalar_elim_block!(block::Block) # Phase 5: Promote block argument types (loop IVs, carries). # BlockArgument is immutable, so we create a new one and replace all uses. for (i, arg) in enumerate(block.args) + is_token_type(arg.type) && continue T = CC.widenconst(arg.type) T <: Tile && continue T <: Number || continue diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index fed7afb..3f99177 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -82,7 +82,7 @@ Returns `(stripped, nothing)` when no Constant types are present. function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes)) params = argtypes isa DataType ? argtypes.parameters : argtypes isa Tuple ? argtypes : fieldtypes(argtypes) - has_consts = any(T -> T <: Constant, params) + has_consts = any(T -> T <: Constant || CC.isconstType(T), params) stripped_params = map(params) do T T <: Constant ? constant_eltype(T) : T end @@ -90,7 +90,13 @@ function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes)) const_argtypes = if has_consts cats = Any[CC.Const(f)] for T in params - push!(cats, T <: Constant ? CC.Const(constant_value(T)) : T) + if T <: Constant + push!(cats, CC.Const(constant_value(T))) + elseif CC.isconstType(T) + push!(cats, CC.Const(T.parameters[1])) + else + push!(cats, T) + end end cats else diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index e57b820..c9b6e96 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -77,7 +77,7 @@ end # walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see # operands inside cuTile-specific IR nodes. IRStructurizer.walk_uses!(f, node::JoinTokensNode) = - for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end + for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing diff --git a/src/language/types.jl b/src/language/types.jl index f246e54..0523b25 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}} """ struct Constant{T, V} end -# Convenience constructor that infers type from value +# Convenience constructors that infer type from value Constant(val::T) where {T} = Constant{T, val}() +Constant(val::Type{T}) where {T} = Constant{Type{T}, T}() # Extract constant value - @inline ensures this folds to a constant in IR @inline Base.getindex(::Constant{T, V}) where {T, V} = V diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index c25baf2..f9581b5 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -1262,6 +1262,33 @@ end end end +#============================================================================= + Constant Type Arguments +=============================================================================# + +@testset "Constant Type Arguments" begin + spec = ct.ArraySpec{1}(16, true) + + function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile = ct.load(a, pid, (tile_size,)) + zeros(T, (tile_size,)) + ct.store(b, pid, tile) + return + end + + @testset "Type parameter used in kernel body" begin + @test @filecheck begin + @check_label "entry" + @check "load_view_tko" + @check "addf" + @check "store_view_tko" + code_tiled(_type_param_kernel, + Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, + ct.Constant{Int,16}, Type{Float32}}) + end + end +end + #============================================================================= For Loops =============================================================================# diff --git a/test/codegen/reflection.jl b/test/codegen/reflection.jl index 51a73c3..13e626b 100644 --- a/test/codegen/reflection.jl +++ b/test/codegen/reflection.jl @@ -104,3 +104,30 @@ end end end end + +@testset "Type args" begin + const_spec = ct.ArraySpec{1}(128, true, (0,), (32,)) + + @test ct.Constant(Int) isa ct.Constant{Type{Int}, Int} + + @testset "code_tiled with Type parameter" begin + function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile_a = ct.load(a; index=pid, shape=(tile_size,)) + tile_b = ct.load(b; index=pid, shape=(tile_size,)) + ct.store(c; index=pid, tile=tile_a + tile_b + zeros(T, (tile_size,))) + return + end + + ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec}, + ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16}, + Type{Float32}} + + @test @filecheck begin + @check "load_view_tko" + @check "addf" + @check "store_view_tko" + ct.code_tiled(reflect_type_param, ConstTypeTT) + end + end +end diff --git a/test/device/core.jl b/test/device/core.jl index 5674fcf..7086868 100644 --- a/test/device/core.jl +++ b/test/device/core.jl @@ -387,6 +387,26 @@ end @test Array(b) ≈ Array(a) end +@testset "Type parameter (auto-wrapped)" begin + function vadd_type_param(a, b, c, tile_size::Int, ::Type{T}) where T + pid = ct.bid(1) + tile_a = ct.load(a; index=pid, shape=(tile_size,)) + tile_b = ct.load(b; index=pid, shape=(tile_size,)) + ct.store(c; index=pid, tile=T.(tile_a) + T.(tile_b)) + return + end + + n = 1024 + tile_size = 32 + a = CUDA.rand(Float16, n) + b = CUDA.rand(Float16, n) + c = CUDA.zeros(Float32, n) + + ct.launch(vadd_type_param, cld(n, tile_size), a, b, c, ct.Constant(tile_size), Float32) + + @test Array(c) ≈ Float32.(Array(a)) + Float32.(Array(b)) +end + end @testset "TileArray auto-conversion" begin