diff --git a/NEWS.md b/NEWS.md index 6a6a4ec8d3..a30e43f1d6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ ClimaCore.jl Release Notes main ------- +- Add a specialized shared memory-based tridiagonal solver that uses PCR algorithm for CUDA backend. [2486](https://github.com/CliMA/ClimaCore.jl/pull/2486) + v0.14.51 ------- diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index fbb2ed5f42..56a0d2b5c6 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -7,13 +7,24 @@ import ClimaCore.Fields import ClimaCore.Spaces import ClimaCore.Topologies import ClimaCore.MatrixFields -import ClimaCore.DataLayouts: vindex +import ClimaCore.DataLayouts: vindex, universal_size import ClimaCore.MatrixFields: single_field_solve! import ClimaCore.MatrixFields: _single_field_solve! import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) - Ni, Nj, _, _, Nh = size(Fields.field_values(A)) + + Ni, Nj, _, Nv, Nh = size(Fields.field_values(A)) + + # Tridiagonal solvers are handled by special implementation + # The special solver is limited in Nv by the number of threads per block + # hence it cannot be used for very large matrices. + # 512 should run on most GPUs + if eltype(A) <: MatrixFields.TridiagonalMatrixRow && Nv <= 512 + single_field_solve_tridiagonal!(cache, x, A, b) + return + end + us = UniversalSize(Fields.field_values(A)) mask = Spaces.get_mask(axes(x)) cart_inds = cartesian_indices_columnwise(us) @@ -210,3 +221,116 @@ function band_matrix_solve_local_mem!( end return nothing end + + +function tridiag_pcr_kernel!( + x, a, b, c, d, ::Val{Nv}, ::Val{n_iter}, +) where {Nv, n_iter} + (idx_i, idx_j, idx_h) = blockIdx() + i = threadIdx().x + if i > Nv + return nothing + end + + s_a = CUDA.CuStaticSharedArray(eltype(a), Nv) + s_b = CUDA.CuStaticSharedArray(eltype(b), Nv) + s_c = CUDA.CuStaticSharedArray(eltype(c), Nv) + s_d = CUDA.CuStaticSharedArray(eltype(d), Nv) + + idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h) + + # Load into shared memory + @inbounds begin + local_ai = a[idx] + local_bi = b[idx] + local_ci = c[idx] + local_di = d[idx] + + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + CUDA.sync_threads() + + # PCR iterations + stride = 1 + + for _ in 1:n_iter + i_minus = max(i - stride, 1) + i_plus = min(i + stride, Nv) + + # Compute elimination factors + @inbounds begin + k1 = (i > stride) ? -local_ai * inv(s_b[i_minus]) : zero(eltype(a)) + k2 = (i <= Nv - stride) ? -local_ci * inv(s_b[i_plus]) : zero(eltype(a)) + + # Update coefficients + local_ai = k1 * s_a[i_minus] + local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus] + local_ci = k2 * s_c[i_plus] + local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus] + end + + CUDA.sync_threads() + + # Copy back for next iteration + @inbounds begin + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + + CUDA.sync_threads() + stride *= 2 + end + + # Final solve into x + @inbounds x[idx] = inv(s_b[i]) * s_d[i] + return nothing +end + + +""" + single_field_solve_tridiagonal!(cache, x, A, b) + +Specialized solver for the tridiagonal MatrixField. Solves each column in +parallel launching Nv threads per block where Nv is the number of vertical levels. +Works best if Nv is multiple of 32. There is an upper limit on the size of Nv +due to resource limits of the GPU (register and shared memory usage). For +A100 it is 1024, but may differ depending on the hardware. +""" +function single_field_solve_tridiagonal!(cache, x, A, b) + + device = ClimaComms.device(x) + device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.") + + eltype(A) <: MatrixFields.TridiagonalMatrixRow || error( + "This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))", + ) + + # Get field dimensions + Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A)) + + # Prepare data + Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries)) + A₋₁, A₀, A₊₁ = Aⱼs + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + + # Solve + threads_per_block = Nv + n_iter = ceil(Int, log2(Nv)) + args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter)) + + auto_launch!( + tridiag_pcr_kernel!, + args; + threads_s = (threads_per_block,), + blocks_s = (Ni, Nj, Nh), + ) + + call_post_op_callback() && post_op_callback(x, device, cache, x, A, b) + return nothing +end diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 90634a184c..59c0f9bb0e 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -216,7 +216,7 @@ we might want to parallelize it in the future). If `Aₙₙ` is a diagonal matrix, the equation `Aₙₙ * xₙ = bₙ` is solved by making a single pass over the data, setting each `xₙ[i]` to `inv(Aₙₙ[i, i]) * bₙ[i]`. -Otherwise, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination +Otherwise, on a CPU, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination (without pivoting), which makes two passes over the data. This is currently only implemented for tri-diagonal and penta-diagonal matrices `Aₙₙ`. In Gaussian elimination, `Aₙₙ` is effectively factorized into the product `Lₙ * Dₙ * Uₙ`, @@ -229,6 +229,17 @@ is referred to as "back substitution". These operations can become numerically unstable when `Aₙₙ` has entries with large disparities in magnitude, but avoiding this would require swapping the rows of `Aₙₙ` (i.e., replacing `Dₙ` with a partial pivoting matrix). + +For performance reasons, on a GPU different solver is used for tri-diagonal systems +that makes use of parallel cyclic reduction (PCR) method. This is to make better use of +the GPU parallelism and shared memory. Since in this solver we need to launch +a thread per each row of the matrix, it is used only for systems smaller than 512 +as to not violate CUDA limitations. Above that size, the code will fall back to +the Gaussian elimination method used for CPU (and may show degraded performance). + +PCR method works by recursively decomposing a larger tridiagonal system into two +system of half the size. This process is continued until we are left with a system +of size 1, which can be solved directly. """ struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end @@ -279,9 +290,15 @@ function run_field_matrix_solver!( case1 = length(names) == 1 case2 = all(name -> cheap_inv(A[name, name]), names.values) case3 = any(name -> cheap_inv(A[name, name]), names.values) + + # Direct all TridiagonalMatrixRow cases to `single_field_solve` path so + # they can be redirected to a specialised solver + # TODO: Group multiple Tridiagonals to the special 'multiple_field' solver + case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values) + # TODO: remove case3 and implement _single_field_solve_diag_matrix_row! # in multiple_field_solve! - if case1 || case2 || case3 + if case1 || case2 || case3 || case4 foreach(names) do name single_field_solve!(cache[name], x[name], A[name, name], b[name]) end