Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9bf8c52
Fuse copyto by space
petebachant Apr 6, 2026
892e2b3
Detect compile error and fallback
petebachant Apr 6, 2026
0d0aeab
Trigger [perf] pipeline
petebachant Apr 6, 2026
dc6fffe
Update baseline SYPD for [perf] pipeline
petebachant Apr 6, 2026
fcaf674
Merge branch 'main' into pb/fieldname-set
petebachant Apr 6, 2026
ab45488
Install ClimaAtmos on main in [perf] pipeline
petebachant Apr 6, 2026
836a706
Better catch failed fusion
petebachant Apr 7, 2026
58738e4
Merge branch 'main' of github.com:CliMA/ClimaCore.jl into pb/fieldnam…
petebachant Apr 7, 2026
599166b
Another try: not faster
petebachant Apr 8, 2026
f0f78a0
Define risky groups
petebachant Apr 8, 2026
4feb8bc
Merge branch 'main' of github.com:CliMA/ClimaCore.jl into pb/fieldnam…
petebachant Apr 20, 2026
861a40c
Revert perf pipeline back to main
petebachant Apr 20, 2026
e974dd5
Sync back to main
petebachant Apr 29, 2026
8f79c25
Try fusing copyto
petebachant Apr 30, 2026
650c6b3
Merge in PCR solver and sync skips
petebachant Apr 30, 2026
4a3c913
Ignore certain functions within ClimaCore for kernel renaming
imreddyTeja Apr 29, 2026
c3755d6
Merge branch 'tr/rename' of github.com:CliMA/ClimaCore.jl into pb/fie…
petebachant May 4, 2026
8af6743
feat: add PCR based single field tridiagonal solver
Mikolaj-A-Kowalski Apr 7, 2026
e06f442
refactor: redirect TridiagonalMatrixField to special solver
Mikolaj-A-Kowalski Feb 19, 2026
18b453e
doc: note the new PCR solver in relevant documentation
Mikolaj-A-Kowalski Apr 30, 2026
5977c4e
refactor: unbox the operators to conform to new interface
Mikolaj-A-Kowalski May 5, 2026
86ea65b
[perf] trigger end-to-end AMIP
Mikolaj-A-Kowalski May 5, 2026
a1fae05
Merge remote-tracking branch 'origin/main' into pb/fieldname-set
petebachant May 5, 2026
1cd9488
Merge remote-tracking branch 'origin/iccs/pcr-solver' into pb/fieldna…
petebachant May 5, 2026
7589db0
Another attempt at fusion
petebachant May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ClimaCore.jl Release Notes
main
-------

- Add a specialized shared memory-based tridiagonal solver that uses PCR algorithm for CUDA backend. [2486](https://github.com/CliMA/ClimaCore.jl/pull/2486)

v0.14.51
-------

Expand Down
27 changes: 19 additions & 8 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
return nothing
end
import MultiBroadcastFusion
const MBFCUDA =
mbf_cuda_ext() =
Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt)
# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
Expand All @@ -107,14 +107,25 @@ function fused_copyto!(
(Nv > 0 && Nh > 0) || return nothing # short circuit

if pkgversion(MultiBroadcastFusion) >= v"0.3.3"
# Automatically split kernels by available parameter memory space:
fmbs = MBFCUDA.partition_kernels(
fmb,
FusedMultiBroadcast,
fused_multibroadcast_args,
)
for fmb in fmbs
mbfcuda = mbf_cuda_ext()
# Check if the fused kernel fits within parameter memory limits.
# If it fits, launch directly with the concrete-typed fmb to preserve
# type information through the CUDA kernel compilation. Calling through
# partition_kernels erases the fmb type (returns Any), which causes
# dynamic dispatch inside the GPU kernel.
if mbfcuda.param_usage_args(fused_multibroadcast_args(fmb)) ≤
mbfcuda.get_param_lim()
launch_fused_copyto!(fmb)
else
# Rare: FMB too large for one kernel — split, accepting type erasure.
fmbs = mbfcuda.partition_kernels(
fmb,
FusedMultiBroadcast,
fused_multibroadcast_args,
)
for fmb_split in fmbs
launch_fused_copyto!(fmb_split)
end
end
else
launch_fused_copyto!(fmb)
Expand Down
128 changes: 126 additions & 2 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@ import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.DataLayouts: vindex
import ClimaCore.DataLayouts: vindex, universal_size
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values

function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))

