diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index b32000f190..5ed499c89d 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -7,13 +7,20 @@ import ClimaCore.Fields import ClimaCore.Spaces import ClimaCore.Topologies import ClimaCore.MatrixFields -import ClimaCore.DataLayouts: vindex +import ClimaCore.DataLayouts: vindex, universal_size import ClimaCore.MatrixFields: single_field_solve! import ClimaCore.MatrixFields: _single_field_solve! import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values import ClimaCore.RecursiveApply: ⊠, ⊞, ⊟, rmap, rzero, rdiv function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) + + # Tridiagonal solvers are handled by special implementation + if eltype(A) <: MatrixFields.TridiagonalMatrixRow + single_field_solve_tridiagonal!(cache, x, A, b) + return + end + Ni, Nj, _, _, Nh = size(Fields.field_values(A)) us = UniversalSize(Fields.field_values(A)) mask = Spaces.get_mask(axes(x)) @@ -211,3 +218,116 @@ function band_matrix_solve_local_mem!( end return nothing end + + +function tridiag_pcr_kernel!( + x, a, b, c, d, ::Val{n}, ::Val{n_iter} +) where {n, n_iter} + (idx_i, idx_j, idx_h) = blockIdx() + i = threadIdx().x + if i > n + return nothing + end + + T = eltype(a) + n_shared = typeof(n)(cld(n, 32) * 32) # Round n to next multiple to avoid bank conflicts (?) + s_a = CUDA.CuStaticSharedArray(T, n_shared) + s_b = CUDA.CuStaticSharedArray(T, n_shared) + s_c = CUDA.CuStaticSharedArray(T, n_shared) + s_d = CUDA.CuStaticSharedArray(T, n_shared) + + idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h) + + # Load into shared memory + @inbounds begin + local_ai = getindex_field(a, idx) + local_bi = getindex_field(b, idx) + local_ci = getindex_field(c, idx) + local_di = getindex_field(d, idx) + + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + CUDA.sync_threads() + + # PCR iterations + stride = 1 + + for _ in 1:n_iter + i_minus = max(i - stride, 1) + i_plus = min(i + stride, n) + + # Compute elimination factors + @inbounds begin + k1 = (i > stride) ? -local_ai / s_b[i_minus] : zero(T) + k2 = (i <= n - stride) ? -local_ci / s_b[i_plus] : zero(T) + + # Update coefficients + local_ai = k1 * s_a[i_minus] + local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus] + local_ci = k2 * s_c[i_plus] + local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus] + end + + CUDA.sync_threads() + + # Copy back for next iteration + @inbounds begin + s_a[i] = local_ai + s_b[i] = local_bi + s_c[i] = local_ci + s_d[i] = local_di + end + + CUDA.sync_threads() + stride *= 2 + end + + # Final solve into x + @inbounds setindex_field!(x, local_di / local_bi, idx) + return nothing +end + + +""" + single_field_solve_tridiagonal!(cache, x, A, b) + +Specialized solver for the tridiagonal MatrixField. Solves each column in +parallel launching Nv threads per block where Nv is the number of vertical levels. +Works best if Nv is multiple of 32. Also must be smaller then 256. +""" +function single_field_solve_tridiagonal!(cache, x, A, b) + + device = ClimaComms.device(x) + device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.") + + eltype(A) <: MatrixFields.TridiagonalMatrixRow || error( + "This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))" + ) + + # Get field dimensions + Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A)) + + # Prepare data + Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries)) + A₋₁, A₀, A₊₁ = Aⱼs + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + + # Solve + threads_per_block = min(Nv, 256) + n_iter = ceil(Int, log2(Nv)) + args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter)) + + auto_launch!( + tridiag_pcr_kernel!, + args; + threads_s = (threads_per_block,), + blocks_s = (Ni, Nj, Nh), + ) + + call_post_op_callback() && post_op_callback(x, device, cache, x, A, b) + return nothing +end diff --git a/file b/file new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 90634a184c..b9af0f74c8 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -279,9 +279,15 @@ function run_field_matrix_solver!( case1 = length(names) == 1 case2 = all(name -> cheap_inv(A[name, name]), names.values) case3 = any(name -> cheap_inv(A[name, name]), names.values) + + # Direct all TridiagonalMatrixRow cases to `single_field_solve` path so + # they can be redirected to a specialised solver + # TODO: Group multiple Tridiagonals to the special 'multiple_field' solver + case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values) + # TODO: remove case3 and implement _single_field_solve_diag_matrix_row! # in multiple_field_solve! - if case1 || case2 || case3 + if case1 || case2 || case3 || case4 foreach(names) do name single_field_solve!(cache[name], x[name], A[name, name], b[name]) end