Skip to content

Commit 00e7a54

Browse files
authored
Merge pull request #129 from JuliaGPU/tb/broadcast
Provide host-level broadcast
2 parents 557ca0f + 5d7dfdc commit 00e7a54

16 files changed

Lines changed: 362 additions & 5 deletions

README.md

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ vector_size = 2^20
4343
tile_size = 16
4444

4545
blocks = cld(vector_size, tile_size)
46-
grid = (blocks, 1, 1)
46+
grid = (blocks, 1, 1)
4747

4848
a, b = CUDA.rand(Float32, vector_size), CUDA.rand(Float32, vector_size)
4949
c = CUDA.zeros(Float32, vector_size)
@@ -232,7 +232,6 @@ uses standard Julia syntax and is overlaid on `Base`.
232232

233233
cuTile.jl follows Julia conventions, which differ from the Python API in several ways:
234234

235-
236235
### Kernel definition syntax
237236

238237
Kernels don't need a decorator, but do have to return `nothing`:
@@ -511,6 +510,36 @@ ct.store(arr, (i, j), t)
511510
```
512511

513512

513+
## Host-level operations
514+
515+
cuTile.jl also provides a limited set of host-level APIs to use cuTile without
516+
writing custom kernels. For example, for element-wise operations on `CuArray`s,
517+
cuTile can automatically generate and launch a fused kernel using Julia's
518+
broadcast machinery:
519+
520+
```julia
521+
using CUDA
522+
import cuTile as ct
523+
524+
A = CUDA.rand(Float32, 1024)
525+
B = CUDA.rand(Float32, 1024)
526+
C = CUDA.zeros(Float32, 1024)
527+
528+
# Wrap arrays in Tiled() to route through cuTile
529+
ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B)
530+
531+
# Or use the @. macro for convenience
532+
ct.@. C = A + sin(B)
533+
534+
# Allocating form (returns a new CuArray)
535+
D = ct.@. A + B
536+
```
537+
538+
The entire broadcast expression is fused into a single cuTile kernel. Tile sizes
539+
are automatically chosen based on array dimensions (power-of-2, budget-based).
540+
Works with 1D through N-dimensional arrays.
541+
542+
514543
## Acknowledgments
515544

516545
cuTile.jl is inspired by [cuTile-Python](https://github.com/NVIDIA/cutile-python/),

ext/CUDAExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ using CompilerCaching: CacheView, method_instance, results
99

1010
import Core.Compiler as CC
1111

12-
using CUDA: CuModule, CuFunction, cudacall, device, capability
12+
using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability
1313
using CUDA_Compiler_jll
1414

15+
import Base.Broadcast: BroadcastStyle
16+
import CUDA: CuArrayStyle
17+
1518
public launch
1619

1720
function run_and_collect(cmd)
@@ -255,4 +258,7 @@ Other values pass through unchanged.
255258
to_tile_arg(x) = x
256259
to_tile_arg(arr::AbstractArray) = TileArray(arr)
257260

261+
# Tiled Broadcast — TiledStyle wins over CuArrayStyle
262+
BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}()
263+
258264
end

src/broadcast.jl

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import Base.Broadcast: BroadcastStyle, Broadcasted
2+
3+
#=============================================================================
4+
Tiled wrapper — routes broadcast expressions through cuTile kernels
5+
=============================================================================#
6+
7+
"""
8+
Tiled(x)
9+
10+
Wrapper that routes broadcast expressions through cuTile kernels.
11+
12+
Tiled(B) .= A .+ A
13+
14+
Uses Julia's `Base.Broadcast` fusion machinery to build a `Broadcasted` tree,
15+
then dispatches to a generic cuTile kernel that evaluates the tree on tiles.
16+
"""
17+
struct Tiled{A <: AbstractArray}
18+
parent::A
19+
end
20+
Tiled(x) = x # passthrough for non-arrays (Numbers, etc.)
21+
Base.parent(t::Tiled) = t.parent
22+
Base.axes(t::Tiled) = axes(parent(t))
23+
Base.size(t::Tiled) = size(parent(t))
24+
Base.ndims(::Tiled{A}) where A = ndims(A)
25+
Base.eltype(::Tiled{A}) where A = eltype(A)
26+
Base.Broadcast.broadcastable(t::Tiled) = t
27+
28+
# Walk dotted AST, wrap value-position leaves in Tiled()
29+
_wrap_tiled(x) = x # literals pass through
30+
_wrap_tiled(s::Symbol) = :($Tiled($s))
31+
function _wrap_tiled(ex::Expr)
32+
if ex.head === :.=
33+
Expr(:.=, _wrap_tiled(ex.args[1]), _wrap_tiled(ex.args[2]))
34+
elseif ex.head === :. && length(ex.args) == 2 &&
35+
ex.args[2] isa Expr && ex.args[2].head === :tuple
36+
# f.(args...) — wrap args, NOT function position
37+
new_args = map(_wrap_tiled, ex.args[2].args)
38+
Expr(:., ex.args[1], Expr(:tuple, new_args...))
39+
else
40+
Expr(ex.head, map(_wrap_tiled, ex.args)...)
41+
end
42+
end
43+
44+
"""
45+
@. expr
46+
47+
Like `Base.@.` but wraps every value-position leaf in `Tiled()`, routing
48+
the broadcast through cuTile kernels.
49+
50+
using cuTile; const ct = cuTile
51+
ct.@. C = A + sin(B)
52+
# equivalent to: Tiled(C) .= Tiled(A) .+ sin.(Tiled(B))
53+
"""
54+
macro __dot__(ex)
55+
esc(_wrap_tiled(Base.Broadcast.__dot__(ex)))
56+
end
57+
58+
#=============================================================================
59+
TiledStyle — routes broadcast through cuTile kernels
60+
=============================================================================#
61+
62+
struct TiledStyle{N} <: BroadcastStyle end
63+
TiledStyle{M}(::Val{N}) where {N,M} = TiledStyle{N}()
64+
65+
BroadcastStyle(::Type{<:Tiled{A}}) where A = TiledStyle{ndims(A)}()
66+
67+
# TiledStyle wins over DefaultArrayStyle
68+
BroadcastStyle(::TiledStyle{N}, ::Base.Broadcast.DefaultArrayStyle{M}) where {N,M} = TiledStyle{max(N,M)}()
69+
BroadcastStyle(::TiledStyle{N}, ::TiledStyle{M}) where {N,M} = TiledStyle{max(N,M)}()
70+
71+
#=============================================================================
72+
materialize! and copy — dispatch to _tiled_broadcast!
73+
=============================================================================#
74+
75+
function Base.Broadcast.materialize!(dest::Tiled, bc::Broadcasted)
76+
_tiled_broadcast!(parent(dest), bc)
77+
return dest
78+
end
79+
80+
function Base.copy(bc::Broadcasted{TiledStyle{N}}) where N
81+
arr = @something _find_tiled_array(bc) error("tiled broadcast requires at least one Tiled() argument")
82+
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
83+
dest = similar(arr, ElType, axes(bc))
84+
_tiled_broadcast!(dest, bc)
85+
return dest
86+
end
87+
88+
"""Find the first underlying array from a Tiled leaf in a Broadcasted tree."""
89+
_find_tiled_array(t::Tiled) = parent(t)
90+
_find_tiled_array(x) = nothing
91+
function _find_tiled_array(bc::Broadcasted)
92+
for arg in bc.args
93+
arr = _find_tiled_array(arg)
94+
arr !== nothing && return arr
95+
end
96+
return nothing
97+
end
98+
99+
#=============================================================================
100+
_tiled_broadcast! — generic AbstractArray implementation
101+
=============================================================================#
102+
103+
function _tiled_broadcast!(dest::AbstractArray{T,N}, bc::Broadcasted) where {T, N}
104+
dest_ta = TileArray(dest)
105+
tiled_bc = _to_tiled_bc(bc)
106+
107+
ts = _compute_tile_sizes(size(dest))
108+
grid = ntuple(i -> cld(size(dest, i), ts[i]), N)
109+
110+
launch_grid = N <= 3 ? grid : (grid[1], grid[2], prod(grid[i] for i in 3:N))
111+
overflow = N > 3 ? grid[3:end] : ()
112+
113+
launch(_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc,
114+
Constant(ts), Constant(overflow))
115+
end
116+
117+
#=============================================================================
118+
Generic tree walk — convert leaves to TileArrays
119+
=============================================================================#
120+
121+
_to_tiled_bc(t::Tiled) = TileArray(parent(t))
122+
_to_tiled_bc(arr::AbstractArray) = TileArray(arr)
123+
_to_tiled_bc(x::Number) = x
124+
_to_tiled_bc(x) = x # fallback for other types
125+
function _to_tiled_bc(bc::Broadcasted)
126+
new_args = map(_to_tiled_bc, bc.args)
127+
Broadcasted{Nothing}(bc.f, new_args, nothing)
128+
end
129+
130+
#=============================================================================
131+
Broadcast kernel — evaluates Broadcasted tree on tiles
132+
=============================================================================#
133+
134+
@generated function _tiled_bc_kernel(dest::TileArray{T, N}, bc, tile_size, overflow_grids) where {T, N}
135+
body = Expr[]
136+
bid_vars = [Symbol("bid_$d") for d in 1:N]
137+
138+
if N <= 3
139+
for d in 1:N
140+
push!(body, :($(bid_vars[d]) = cuTile.bid($d)))
141+
end
142+
else
143+
push!(body, :($(bid_vars[1]) = cuTile.bid(1)))
144+
push!(body, :($(bid_vars[2]) = cuTile.bid(2)))
145+
push!(body, :(_rem = cuTile.bid(3) - Int32(1)))
146+
for d in 3:N
147+
if d < N
148+
push!(body, :($(bid_vars[d]) = rem(_rem, Int32(overflow_grids[$(d-2)])) + Int32(1)))
149+
push!(body, :(_rem = fld(_rem, Int32(overflow_grids[$(d-2)]))))
150+
else
151+
push!(body, :($(bid_vars[d]) = _rem + Int32(1)))
152+
end
153+
end
154+
end
155+
156+
idx = N == 1 ? bid_vars[1] : Expr(:tuple, bid_vars...)
157+
push!(body, :(result = _eval_bc(bc, $idx, tile_size)))
158+
push!(body, :(result_converted = convert(cuTile.Tile{$T}, result)))
159+
push!(body, :(cuTile.store(dest, $idx, result_converted)))
160+
push!(body, :(return))
161+
Expr(:block, body...)
162+
end
163+
164+
#=============================================================================
165+
Recursive tree evaluation inside kernel
166+
=============================================================================#
167+
168+
@inline _eval_bc(arr::TileArray, bid, tile_size) = cuTile.load(arr, bid, tile_size)
169+
@inline _eval_bc(x::Number, bid, tile_size) = x
170+
171+
@inline function _eval_bc(bc::Broadcasted, bid, tile_size)
172+
args = _eval_bc_args(bc.args, bid, tile_size)
173+
# Use broadcast to get element-wise semantics (not direct call, which
174+
# would dispatch to e.g. matmul for * on tiles)
175+
broadcast(bc.f, args...)
176+
end
177+
178+
@inline _eval_bc_args(::Tuple{}, bid, tile_size) = ()
179+
@inline _eval_bc_args(args::Tuple, bid, tile_size) =
180+
(_eval_bc(args[1], bid, tile_size), _eval_bc_args(Base.tail(args), bid, tile_size)...)
181+
182+
#=============================================================================
183+
Tile sizing
184+
=============================================================================#
185+
186+
"""
187+
_compute_tile_sizes(dest_size; budget=4096)
188+
189+
Distribute a total element budget greedily across dimensions, skipping singletons.
190+
Each tile dimension is a power of 2, capped by the array size in that dimension.
191+
"""
192+
function _compute_tile_sizes(dest_size::NTuple{N,Int}; budget::Int=4096) where N
193+
ts = ones(Int, N)
194+
remaining = budget
195+
for i in 1:N
196+
s = dest_size[i]
197+
s == 1 && continue
198+
t = prevpow(2, min(remaining, s))
199+
ts[i] = t
200+
remaining = remaining ÷ t
201+
remaining < 2 && break
202+
end
203+
return NTuple{N,Int}(ts)
204+
end

src/cuTile.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ include("language/math.jl")
3838
include("language/operations.jl")
3939
include("language/atomics.jl")
4040

41-
public launch, ByTarget, @compiler_options
41+
# Host-level abstractions
42+
include("broadcast.jl")
43+
44+
public launch, Tiled, ByTarget, @compiler_options, @.
4245
launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.")
4346

4447
end # module cuTile

0 commit comments

Comments
 (0)