Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/compiler/codegen/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
# Bounds checking is always disabled in Tile IR kernels.
# Emit false so IfOps referencing this SSA can resolve the condition.
return emit_constant!(ctx, false, Bool)
elseif expr.head === :static_parameter
# Static type parameter reference (e.g., V in `f(::T{V}) where {V}`).
# Look up the concrete value from the method's sptypes.
idx = expr.args[1]::Int
sp = ctx.sci.sptypes[idx]
sptyp = sp isa CC.VarState ? sp.typ : sp
val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp)
return emit_value!(ctx, val)
elseif expr.head === :code_coverage_effect
return nothing
else
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/passes/canonicalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ function scalar_elim_block!(block::Block)

current_type = value_type(inst)
current_type === nothing && continue
is_token_type(current_type) && continue
T = CC.widenconst(current_type)
T <: Tile && continue # already tile-typed
T <: Number || continue # only promote scalar number types

for op in ops
op_type = value_type(block, op)
op_type === nothing && continue
is_token_type(op_type) && continue
OT = CC.widenconst(op_type)
OT <: Tile || continue
S = OT.parameters[2]
Expand Down Expand Up @@ -162,6 +164,7 @@ function scalar_elim_block!(block::Block)
for inst in instructions(block)
current_type = value_type(inst)
current_type === nothing && continue
is_token_type(current_type) && continue
new_type = promote_scalar_type(CC.widenconst(current_type))
new_type === nothing && continue
update_type!(block, inst, new_type)
Expand All @@ -170,6 +173,7 @@ function scalar_elim_block!(block::Block)
# Phase 5: Promote block argument types (loop IVs, carries).
# BlockArgument is immutable, so we create a new one and replace all uses.
for (i, arg) in enumerate(block.args)
is_token_type(arg.type) && continue
T = CC.widenconst(arg.type)
T <: Tile && continue
T <: Number || continue
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end
# walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see
# operands inside cuTile-specific IR nodes.
IRStructurizer.walk_uses!(f, node::JoinTokensNode) =
for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end
for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end
IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing
IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing

Expand Down
3 changes: 2 additions & 1 deletion src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}}
"""
struct Constant{T, V} end

# Convenience constructor that infers type from value
# Convenience constructors that infer type from value
Constant(val::T) where {T} = Constant{T, val}()
Constant(val::Type{T}) where {T} = Constant{Type{T}, T}()

# Extract constant value - @inline ensures this folds to a constant in IR
@inline Base.getindex(::Constant{T, V}) where {T, V} = V
Expand Down
45 changes: 45 additions & 0 deletions test/codegen/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,51 @@ end
end
end

#=============================================================================
Constant Type Arguments
=============================================================================#

@testset "Constant Type Arguments" begin
spec = ct.ArraySpec{1}(16, true)

function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile = ct.load(a, pid, (tile_size,))
ct.store(b, pid, tile)
return
end

@testset "Type parameter via static_parameter" begin
@test @filecheck begin
@check_label "entry"
@check "load_view_tko"
@check "store_view_tko"
code_tiled(_type_param_kernel,
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
ct.Constant{Int,16}, ct.Constant{Type{Nothing},Nothing}})
end
end

# Test that Constant(Type) constructor produces correct types
function _use_type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile = ct.load(a, pid, (tile_size,))
ct.store(b, pid, tile)
return
end

@testset "Constant(Type) via convenience constructor" begin
@test @filecheck begin
@check_label "entry"
@check "load_view_tko"
@check "store_view_tko"
code_tiled(_use_type_param_kernel,
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
ct.Constant{Int,16}, ct.Constant{Type{Float32},Float32}})
end
end
end

#=============================================================================
For Loops
=============================================================================#
Expand Down
27 changes: 27 additions & 0 deletions test/codegen/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,30 @@ end
end
end
end

@testset "Constant Type args" begin
const_spec = ct.ArraySpec{1}(128, true, (0,), (32,))

@test ct.Constant(Int) isa ct.Constant{Type{Int}, Int}

@testset "code_tiled with Constant Type parameter" begin
function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile_a = ct.load(a; index=pid, shape=(tile_size,))
tile_b = ct.load(b; index=pid, shape=(tile_size,))
ct.store(c; index=pid, tile=tile_a + tile_b)
return
end

ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec},
ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16},
ct.Constant{Type{Nothing}, Nothing}}

@test @filecheck begin
@check "load_view_tko"
@check "addf"
@check "store_view_tko"
ct.code_tiled(reflect_type_param, ConstTypeTT)
end
end
end
Loading