Ni, Nj, _, Nv, Nh = size(Fields.field_values(A))

# Tridiagonal solvers are handled by special implementation
# The special solver is limited in Nv by the number of threads per block
# hence it cannot be used for very large matrices.
# 512 should run on most GPUs
if eltype(A) <: MatrixFields.TridiagonalMatrixRow && Nv <= 512
single_field_solve_tridiagonal!(cache, x, A, b)
return
end

us = UniversalSize(Fields.field_values(A))
mask = Spaces.get_mask(axes(x))
cart_inds = cartesian_indices_columnwise(us)
Expand Down Expand Up @@ -210,3 +221,116 @@ function band_matrix_solve_local_mem!(
end
return nothing
end


function tridiag_pcr_kernel!(
x, a, b, c, d, ::Val{Nv}, ::Val{n_iter},
) where {Nv, n_iter}
(idx_i, idx_j, idx_h) = blockIdx()
i = threadIdx().x
if i > Nv
return nothing
end

s_a = CUDA.CuStaticSharedArray(eltype(a), Nv)
s_b = CUDA.CuStaticSharedArray(eltype(b), Nv)
s_c = CUDA.CuStaticSharedArray(eltype(c), Nv)
s_d = CUDA.CuStaticSharedArray(eltype(d), Nv)

idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h)

# Load into shared memory
@inbounds begin
local_ai = a[idx]
local_bi = b[idx]
local_ci = c[idx]
local_di = d[idx]

s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end
CUDA.sync_threads()

# PCR iterations
stride = 1

for _ in 1:n_iter
i_minus = max(i - stride, 1)
i_plus = min(i + stride, Nv)

# Compute elimination factors
@inbounds begin
k1 = (i > stride) ? -local_ai * inv(s_b[i_minus]) : zero(eltype(a))
k2 = (i <= Nv - stride) ? -local_ci * inv(s_b[i_plus]) : zero(eltype(a))

# Update coefficients
local_ai = k1 * s_a[i_minus]
local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus]
local_ci = k2 * s_c[i_plus]
local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus]
end

CUDA.sync_threads()

# Copy back for next iteration
@inbounds begin
s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end

CUDA.sync_threads()
stride *= 2
end

# Final solve into x
@inbounds x[idx] = inv(s_b[i]) * s_d[i]
return nothing
end


"""
single_field_solve_tridiagonal!(cache, x, A, b)

Specialized solver for the tridiagonal MatrixField. Solves each column in
parallel launching Nv threads per block where Nv is the number of vertical levels.
Works best if Nv is multiple of 32. There is an upper limit on the size of Nv
due to resource limits of the GPU (register and shared memory usage). For
A100 it is 1024, but may differ depending on the hardware.
"""
function single_field_solve_tridiagonal!(cache, x, A, b)

device = ClimaComms.device(x)
device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.")

eltype(A) <: MatrixFields.TridiagonalMatrixRow || error(
"This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))",
)

# Get field dimensions
Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A))

# Prepare data
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
A₋₁, A₀, A₊₁ = Aⱼs
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)

# Solve
threads_per_block = Nv
n_iter = ceil(Int, log2(Nv))
args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter))

auto_launch!(
tridiag_pcr_kernel!,
args;
threads_s = (threads_per_block,),
blocks_s = (Ni, Nj, Nh),
)

