Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 121 additions & 1 deletion ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.DataLayouts: vindex
import ClimaCore.DataLayouts: vindex, universal_size
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
import ClimaCore.RecursiveApply: ⊠, ⊞, ⊟, rmap, rzero, rdiv

function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)

# Tridiagonal solvers are handled by special implementation
if eltype(A) <: MatrixFields.TridiagonalMatrixRow
single_field_solve_tridiagonal!(cache, x, A, b)
return
end

Ni, Nj, _, _, Nh = size(Fields.field_values(A))
us = UniversalSize(Fields.field_values(A))
mask = Spaces.get_mask(axes(x))
Expand Down Expand Up @@ -211,3 +218,116 @@ function band_matrix_solve_local_mem!(
end
return nothing
end


function tridiag_pcr_kernel!(
x, a, b, c, d, ::Val{n}, ::Val{n_iter}
) where {n, n_iter}
(idx_i, idx_j, idx_h) = blockIdx()
i = threadIdx().x
if i > n
return nothing
end

T = eltype(a)
n_shared = typeof(n)(cld(n, 32) * 32) # Round n to next multiple to avoid bank conflicts (?)
s_a = CUDA.CuStaticSharedArray(T, n_shared)
s_b = CUDA.CuStaticSharedArray(T, n_shared)
s_c = CUDA.CuStaticSharedArray(T, n_shared)
s_d = CUDA.CuStaticSharedArray(T, n_shared)

idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h)

# Load into shared memory
@inbounds begin
local_ai = getindex_field(a, idx)
local_bi = getindex_field(b, idx)
local_ci = getindex_field(c, idx)
local_di = getindex_field(d, idx)

s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end
CUDA.sync_threads()

# PCR iterations
stride = 1

for _ in 1:n_iter
i_minus = max(i - stride, 1)
i_plus = min(i + stride, n)

# Compute elimination factors
@inbounds begin
k1 = (i > stride) ? -local_ai / s_b[i_minus] : zero(T)
k2 = (i <= n - stride) ? -local_ci / s_b[i_plus] : zero(T)

# Update coefficients
local_ai = k1 * s_a[i_minus]
local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus]
local_ci = k2 * s_c[i_plus]
local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus]
end

CUDA.sync_threads()

# Copy back for next iteration
@inbounds begin
s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end

CUDA.sync_threads()
stride *= 2
end

# Final solve into x
@inbounds setindex_field!(x, local_di / local_bi, idx)
return nothing
end


"""
single_field_solve_tridiagonal!(cache, x, A, b)

Specialized solver for the tridiagonal MatrixField. Solves each column in
parallel launching Nv threads per block where Nv is the number of vertical levels.
Works best if Nv is multiple of 32. Also must be smaller then 256.
"""
function single_field_solve_tridiagonal!(cache, x, A, b)

device = ClimaComms.device(x)
device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.")

eltype(A) <: MatrixFields.TridiagonalMatrixRow || error(
"This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))"
)

# Get field dimensions
Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A))

# Prepare data
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
A₋₁, A₀, A₊₁ = Aⱼs
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)

# Solve
threads_per_block = min(Nv, 256)
n_iter = ceil(Int, log2(Nv))
args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter))

auto_launch!(
tridiag_pcr_kernel!,
args;
threads_s = (threads_per_block,),
blocks_s = (Ni, Nj, Nh),
)

call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
return nothing
end
Empty file added file
Empty file.
8 changes: 7 additions & 1 deletion src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,15 @@ function run_field_matrix_solver!(
case1 = length(names) == 1
case2 = all(name -> cheap_inv(A[name, name]), names.values)
case3 = any(name -> cheap_inv(A[name, name]), names.values)

# Direct all TridiagonalMatrixRow cases to `single_field_solve` path so
# they can be redirected to a specialised solver
# TODO: Group multiple Tridiagonals to the special 'multiple_field' solver
case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values)

# TODO: remove case3 and implement _single_field_solve_diag_matrix_row!
# in multiple_field_solve!
if case1 || case2 || case3
if case1 || case2 || case3 || case4
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
Loading