Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ Other values pass through unchanged.
"""
to_tile_arg(x) = x
to_tile_arg(arr::AbstractArray) = TileArray(arr)
to_tile_arg(t::Type) = Constant(t)

# Tiled Broadcast — TiledStyle wins over CuArrayStyle
BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}()
Expand Down
8 changes: 8 additions & 0 deletions src/compiler/codegen/expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
# Bounds checking is always disabled in Tile IR kernels.
# Emit false so IfOps referencing this SSA can resolve the condition.
return emit_constant!(ctx, false, Bool)
elseif expr.head === :static_parameter
# Static type parameter reference (e.g., V in `f(::T{V}) where {V}`).
# Look up the concrete value from the method's sptypes.
idx = expr.args[1]::Int
sp = ctx.sci.sptypes[idx]
sptyp = sp isa CC.VarState ? sp.typ : sp
val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp)
return emit_value!(ctx, val)
elseif expr.head === :code_coverage_effect
return nothing
else
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/passes/canonicalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ function scalar_elim_block!(block::Block)

current_type = value_type(inst)
current_type === nothing && continue
is_token_type(current_type) && continue
T = CC.widenconst(current_type)
T <: Tile && continue # already tile-typed
T <: Number || continue # only promote scalar number types

for op in ops
op_type = value_type(block, op)
op_type === nothing && continue
is_token_type(op_type) && continue
OT = CC.widenconst(op_type)
OT <: Tile || continue
S = OT.parameters[2]
Expand Down Expand Up @@ -162,6 +164,7 @@ function scalar_elim_block!(block::Block)
for inst in instructions(block)
current_type = value_type(inst)
current_type === nothing && continue
is_token_type(current_type) && continue
new_type = promote_scalar_type(CC.widenconst(current_type))
new_type === nothing && continue
update_type!(block, inst, new_type)
Expand All @@ -170,6 +173,7 @@ function scalar_elim_block!(block::Block)
# Phase 5: Promote block argument types (loop IVs, carries).
# BlockArgument is immutable, so we create a new one and replace all uses.
for (i, arg) in enumerate(block.args)
is_token_type(arg.type) && continue
T = CC.widenconst(arg.type)
T <: Tile && continue
T <: Number || continue
Expand Down
10 changes: 8 additions & 2 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,21 @@ Returns `(stripped, nothing)` when no Constant types are present.
function process_const_argtypes(@nospecialize(f), @nospecialize(argtypes))
params = argtypes isa DataType ? argtypes.parameters :
argtypes isa Tuple ? argtypes : fieldtypes(argtypes)
has_consts = any(T -> T <: Constant, params)
has_consts = any(T -> T <: Constant || CC.isconstType(T), params)
stripped_params = map(params) do T
T <: Constant ? constant_eltype(T) : T
end
stripped = Tuple{stripped_params...}
const_argtypes = if has_consts
cats = Any[CC.Const(f)]
for T in params
push!(cats, T <: Constant ? CC.Const(constant_value(T)) : T)
if T <: Constant
push!(cats, CC.Const(constant_value(T)))
elseif CC.isconstType(T)
push!(cats, CC.Const(T.parameters[1]))
else
push!(cats, T)
end
end
cats
else
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end
# walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see
# operands inside cuTile-specific IR nodes.
IRStructurizer.walk_uses!(f, node::JoinTokensNode) =
for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end
for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end
IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing
IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing

Expand Down
3 changes: 2 additions & 1 deletion src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}}
"""
struct Constant{T, V} end

# Convenience constructor that infers type from value
# Convenience constructors that infer type from value
Constant(val::T) where {T} = Constant{T, val}()
Constant(val::Type{T}) where {T} = Constant{Type{T}, T}()

# Extract constant value - @inline ensures this folds to a constant in IR
@inline Base.getindex(::Constant{T, V}) where {T, V} = V
Expand Down
27 changes: 27 additions & 0 deletions test/codegen/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,33 @@ end
end
end

#=============================================================================
Constant Type Arguments
=============================================================================#

@testset "Constant Type Arguments" begin
spec = ct.ArraySpec{1}(16, true)

function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile = ct.load(a, pid, (tile_size,)) + zeros(T, (tile_size,))
ct.store(b, pid, tile)
return
end

@testset "Type parameter used in kernel body" begin
@test @filecheck begin
@check_label "entry"
@check "load_view_tko"
@check "addf"
@check "store_view_tko"
code_tiled(_type_param_kernel,
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
ct.Constant{Int,16}, Type{Float32}})
end
end
end

#=============================================================================
For Loops
=============================================================================#
Expand Down
27 changes: 27 additions & 0 deletions test/codegen/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,30 @@ end
end
end
end

@testset "Type args" begin
const_spec = ct.ArraySpec{1}(128, true, (0,), (32,))

@test ct.Constant(Int) isa ct.Constant{Type{Int}, Int}

@testset "code_tiled with Type parameter" begin
function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile_a = ct.load(a; index=pid, shape=(tile_size,))
tile_b = ct.load(b; index=pid, shape=(tile_size,))
ct.store(c; index=pid, tile=tile_a + tile_b + zeros(T, (tile_size,)))
return
end

ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec},
ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16},
Type{Float32}}

@test @filecheck begin
@check "load_view_tko"
@check "addf"
@check "store_view_tko"
ct.code_tiled(reflect_type_param, ConstTypeTT)
end
end
end
20 changes: 20 additions & 0 deletions test/device/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,26 @@ end
@test Array(b) ≈ Array(a)
end

@testset "Type parameter (auto-wrapped)" begin
function vadd_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
pid = ct.bid(1)
tile_a = ct.load(a; index=pid, shape=(tile_size,))
tile_b = ct.load(b; index=pid, shape=(tile_size,))
ct.store(c; index=pid, tile=T.(tile_a) + T.(tile_b))
return
end

n = 1024
tile_size = 32
a = CUDA.rand(Float16, n)
b = CUDA.rand(Float16, n)
c = CUDA.zeros(Float32, n)

ct.launch(vadd_type_param, cld(n, tile_size), a, b, c, ct.Constant(tile_size), Float32)

@test Array(c) ≈ Float32.(Array(a)) + Float32.(Array(b))
end

end

@testset "TileArray auto-conversion" begin
Expand Down
Loading