call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
return nothing
end
25 changes: 16 additions & 9 deletions src/DataLayouts/fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,29 @@ Base.@propagate_inbounds function rcopyto_at_linear!(pairs::Tuple, I)
unrolled_foreach(Base.Fix2(rcopyto_at_linear!, I), pairs)
end

# Normalize a scalar Broadcasted (0-dimensional) so it can be indexed with a
# scalar index. Uses type dispatch so the return type equals the input type —
# no Union is produced, keeping concrete types through the map closure below.
@inline _normalize_bc(bc) = bc
@inline _normalize_bc(
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <: Union{
Base.Broadcast.AbstractArrayStyle{0},
Base.Broadcast.Style{Tuple},
},
} = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)

# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
fmb_inst = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = pair.second
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(pair.first, bc′)
Pair(pair.first, _normalize_bc(pair.second))
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
Expand Down
9 changes: 8 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,14 @@ function Spaces.weighted_dss!(
)
end

cuda_synchronize(device; blocking = true)
needs_sync =
dss_buffer1 isa Topologies.DSSBuffer &&
!isempty(dss_buffer1.perimeter_elems) ||
any(
b isa Topologies.DSSBuffer && !isempty(b.perimeter_elems)
for (_, b) in field_buffer_pairs
)
needs_sync && cuda_synchronize(device; blocking = true)
dss_buffer1 isa Topologies.DSSBuffer &&
ClimaComms.start(dss_buffer1.graph_context)
for (field, dss_buffer) in field_buffer_pairs
Expand Down
21 changes: 19 additions & 2 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ we might want to parallelize it in the future).
If `Aₙₙ` is a diagonal matrix, the equation `Aₙₙ * xₙ = bₙ` is solved by making a
single pass over the data, setting each `xₙ[i]` to `inv(Aₙₙ[i, i]) * bₙ[i]`.

Otherwise, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination
Otherwise, on a CPU, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination
(without pivoting), which makes two passes over the data. This is currently only
implemented for tri-diagonal and penta-diagonal matrices `Aₙₙ`. In Gaussian
elimination, `Aₙₙ` is effectively factorized into the product `Lₙ * Dₙ * Uₙ`,
Expand All @@ -229,6 +229,17 @@ is referred to as "back substitution". These operations can become numerically
unstable when `Aₙₙ` has entries with large disparities in magnitude, but avoiding
this would require swapping the rows of `Aₙₙ` (i.e., replacing `Dₙ` with a
partial pivoting matrix).

For performance reasons, on a GPU different solver is used for tri-diagonal systems
that makes use of parallel cyclic reduction (PCR) method. This is to make better use of
the GPU parallelism and shared memory. Since in this solver we need to launch
a thread per each row of the matrix, it is used only for systems smaller than 512
as to not violate CUDA limitations. Above that size, the code will fall back to
the Gaussian elimination method used for CPU (and may show degraded performance).

