11module CUDAExt
22
33using cuTile
4- using cuTile: Tiled, TileArray, Constant, CuTileResults,
4+ using cuTile: TileArray, Constant, CuTileResults,
55 emit_code, sanitize_name, constant_eltype, flatten,
66 resolve_hint, format_sm_arch
77
@@ -12,8 +12,7 @@ import Core.Compiler as CC
1212using CUDA: CuArray, CuModule, CuFunction, cudacall, device, capability
1313using CUDA_Compiler_jll
1414
15- import Base. Broadcast
16- import Base. Broadcast: BroadcastStyle, Broadcasted, DefaultArrayStyle
15+ import Base. Broadcast: BroadcastStyle
1716import CUDA: CuArrayStyle
1817
1918public launch
@@ -259,132 +258,7 @@ Other values pass through unchanged.
259258to_tile_arg (x) = x
260259to_tile_arg (arr:: AbstractArray ) = TileArray (arr)
261260
262- #= ============================================================================
263- Tiled Broadcast via Base.Broadcast
264- =============================================================================#
265-
266- struct TiledCuArrayStyle{N} <: BroadcastStyle end
267- TiledCuArrayStyle {M} (:: Val{N} ) where {N,M} = TiledCuArrayStyle {N} ()
268-
269- BroadcastStyle (:: Type{<:Tiled{<:CuArray{T,N}}} ) where {T,N} = TiledCuArrayStyle {N} ()
270-
271- # TiledCuArrayStyle wins over CuArrayStyle and DefaultArrayStyle
272- BroadcastStyle (:: TiledCuArrayStyle{N} , :: CuArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
273- BroadcastStyle (:: TiledCuArrayStyle{N} , :: DefaultArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
274- BroadcastStyle (:: TiledCuArrayStyle{N} , :: TiledCuArrayStyle{M} ) where {N,M} = TiledCuArrayStyle {max(N,M)} ()
275-
276- # materialize! dispatch: Tiled(B) .= expr
277- function Base. Broadcast. materialize! (dest:: Tiled , bc:: Broadcasted )
278- _tiled_broadcast! (parent (dest), bc)
279- return dest
280- end
281-
282- # copy dispatch: C = Tiled(A) .+ B (allocating form)
283- function Base. copy (bc:: Broadcasted{TiledCuArrayStyle{N}} ) where N
284- ElType = Broadcast. combine_eltypes (bc. f, bc. args)
285- dest = similar (CuArray{ElType}, axes (bc))
286- _tiled_broadcast! (dest, bc)
287- return dest
288- end
289-
290- """
291- _to_tiled_bc(bc)
292-
293- Walk a Broadcasted tree, converting leaf CuArrays to TileArrays and stripping
294- style/axes (replacing with nothing). Scalars and other leaves pass through.
295- """
296- _to_tiled_bc (arr:: CuArray ) = TileArray (arr)
297- _to_tiled_bc (t:: Tiled ) = TileArray (parent (t))
298- _to_tiled_bc (x:: Number ) = x
299- _to_tiled_bc (x) = x # fallback for other types
300- function _to_tiled_bc (bc:: Broadcasted )
301- new_args = map (_to_tiled_bc, bc. args)
302- Broadcasted {Nothing} (bc. f, new_args, nothing )
303- end
304-
305- # The generic broadcast kernel: evaluates the Broadcasted tree on tiles
306- @generated function _tiled_bc_kernel (dest:: TileArray{T, N} , bc, tile_size, overflow_grids) where {T, N}
307- body = Expr[]
308- bid_vars = [Symbol (" bid_$d " ) for d in 1 : N]
309-
310- if N <= 3
311- for d in 1 : N
312- push! (body, :($ (bid_vars[d]) = cuTile. bid ($ d)))
313- end
314- else
315- push! (body, :($ (bid_vars[1 ]) = cuTile. bid (1 )))
316- push! (body, :($ (bid_vars[2 ]) = cuTile. bid (2 )))
317- push! (body, :(_rem = cuTile. bid (3 ) - Int32 (1 )))
318- for d in 3 : N
319- if d < N
320- push! (body, :($ (bid_vars[d]) = rem (_rem, Int32 (overflow_grids[$ (d- 2 )])) + Int32 (1 )))
321- push! (body, :(_rem = fld (_rem, Int32 (overflow_grids[$ (d- 2 )]))))
322- else
323- push! (body, :($ (bid_vars[d]) = _rem + Int32 (1 )))
324- end
325- end
326- end
327-
328- idx = N == 1 ? bid_vars[1 ] : Expr (:tuple , bid_vars... )
329- push! (body, :(result = _eval_bc (bc, $ idx, tile_size)))
330- push! (body, :(result_converted = convert (cuTile. Tile{$ T}, result)))
331- push! (body, :(cuTile. store (dest, $ idx, result_converted)))
332- push! (body, :(return ))
333- Expr (:block , body... )
334- end
335-
336- # Recursive tree evaluation inside kernel
337- @inline _eval_bc (arr:: TileArray , bid, tile_size) = cuTile. load (arr, bid, tile_size)
338- @inline _eval_bc (x:: Number , bid, tile_size) = x
339-
340- @inline function _eval_bc (bc:: Broadcasted , bid, tile_size)
341- args = _eval_bc_args (bc. args, bid, tile_size)
342- # Use broadcast to get element-wise semantics (not direct call, which
343- # would dispatch to e.g. matmul for * on tiles)
344- broadcast (bc. f, args... )
345- end
346-
347- @inline _eval_bc_args (:: Tuple{} , bid, tile_size) = ()
348- @inline _eval_bc_args (args:: Tuple , bid, tile_size) =
349- (_eval_bc (args[1 ], bid, tile_size), _eval_bc_args (Base. tail (args), bid, tile_size)... )
350-
351- """
352- _compute_tile_sizes(dest_size; budget=4096)
353-
354- Distribute a total element budget greedily across dimensions, skipping singletons.
355- Each tile dimension is a power of 2, capped by the array size in that dimension.
356- """
357- function _compute_tile_sizes (dest_size:: NTuple{N,Int} ; budget:: Int = 4096 ) where N
358- ts = ones (Int, N)
359- remaining = budget
360- for i in 1 : N
361- s = dest_size[i]
362- s == 1 && continue
363- t = prevpow (2 , min (remaining, s))
364- ts[i] = t
365- remaining = remaining ÷ t
366- remaining < 2 && break
367- end
368- return NTuple {N,Int} (ts)
369- end
370-
371- """
372- _tiled_broadcast!(dest, bc)
373-
374- Launch a tiled broadcast kernel for the fused expression `bc` writing to `dest`.
375- """
376- function _tiled_broadcast! (dest:: CuArray{T,N} , bc:: Broadcasted ) where {T, N}
377- dest_ta = TileArray (dest)
378- tiled_bc = _to_tiled_bc (bc)
379-
380- ts = _compute_tile_sizes (size (dest))
381- grid = ntuple (i -> cld (size (dest, i), ts[i]), N)
382-
383- launch_grid = N <= 3 ? grid : (grid[1 ], grid[2 ], prod (grid[i] for i in 3 : N))
384- overflow = N > 3 ? grid[3 : end ] : ()
385-
386- cuTile. launch (_tiled_bc_kernel, launch_grid, dest_ta, tiled_bc,
387- Constant (ts), Constant (overflow))
388- end
261+ # Tiled Broadcast — TiledStyle wins over CuArrayStyle
262+ BroadcastStyle (:: cuTile.TiledStyle{N} , :: CuArrayStyle{M} ) where {N,M} = cuTile. TiledStyle {max(N,M)} ()
389263
390264end
0 commit comments