Skip to content

Commit 309c0d6

Browse files
AntonOrestenclaude
andcommitted
Add Constant(Type) support and fix static_parameter codegen
Enable `Constant(T)` where `T` is a type (e.g., `Constant(Int)`) to produce `Constant{Type{T}, T}` instead of `Constant{DataType, T}`, so method dispatch correctly binds type parameters. Handle `:static_parameter` expressions in codegen by looking up concrete values from the method's sptypes, unwrapping VarState and Const wrappers. Also fix `IndexedUseRef` qualification in walk_uses! and guard scalar_elim_pass! against TokenType annotations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent df5a821 commit 309c0d6

File tree

6 files changed

+93
-2
lines changed

6 files changed

+93
-2
lines changed

src/compiler/codegen/expressions.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ function emit_expr!(ctx::CGCtx, expr::Expr, @nospecialize(result_type))
2020
# Bounds checking is always disabled in Tile IR kernels.
2121
# Emit false so IfOps referencing this SSA can resolve the condition.
2222
return emit_constant!(ctx, false, Bool)
23+
elseif expr.head === :static_parameter
24+
# Static type parameter reference (e.g., V in `f(::T{V}) where {V}`).
25+
# Look up the concrete value from the method's sptypes.
26+
idx = expr.args[1]::Int
27+
sp = ctx.sci.sptypes[idx]
28+
sptyp = sp isa CC.VarState ? sp.typ : sp
29+
val = sptyp isa CC.Const ? sptyp.val : CC.widenconst(sptyp)
30+
return emit_value!(ctx, val)
2331
elseif expr.head === :code_coverage_effect
2432
return nothing
2533
else

src/compiler/passes/canonicalize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,15 @@ function scalar_elim_block!(block::Block)
126126

127127
current_type = value_type(inst)
128128
current_type === nothing && continue
129+
is_token_type(current_type) && continue
129130
T = CC.widenconst(current_type)
130131
T <: Tile && continue # already tile-typed
131132
T <: Number || continue # only promote scalar number types
132133

133134
for op in ops
134135
op_type = value_type(block, op)
135136
op_type === nothing && continue
137+
is_token_type(op_type) && continue
136138
OT = CC.widenconst(op_type)
137139
OT <: Tile || continue
138140
S = OT.parameters[2]
@@ -162,6 +164,7 @@ function scalar_elim_block!(block::Block)
162164
for inst in instructions(block)
163165
current_type = value_type(inst)
164166
current_type === nothing && continue
167+
is_token_type(current_type) && continue
165168
new_type = promote_scalar_type(CC.widenconst(current_type))
166169
new_type === nothing && continue
167170
update_type!(block, inst, new_type)
@@ -170,6 +173,7 @@ function scalar_elim_block!(block::Block)
170173
# Phase 5: Promote block argument types (loop IVs, carries).
171174
# BlockArgument is immutable, so we create a new one and replace all uses.
172175
for (i, arg) in enumerate(block.args)
176+
is_token_type(arg.type) && continue
173177
T = CC.widenconst(arg.type)
174178
T <: Tile && continue
175179
T <: Number || continue

src/compiler/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ end
7777
# walk_uses! extensions so that IRStructurizer's uses()/replace_uses! see
7878
# operands inside cuTile-specific IR nodes.
7979
IRStructurizer.walk_uses!(f, node::JoinTokensNode) =
80-
for i in 1:length(node.tokens); f(IndexedUseRef(node.tokens, i)); end
80+
for i in 1:length(node.tokens); f(IRStructurizer.IndexedUseRef(node.tokens, i)); end
8181
IRStructurizer.walk_uses!(f, ::TokenResultNode) = nothing
8282
IRStructurizer.walk_uses!(f, ::MakeTokenNode) = nothing
8383

