diff --git a/NEWS.md b/NEWS.md index 6a6a4ec8d3..a30e43f1d6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ ClimaCore.jl Release Notes main ------- +- Add a specialized shared memory-based tridiagonal solver that uses PCR algorithm for CUDA backend. [2486](https://github.com/CliMA/ClimaCore.jl/pull/2486) + v0.14.51 ------- diff --git a/ext/cuda/data_layouts_fused_copyto.jl b/ext/cuda/data_layouts_fused_copyto.jl index ced59900dc..3bc5d11806 100644 --- a/ext/cuda/data_layouts_fused_copyto.jl +++ b/ext/cuda/data_layouts_fused_copyto.jl @@ -81,7 +81,7 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us) return nothing end import MultiBroadcastFusion -const MBFCUDA = +mbf_cuda_ext() = Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt) # https://github.com/JuliaLang/julia/issues/56295 # Julia 1.11's Base.Broadcast currently requires @@ -107,14 +107,25 @@ function fused_copyto!( (Nv > 0 && Nh > 0) || return nothing # short circuit if pkgversion(MultiBroadcastFusion) >= v"0.3.3" - # Automatically split kernels by available parameter memory space: - fmbs = MBFCUDA.partition_kernels( - fmb, - FusedMultiBroadcast, - fused_multibroadcast_args, - ) - for fmb in fmbs + mbfcuda = mbf_cuda_ext() + # Check if the fused kernel fits within parameter memory limits. + # If it fits, launch directly with the concrete-typed fmb to preserve + # type information through the CUDA kernel compilation. Calling through + # partition_kernels erases the fmb type (returns Any), which causes + # dynamic dispatch inside the GPU kernel. + if mbfcuda.param_usage_args(fused_multibroadcast_args(fmb)) ≤ + mbfcuda.get_param_lim() launch_fused_copyto!(fmb) + else + # Rare: FMB too large for one kernel — split, accepting type erasure. + fmbs = mbfcuda.partition_kernels( + fmb, + FusedMultiBroadcast, + fused_multibroadcast_args, + ) + for fmb_split in fmbs + launch_fused_copyto!(fmb_split) + end end else launch_fused_copyto!(fmb) diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index fbb2ed5f42..56a0d2b5c6 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -7,13 +7,24 @@ import ClimaCore.Fields import ClimaCore.Spaces import ClimaCore.Topologies import ClimaCore.MatrixFields -import ClimaCore.DataLayouts: vindex +import ClimaCore.DataLayouts: vindex, universal_size import ClimaCore.MatrixFields: single_field_solve! import ClimaCore.MatrixFields: _single_field_solve! import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) - Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + + Ni, Nj, _, Nv, Nh = size(Fields.field_values(A)) + + # Tridiagonal solvers are handled by special implementation + # The special solver is limited in Nv by the number of threads per block + # hence it cannot be used for very large matrices. + # 512 should run on most GPUs + if eltype(A) <: MatrixFields.TridiagonalMatrixRow && Nv <= 512 + single_field_solve_tridiagonal!(cache, x, A, b) + return + end + us = UniversalSize(Fields.field_values(A)) mask = Spaces.get_mask(axes(x)) cart_inds = cartesian_indices_columnwise(us) @@ -210,3 +221,116 @@ function band_matrix_solve_local_mem!( end return nothing end + + +function tridiag_pcr_kernel!( + x, a, b, c, d, ::Val{Nv}, ::Val{n_iter}, +) where {Nv, n_iter} + (idx_i, idx_j, idx_h) = blockIdx() + i = threadIdx().x + if i > Nv + return nothing + end + + s_a = CUDA.CuStaticSharedArray(eltype(a), Nv) + s_b = CUDA.CuStaticSharedArray(eltype(b), Nv) + s_c = CUDA.CuStaticSharedArray(eltype(c), Nv) + s_d = CUDA.CuStaticSharedArray(eltype(d), Nv) + + idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h) + + # Load into shared memory + @inbounds begin + local_ai = a[idx] + local_bi = b[idx] + local_ci = c[idx] + local_di = d[idx] + + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + CUDA.sync_threads() + + # PCR iterations + stride = 1 + + for _ in 1:n_iter + i_minus = max(i - stride, 1) + i_plus = min(i + stride, Nv) + + # Compute elimination factors + @inbounds begin + k1 = (i > stride) ? -local_ai * inv(s_b[i_minus]) : zero(eltype(a)) + k2 = (i <= Nv - stride) ? -local_ci * inv(s_b[i_plus]) : zero(eltype(a)) + + # Update coefficients + local_ai = k1 * s_a[i_minus] + local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus] + local_ci = k2 * s_c[i_plus] + local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus] + end + + CUDA.sync_threads() + + # Copy back for next iteration + @inbounds begin + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + + CUDA.sync_threads() + stride *= 2 + end + + # Final solve into x + @inbounds x[idx] = inv(s_b[i]) * s_d[i] + return nothing +end + + +""" + single_field_solve_tridiagonal!(cache, x, A, b) + +Specialized solver for the tridiagonal MatrixField. Solves each column in +parallel launching Nv threads per block where Nv is the number of vertical levels. +Works best if Nv is multiple of 32. There is an upper limit on the size of Nv +due to resource limits of the GPU (register and shared memory usage). For +A100 it is 1024, but may differ depending on the hardware. +""" +function single_field_solve_tridiagonal!(cache, x, A, b) + + device = ClimaComms.device(x) + device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.") + + eltype(A) <: MatrixFields.TridiagonalMatrixRow || error( + "This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))", + ) + + # Get field dimensions + Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A)) + + # Prepare data + Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries)) + A₋₁, A₀, A₊₁ = Aⱼs + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + + # Solve + threads_per_block = Nv + n_iter = ceil(Int, log2(Nv)) + args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter)) + + auto_launch!( + tridiag_pcr_kernel!, + args; + threads_s = (threads_per_block,), + blocks_s = (Ni, Nj, Nh), + ) + + call_post_op_callback() && post_op_callback(x, device, cache, x, A, b) + return nothing +end diff --git a/src/DataLayouts/fused_copyto.jl b/src/DataLayouts/fused_copyto.jl index e5e8cdacae..fce78b98be 100644 --- a/src/DataLayouts/fused_copyto.jl +++ b/src/DataLayouts/fused_copyto.jl @@ -21,6 +21,21 @@ Base.@propagate_inbounds function rcopyto_at_linear!(pairs::Tuple, I) unrolled_foreach(Base.Fix2(rcopyto_at_linear!, I), pairs) end +# Normalize a scalar Broadcasted (0-dimensional) so it can be indexed with a +# scalar index. Uses type dispatch so the return type equals the input type — +# no Union is produced, keeping concrete types through the map closure below. +@inline _normalize_bc(bc) = bc +@inline _normalize_bc( + bc::Base.Broadcast.Broadcasted{Style}, +) where { + Style <: Union{ + Base.Broadcast.AbstractArrayStyle{0}, + Base.Broadcast.Style{Tuple}, + }, +} = Base.Broadcast.instantiate( + Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()), +) + # Fused multi-broadcast entry point for DataLayouts function Base.copyto!( fmbc::FusedMultiBroadcast{T}, @@ -28,15 +43,7 @@ function Base.copyto!( dest1 = first(fmbc.pairs).first fmb_inst = FusedMultiBroadcast( map(fmbc.pairs) do pair - bc = pair.second - bc′ = if isascalar(bc) - Base.Broadcast.instantiate( - Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()), - ) - else - bc - end - Pair(pair.first, bc′) + Pair(pair.first, _normalize_bc(pair.second)) end, ) # check_fused_broadcast_axes(fmbc) # we should already have checked the axes diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index efeb824844..0b9d4a9d02 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -464,7 +464,14 @@ function Spaces.weighted_dss!( ) end - cuda_synchronize(device; blocking = true) + needs_sync = + dss_buffer1 isa Topologies.DSSBuffer && + !isempty(dss_buffer1.perimeter_elems) || + any( + b isa Topologies.DSSBuffer && !isempty(b.perimeter_elems) + for (_, b) in field_buffer_pairs + ) + needs_sync && cuda_synchronize(device; blocking = true) dss_buffer1 isa Topologies.DSSBuffer && ClimaComms.start(dss_buffer1.graph_context) for (field, dss_buffer) in field_buffer_pairs diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 90634a184c..59c0f9bb0e 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -216,7 +216,7 @@ we might want to parallelize it in the future). If `Aₙₙ` is a diagonal matrix, the equation `Aₙₙ * xₙ = bₙ` is solved by making a single pass over the data, setting each `xₙ[i]` to `inv(Aₙₙ[i, i]) * bₙ[i]`. -Otherwise, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination +Otherwise, on a CPU, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination (without pivoting), which makes two passes over the data. This is currently only implemented for tri-diagonal and penta-diagonal matrices `Aₙₙ`. In Gaussian elimination, `Aₙₙ` is effectively factorized into the product `Lₙ * Dₙ * Uₙ`, @@ -229,6 +229,17 @@ is referred to as "back substitution". These operations can become numerically unstable when `Aₙₙ` has entries with large disparities in magnitude, but avoiding this would require swapping the rows of `Aₙₙ` (i.e., replacing `Dₙ` with a partial pivoting matrix). + +For performance reasons, on a GPU different solver is used for tri-diagonal systems +that makes use of parallel cyclic reduction (PCR) method. This is to make better use of +the GPU parallelism and shared memory. Since in this solver we need to launch +a thread per each row of the matrix, it is used only for systems smaller than 512 +as to not violate CUDA limitations. Above that size, the code will fall back to +the Gaussian elimination method used for CPU (and may show degraded performance). + +PCR method works by recursively decomposing a larger tridiagonal system into two +system of half the size. This process is continued until we are left with a system +of size 1, which can be solved directly. """ struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end @@ -279,9 +290,15 @@ function run_field_matrix_solver!( case1 = length(names) == 1 case2 = all(name -> cheap_inv(A[name, name]), names.values) case3 = any(name -> cheap_inv(A[name, name]), names.values) + + # Direct all TridiagonalMatrixRow cases to `single_field_solve` path so + # they can be redirected to a specialised solver + # TODO: Group multiple Tridiagonals to the special 'multiple_field' solver + case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values) + # TODO: remove case3 and implement _single_field_solve_diag_matrix_row! # in multiple_field_solve! - if case1 || case2 || case3 + if case1 || case2 || case3 || case4 foreach(names) do name single_field_solve!(cache[name], x[name], A[name, name], b[name]) end diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 589f16e37b..2829bd7492 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -971,18 +971,54 @@ Base.Broadcast.materialize!( vector_or_matrix::FieldNameDict, ) = Base.Broadcast.materialize!(field_vector_view(dest), vector_or_matrix) +# Pairs are eligible for fused multi-broadcast when both sides are plain Fields +# of the same concrete type (so axes / data layout / element type all match) — +# this keeps the GPU codegen for `_newindex` static. +@inline _is_fusable_pair(dest_entry::F, entry::F) where {F <: Fields.Field} = + true +@inline _is_fusable_pair(_, _) = false + NVTX.@annotate function copyto_foreach!( dest::FieldNameDict, vector_or_matrix::FieldNameDict, ) - foreach(keys(vector_or_matrix)) do key - entry = vector_or_matrix[key] - if dest[key] isa ScalingFieldMatrixEntry - dest[key] == entry || error("matrix entry at $key is immutable") + key_values = keys(vector_or_matrix).values + triples = unrolled_map(key_values) do key + (key, dest[key], vector_or_matrix[key]) + end + fusable = unrolled_filter(triples) do t + _is_fusable_pair(t[2], t[3]) + end + non_fusable = unrolled_filter(triples) do t + !_is_fusable_pair(t[2], t[3]) + end + # All fusable triples must share the same concrete destination type for + # FusedMultiBroadcast (its `check_mismatched_spaces` requires uniform space + # types across pairs, and uniform `_newindex` types across pairs). + uniform_fusable = if length(fusable) > 1 + F1 = typeof(fusable[1][2]) + unrolled_all(t -> typeof(t[2]) === F1, fusable) + else + false + end + if uniform_fusable + pairs = unrolled_map(fusable) do t + Pair(t[2], Base.broadcasted(identity, t[3])) + end + Base.copyto!(Fields.FusedMultiBroadcast(pairs)) + else + unrolled_foreach(fusable) do t + t[2] .= t[3] + end + end + unrolled_foreach(non_fusable) do t + key, dest_entry, entry = t + if dest_entry isa ScalingFieldMatrixEntry + dest_entry == entry || error("matrix entry at $key is immutable") elseif entry isa ScalingFieldMatrixEntry - dest[key] .= (entry,) + dest_entry .= (entry,) else - dest[key] .= entry + dest_entry .= entry end end end diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index b361cf9ab1..41f17618c8 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -384,7 +384,7 @@ function _Remapper( if horiz_method isa BilinearRemapping quad_pts = quad_points if num_hdims == 1 - # 1D: linear on 2-point cell. + # 1D: linear on 2-point cell. ξ1s = ξs_split[1] i_arr = [clamp(searchsortedlast(quad_pts, ξ1), 1, Nq - 1) for ξ1 in ξ1s] s_arr = [ @@ -396,7 +396,7 @@ function _Remapper( local_bilinear_t = local_bilinear_j = nothing local_horiz_interpolation_weights = nothing else - # 2D: bilinear on 2×2 cell. + # 2D: bilinear on 2×2 cell. n = length(ξs_split[1]) s_arr = Vector{FT}(undef, n) t_arr = Vector{FT}(undef, n) @@ -632,7 +632,7 @@ function _set_interpolated_values_bilinear!( for (field_index, field) in enumerate(fields) fv = Fields.field_values(field) # out_index = horizontal target point - # vindex = vertical target level + # vindex = vertical target level # h = element index # (i, s) = 1D linear stencil. @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) @@ -1030,7 +1030,10 @@ function _collect_interpolated_values!( index_field_end::Int; only_one_field, ) - cuda_synchronize(ClimaComms.device(remapper.comms_ctx)) # Sync streams before MPI calls + # Sync streams before MPI calls if we are on GPU and we have more than one + # process, to ensure that the data is ready + ClimaComms.nprocs(remapper.comms_ctx) > 1 && + cuda_synchronize(ClimaComms.device(remapper.comms_ctx)) if only_one_field ClimaComms.reduce!( remapper.comms_ctx, diff --git a/src/Spaces/dss.jl b/src/Spaces/dss.jl index bd204a2b60..7ec109b49d 100644 --- a/src/Spaces/dss.jl +++ b/src/Spaces/dss.jl @@ -73,7 +73,7 @@ end dss_buffer::Union{DSSBuffer, Nothing}, ) -Computes weighted dss of `data`. +Computes weighted dss of `data`. It comprises of the following steps: @@ -151,14 +151,14 @@ cuda_synchronize(device::ClimaComms.AbstractDevice; kwargs...) = nothing It comprises of the following steps: -1). Apply [`Spaces.dss_transform!`](@ref) on perimeter elements. This weights and tranforms vector -fields to physical basis if needed. Scalar fields are weighted. The transformed and/or weighted +1). Apply [`Spaces.dss_transform!`](@ref) on perimeter elements. This weights and transforms vector +fields to physical basis if needed. Scalar fields are weighted. The transformed and/or weighted perimeter `data` is stored in `perimeter_data`. 2). Apply [`Spaces.dss_local_ghost!`](@ref) This computes partial weighted DSS on ghost vertices, using only the information from `local` vertices. -3). [`Spaces.fill_send_buffer!`](@ref) +3). [`Spaces.fill_send_buffer!`](@ref) Loads the send buffer from `perimeter_data`. For unique ghost vertices, only data from the representative ghost vertices which store result of "ghost local" DSS are loaded. @@ -176,7 +176,7 @@ function weighted_dss_start!( sizeof(eltype(data)) > 0 || return nothing device = ClimaComms.device(topology(space)) weighted_dss_prepare!(data, space, dss_buffer) - cuda_synchronize(device; blocking = true) + isempty(dss_buffer.perimeter_elems) || cuda_synchronize(device; blocking = true) ClimaComms.start(dss_buffer.graph_context) return nothing end @@ -203,7 +203,7 @@ weighted_dss_start!(data, space, dss_buffer::Nothing) = nothing dss_buffer::DSSBuffer, ) -1). Apply [`Spaces.dss_transform!`](@ref) on interior elements. Local elements are split into interior +1). Apply [`Spaces.dss_transform!`](@ref) on interior elements. Local elements are split into interior and perimeter elements to facilitate overlapping of communication with computation. 2). Probe communication @@ -285,8 +285,8 @@ end 1). Finish communications. 2). Call [`Spaces.load_from_recv_buffer!`](@ref) -After the communication is complete, this adds data from the recv buffer to the corresponding location in -`perimeter_data`. For ghost vertices, this data is added only to the representative vertices. The values are +After the communication is complete, this adds data from the recv buffer to the corresponding location in +`perimeter_data`. For ghost vertices, this data is added only to the representative vertices. The values are then scattered to other local vertices corresponding to each unique ghost vertex in `dss_local_ghost`. 3). Call [`Spaces.dss_untransform!`](@ref) on all local elements.