PCR method works by recursively decomposing a larger tridiagonal system into two
system of half the size. This process is continued until we are left with a system
of size 1, which can be solved directly.
"""
struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end

Expand Down Expand Up @@ -279,9 +290,15 @@ function run_field_matrix_solver!(
case1 = length(names) == 1
case2 = all(name -> cheap_inv(A[name, name]), names.values)
case3 = any(name -> cheap_inv(A[name, name]), names.values)

# Direct all TridiagonalMatrixRow cases to `single_field_solve` path so
# they can be redirected to a specialised solver
# TODO: Group multiple Tridiagonals to the special 'multiple_field' solver
case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values)

# TODO: remove case3 and implement _single_field_solve_diag_matrix_row!
# in multiple_field_solve!
if case1 || case2 || case3
if case1 || case2 || case3 || case4
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
48 changes: 42 additions & 6 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,18 +971,54 @@ Base.Broadcast.materialize!(
vector_or_matrix::FieldNameDict,
) = Base.Broadcast.materialize!(field_vector_view(dest), vector_or_matrix)

# Pairs are eligible for fused multi-broadcast when both sides are plain Fields
# of the same concrete type (so axes / data layout / element type all match) —
# this keeps the GPU codegen for `_newindex` static.
@inline _is_fusable_pair(dest_entry::F, entry::F) where {F <: Fields.Field} =
true
@inline _is_fusable_pair(_, _) = false

NVTX.@annotate function copyto_foreach!(
dest::FieldNameDict,
vector_or_matrix::FieldNameDict,
)
foreach(keys(vector_or_matrix)) do key
entry = vector_or_matrix[key]
if dest[key] isa ScalingFieldMatrixEntry
dest[key] == entry || error("matrix entry at $key is immutable")
key_values = keys(vector_or_matrix).values
triples = unrolled_map(key_values) do key
(key, dest[key], vector_or_matrix[key])
end
fusable = unrolled_filter(triples) do t
_is_fusable_pair(t[2], t[3])
end
non_fusable = unrolled_filter(triples) do t
!_is_fusable_pair(t[2], t[3])
end
# All fusable triples must share the same concrete destination type for
# FusedMultiBroadcast (its `check_mismatched_spaces` requires uniform space
# types across pairs, and uniform `_newindex` types across pairs).
uniform_fusable = if length(fusable) > 1
F1 = typeof(fusable[1][2])
unrolled_all(t -> typeof(t[2]) === F1, fusable)
else
false
end
if uniform_fusable
pairs = unrolled_map(fusable) do t
Pair(t[2], Base.broadcasted(identity, t[3]))
end
Base.copyto!(Fields.FusedMultiBroadcast(pairs))
else
unrolled_foreach(fusable) do t
t[2] .= t[3]
end
end
unrolled_foreach(non_fusable) do t
key, dest_entry, entry = t
if dest_entry isa ScalingFieldMatrixEntry
dest_entry == entry || error("matrix entry at $key is immutable")
elseif entry isa ScalingFieldMatrixEntry
dest[key] .= (entry,)
dest_entry .= (entry,)
else
dest[key] .= entry
dest_entry .= entry
end
end
end
Expand Down
11 changes: 7 additions & 4 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ function _Remapper(
if horiz_method isa BilinearRemapping
quad_pts = quad_points
if num_hdims == 1
# 1D: linear on 2-point cell.
# 1D: linear on 2-point cell.
ξ1s = ξs_split[1]
i_arr = [clamp(searchsortedlast(quad_pts, ξ1), 1, Nq - 1) for ξ1 in ξ1s]
s_arr = [
Expand All @@ -396,7 +396,7 @@ function _Remapper(
local_bilinear_t = local_bilinear_j = nothing
local_horiz_interpolation_weights = nothing
else
# 2D: bilinear on 2×2 cell.
# 2D: bilinear on 2×2 cell.
n = length(ξs_split[1])
s_arr = Vector{FT}(undef, n)
t_arr = Vector{FT}(undef, n)
Expand Down Expand Up @@ -632,7 +632,7 @@ function _set_interpolated_values_bilinear!(
for (field_index, field) in enumerate(fields)
fv = Fields.field_values(field)
# out_index = horizontal target point
# vindex = vertical target level
# vindex = vertical target level
# h = element index
# (i, s) = 1D linear stencil.
@inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights)
Expand Down Expand Up @@ -1030,7 +1030,10 @@ function _collect_interpolated_values!(
index_field_end::Int;
only_one_field,
)
cuda_synchronize(ClimaComms.device(remapper.comms_ctx)) # Sync streams before MPI calls
# Sync streams before MPI calls if we are on GPU and we have more than one
# process, to ensure that the data is ready
ClimaComms.nprocs(remapper.comms_ctx) > 1 &&
cuda_synchronize(ClimaComms.device(remapper.comms_ctx))
if only_one_field
ClimaComms.reduce!(
remapper.comms_ctx,
Expand Down
Loading
Loading