src/language/types.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,9 @@ argtypes = Tuple{Ptr{Float32}, Constant{Int, 16}}
316316
"""
317317
struct Constant{T, V} end
318318

319-
# Convenience constructor that infers type from value
319+
# Convenience constructors that infer type from value
320320
Constant(val::T) where {T} = Constant{T, val}()
321+
Constant(val::Type{T}) where {T} = Constant{Type{T}, T}()
321322

322323
# Extract constant value - @inline ensures this folds to a constant in IR
323324
@inline Base.getindex(::Constant{T, V}) where {T, V} = V

test/codegen/integration.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,51 @@ end
12621262
end
12631263
end
12641264

1265+
#=============================================================================
1266+
Constant Type Arguments
1267+
=============================================================================#
1268+
1269+
@testset "Constant Type Arguments" begin
1270+
spec = ct.ArraySpec{1}(16, true)
1271+
1272+
function _type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
1273+
pid = ct.bid(1)
1274+
tile = ct.load(a, pid, (tile_size,))
1275+
ct.store(b, pid, tile)
1276+
return
1277+
end
1278+
1279+
@testset "Type parameter via static_parameter" begin
1280+
@test @filecheck begin
1281+
@check_label "entry"
1282+
@check "load_view_tko"
1283+
@check "store_view_tko"
1284+
code_tiled(_type_param_kernel,
1285+
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
1286+
ct.Constant{Int,16}, ct.Constant{Type{Nothing},Nothing}})
1287+
end
1288+
end
1289+
1290+
# Test that Constant(Type) constructor produces correct types
1291+
function _use_type_param_kernel(a, b, tile_size::Int, ::Type{T}) where T
1292+
pid = ct.bid(1)
1293+
tile = ct.load(a, pid, (tile_size,))
1294+
ct.store(b, pid, tile)
1295+
return
1296+
end
1297+
1298+
@testset "Constant(Type) via convenience constructor" begin
1299+
@test @filecheck begin
1300+
@check_label "entry"
1301+
@check "load_view_tko"
1302+
@check "store_view_tko"
1303+
code_tiled(_use_type_param_kernel,
1304+
Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec},
1305+
ct.Constant{Int,16}, ct.Constant{Type{Float32},Float32}})
1306+
end
1307+
end
1308+
end
1309+
12651310
#=============================================================================
12661311
For Loops
12671312
=============================================================================#

test/codegen/reflection.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,36 @@ end
104104
end
105105
end
106106
end
107+
108+
@testset "Constant Type args" begin
109+
const_spec = ct.ArraySpec{1}(128, true, (0,), (32,))
110+
111+
@testset "Constant(Type) constructor" begin
112+
@test typeof(ct.Constant(Int)) === ct.Constant{Type{Int}, Int}
113+
@test typeof(ct.Constant(Nothing)) === ct.Constant{Type{Nothing}, Nothing}
114+
@test typeof(ct.Constant(Float32)) === ct.Constant{Type{Float32}, Float32}
115+
# Non-type values still work as before
116+
@test typeof(ct.Constant(42)) === ct.Constant{Int, 42}
117+
end
118+
119+
@testset "code_tiled with Constant Type parameter" begin
120+
function reflect_type_param(a, b, c, tile_size::Int, ::Type{T}) where T
121+
pid = ct.bid(1)
122+
tile_a = ct.load(a; index=pid, shape=(tile_size,))
123+
tile_b = ct.load(b; index=pid, shape=(tile_size,))
124+
ct.store(c; index=pid, tile=tile_a + tile_b)
125+
return
126+
end
127+
128+
ConstTypeTT = Tuple{ct.TileArray{Float32,1,const_spec}, ct.TileArray{Float32,1,const_spec},
129+
ct.TileArray{Float32,1,const_spec}, ct.Constant{Int64, 16},
130+
ct.Constant{Type{Nothing}, Nothing}}
131+
132+
@test @filecheck begin
133+
@check "load_view_tko"
134+
@check "addf"
135+
@check "store_view_tko"
136+
ct.code_tiled(reflect_type_param, ConstTypeTT)
137+
end
138+
end
139+
end

0 commit comments

Comments
 (0)