Skip to content

Commit 5d7dfdc

Browse files
committed
Reorganize.
1 parent 0f5815d commit 5d7dfdc

16 files changed

Lines changed: 358 additions & 300 deletions

README.md

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ vector_size = 2^20
4343
tile_size = 16
4444

4545
blocks = cld(vector_size, tile_size)
46-
grid = (blocks, 1, 1)
46+
grid = (blocks, 1, 1)
4747

4848
a, b = CUDA.rand(Float32, vector_size), CUDA.rand(Float32, vector_size)
4949
c = CUDA.zeros(Float32, vector_size)
@@ -232,7 +232,6 @@ uses standard Julia syntax and is overlaid on `Base`.
232232

233233
cuTile.jl follows Julia conventions, which differ from the Python API in several ways:
234234

235-
236235
### Kernel definition syntax
237236

238237
Kernels don't need a decorator, but do have to return `nothing`:
@@ -511,6 +510,36 @@ ct.store(arr, (i, j), t)
511510
```
512511

513512

513+
## Host-level operations
514+
515+
cuTile.jl also provides a limited set of host-level APIs to use cuTile without
516+
writing custom kernels. For example, for element-wise operations on `CuArray`s,
517+
cuTile can automatically generate and launch a fused kernel using Julia's
518+
broadcast machinery:
519+
520+
```julia
521+
using CUDA
522+
import cuTile as ct
523+
524+
A = CUDA.rand(Float32, 1024)
525+
B = CUDA.rand(Float32, 1024)
526+
C = CUDA.zeros(Float32, 1024)
527+
528+
# Wrap arrays in Tiled() to route through cuTile
529+
ct.Tiled(C) .= ct.Tiled(A) .+ ct.Tiled(B)
530+
531+
# Or use the @. macro for convenience
532+
ct.@. C = A + sin(B)
533+
534+
# Allocating form (returns a new CuArray)
535+
D = ct.@. A + B
536+
```
537+
538+
The entire broadcast expression is fused into a single cuTile kernel. Tile sizes
539+
are automatically chosen based on array dimensions (power-of-2, budget-based).
540+
Works with 1D through N-dimensional arrays.
541+
542+
514543
## Acknowledgments
515544

516545
cuTile.jl is inspired by [cuTile-Python](https://github.com/NVIDIA/cutile-python/),

ext/CUDAExt.jl

Lines changed: 4 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module CUDAExt
22

33
using cuTile
4-
using cuTile: Tiled, TileArray, Constant, CuTileResults,
4+
using cuTile: TileArray, Constant, CuTileResults,
55
emit_code, sanitize_name, constant_eltype, flatten,
66
resolve_hint, format_sm_arch
77

@@ -12,8 +12,7 @@ import Core.Compiler as CC
1212
using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability
1313
using CUDA_Compiler_jll
1414

15-
import Base.Broadcast
16-
import Base.Broadcast: BroadcastStyle, Broadcasted, DefaultArrayStyle
15+
import Base.Broadcast: BroadcastStyle
1716
import CUDA: CuArrayStyle
1817

1918
public launch
@@ -259,132 +258,7 @@ Other values pass through unchanged.
259258
to_tile_arg(x) = x
260259
to_tile_arg(arr::AbstractArray) = TileArray(arr)
261260

262-
#=============================================================================
263-
Tiled Broadcast via Base.Broadcast
264-
=============================================================================#
265-
266-
struct TiledCuArrayStyle{N} <: BroadcastStyle end
267-
TiledCuArrayStyle{M}(::Val{N}) where {N,M} = TiledCuArrayStyle{N}()
268-
269-
BroadcastStyle(::Type{<:Tiled{<:CuArray{T,N}}}) where {T,N} = TiledCuArrayStyle{N}()
270-
271-
# TiledCuArrayStyle wins over CuArrayStyle and DefaultArrayStyle
272-
BroadcastStyle(::TiledCuArrayStyle{N}, ::CuArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
273-
BroadcastStyle(::TiledCuArrayStyle{N}, ::DefaultArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
274-
BroadcastStyle(::TiledCuArrayStyle{N}, ::TiledCuArrayStyle{M}) where {N,M} = TiledCuArrayStyle{max(N,M)}()
275-
276-
# materialize! dispatch: Tiled(B) .= expr
277-
function Base.Broadcast.materialize!(dest::Tiled, bc::Broadcasted)
278-
_tiled_broadcast!(parent(dest), bc)
279-
return dest
280-
end
281-
282-
# copy dispatch: C = Tiled(A) .+ B (allocating form)
283-
function Base.copy(bc::Broadcasted{TiledCuArrayStyle{N}}) where N
284-
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
285-
dest = similar(CuArray{ElType}, axes(bc))
286-
_tiled_broadcast!(dest, bc)
287-
return dest
288-
end
289-
290-
"""
291-
_to_tiled_bc(bc)
292-
293-
Walk a Broadcasted tree, converting leaf CuArrays to TileArrays and stripping
294-
style/axes (replacing with nothing). Scalars and other leaves pass through.
295-
"""
296-
_to_tiled_bc(arr::CuArray) = TileArray(arr)
297-
_to_tiled_bc(t::Tiled) = TileArray(parent(t))
298-
_to_tiled_bc(x::Number) = x
299-
_to_tiled_bc(x) = x # fallback for other types
300-
function _to_tiled_bc(bc::Broadcasted)
301-
new_args = map(_to_tiled_bc, bc.args)
302-
Broadcasted{Nothing}(bc.f, new_args, nothing)
303-
end
304-
305-
# The generic broadcast kernel: evaluates the Broadcasted tree on tiles
306-
@generated function _tiled_bc_kernel(dest::TileArray{T, N}, bc, tile_size, overflow_grids) where {T, N}
307-
body = Expr[]
308-
bid_vars = [Symbol("bid_$d") for d in 1:N]
309-
310-
if N <= 3
311-
for d in 1:N
312-
push!(body, :($(bid_vars[d]) = cuTile.bid($d)))
313-
end
314-
else
315-
push!(body, :($(bid_vars[1]) = cuTile.bid(1)))
316-
push!(body, :($(bid_vars[2]) = cuTile.bid(2)))
317-
push!(body, :(_rem = cuTile.bid(3) - Int32(1)))
318-
for d in 3:N
319-
if d < N
320-
push!(body, :($(bid_vars[d]) = rem(_rem, Int32(overflow_grids[$(d-2)])) + Int32(1)))
321-
push!(body, :(_rem = fld(_rem, Int32(overflow_grids[$(d-2)]))))
322-
else
323-
push!(body, :($(bid_vars[d]) = _rem + Int32(1)))
324-
end
325-
end
326-
end
327-
328-
idx = N == 1 ? bid_vars[1] : Expr(:tuple, bid_vars...)
329-
push!(body, :(result = _eval_bc(bc, $idx, tile_size)))
330-
push!(body, :(result_converted = convert(cuTile.Tile{$T}, result)))
331-
push!(body, :(cuTile.store(dest, $idx, result_converted)))
332-
push!(body, :(return))
333-
Expr(:block, body...)
334-
end
335-
336-
# Recursive tree evaluation inside kernel
337-
@inline _eval_bc(arr::TileArray, bid, tile_size) = cuTile.load(arr, bid, tile_size)
338-
@inline _eval_bc(x::Number, bid, tile_size) = x
339-
340-
@inline function _eval_bc(bc::Broadcasted, bid, tile_size)
341-
args = _eval_bc_args(bc.args, bid, tile_size)
342-
# Use broadcast to get element-wise semantics (not direct call, which
343-
# would dispatch to e.g. matmul for * on tiles)
344-
broadcast(bc.f, args...)
345-
end
346-
347-
@inline _eval_bc_args(::Tuple{}, bid, tile_size) = ()
348-
@inline _eval_bc_args(args::Tuple, bid, tile_size) =
349-
(_eval_bc(args[1], bid, tile_size), _eval_bc_args(Base.tail(args), bid, tile_size)...)
350-
351-
"""
352-
_compute_tile_sizes(dest_size; budget=4096)
353-
354-
Distribute a total element budget greedily across dimensions, skipping singletons.
355-
Each tile dimension is a power of 2, capped by the array size in that dimension.
356-
"""
357-
function _compute_tile_sizes(dest_size::NTuple{N,Int}; budget::Int=4096) where N
358-
ts = ones(Int, N)
359-
remaining = budget
360-
for i in 1:N
361-
s = dest_size[i]
362-
s == 1 && continue
363-
t = prevpow(2, min(remaining, s))
364-
ts[i] = t
365-
remaining = remaining ÷ t
366-
remaining < 2 && break
367-
end
368-
return NTuple{N,Int}(ts)
369-
end
370-
371-
"""
372-
_tiled_broadcast!(dest, bc)
373-
374-
Launch a tiled broadcast kernel for the fused expression `bc` writing to `dest`.
375-
"""
376-
function _tiled_broadcast!(dest::CuArray{T,N}, bc::Broadcasted) where {T, N}
377-
dest_ta = TileArray(dest)
378-
tiled_bc = _to_tiled_bc(bc)
379-
380-
ts = _compute_tile_sizes(size(dest))
381-
grid = ntuple(i -> cld(size(dest, i), ts[i]), N)
382-
383-
launch_grid = N <= 3 ? grid : (grid[1], grid[2], prod(grid[i] for i in 3:N))
384-
overflow = N > 3 ? grid[3:end] : ()
385-
386-
cuTile.launch(_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc,
387-
Constant(ts), Constant(overflow))
388-
end
261+
# Tiled Broadcast — TiledStyle wins over CuArrayStyle
262+
BroadcastStyle(::cuTile.TiledStyle{N}, ::CuArrayStyle{M}) where {N,M} = cuTile.TiledStyle{max(N,M)}()
389263

390264
end

0 commit comments

Comments
 (0)