Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ClimaCore.jl Release Notes
main
-------

- Add a specialized shared memory-based tridiagonal solver that uses PCR algorithm for CUDA backend. [2486](https://github.com/CliMA/ClimaCore.jl/pull/2486)

v0.14.51
-------

Expand Down
128 changes: 126 additions & 2 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@ import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.DataLayouts: vindex
import ClimaCore.DataLayouts: vindex, universal_size
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values

function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))

Ni, Nj, _, Nv, Nh = size(Fields.field_values(A))

# Tridiagonal solvers are handled by special implementation
# The special solver is limited in Nv by the number of threads per block
# hence it cannot be used for very large matrices.
# 512 should run on most GPUs
if eltype(A) <: MatrixFields.TridiagonalMatrixRow && Nv <= 512
single_field_solve_tridiagonal!(cache, x, A, b)
return
end

us = UniversalSize(Fields.field_values(A))
mask = Spaces.get_mask(axes(x))
cart_inds = cartesian_indices_columnwise(us)
Expand Down Expand Up @@ -210,3 +221,116 @@ function band_matrix_solve_local_mem!(
end
return nothing
end


function tridiag_pcr_kernel!(
x, a, b, c, d, ::Val{Nv}, ::Val{n_iter},
) where {Nv, n_iter}
(idx_i, idx_j, idx_h) = blockIdx()
i = threadIdx().x
if i > Nv
return nothing
end

s_a = CUDA.CuStaticSharedArray(eltype(a), Nv)
s_b = CUDA.CuStaticSharedArray(eltype(b), Nv)
s_c = CUDA.CuStaticSharedArray(eltype(c), Nv)
s_d = CUDA.CuStaticSharedArray(eltype(d), Nv)

idx = CartesianIndex(idx_i, idx_j, 1, i, idx_h)

# Load into shared memory
@inbounds begin
local_ai = a[idx]
local_bi = b[idx]
local_ci = c[idx]
local_di = d[idx]

s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end
CUDA.sync_threads()

# PCR iterations
stride = 1

for _ in 1:n_iter
i_minus = max(i - stride, 1)
i_plus = min(i + stride, Nv)

# Compute elimination factors
@inbounds begin
k1 = (i > stride) ? -local_ai * inv(s_b[i_minus]) : zero(eltype(a))
k2 = (i <= Nv - stride) ? -local_ci * inv(s_b[i_plus]) : zero(eltype(a))

# Update coefficients
local_ai = k1 * s_a[i_minus]
local_bi = local_bi + k1 * s_c[i_minus] + k2 * s_a[i_plus]
local_ci = k2 * s_c[i_plus]
local_di = local_di + k1 * s_d[i_minus] + k2 * s_d[i_plus]
end

CUDA.sync_threads()

# Copy back for next iteration
@inbounds begin
s_a[i] = local_ai
s_b[i] = local_bi
s_c[i] = local_ci
s_d[i] = local_di
end

CUDA.sync_threads()
stride *= 2
end

# Final solve into x
@inbounds x[idx] = inv(s_b[i]) * s_d[i]
return nothing
end


"""
single_field_solve_tridiagonal!(cache, x, A, b)

Specialized solver for the tridiagonal MatrixField. Solves each column in
parallel launching Nv threads per block where Nv is the number of vertical levels.
Works best if Nv is multiple of 32. There is an upper limit on the size of Nv
due to resource limits of the GPU (register and shared memory usage). For
A100 it is 1024, but may differ depending on the hardware.
"""
function single_field_solve_tridiagonal!(cache, x, A, b)

device = ClimaComms.device(x)
device isa ClimaComms.CUDADevice || error("This solver supports only CUDA devices.")

eltype(A) <: MatrixFields.TridiagonalMatrixRow || error(
"This function expects a tridiagonal matrix field, but got a field with element type $(eltype(A))",
)

# Get field dimensions
Ni, Nj, _, Nv, Nh = universal_size(Fields.field_values(A))

# Prepare data
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
A₋₁, A₀, A₊₁ = Aⱼs
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)

# Solve
threads_per_block = Nv
n_iter = ceil(Int, log2(Nv))
args = (x_data, A₋₁, A₀, A₊₁, b_data, Val(Nv), Val(n_iter))

auto_launch!(
tridiag_pcr_kernel!,
args;
threads_s = (threads_per_block,),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not a big issue because most long simulations run with 64 vertical faces, but wouldn't this result in a theoretical occupancy of 50% when Nv==32?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! This is a biggest drawback of the kernel as it is. When the Nv is 33 basically we end up with a 50% effective occupancy. This is the worst case scenario. But I have to admit that it is annoyingly difficult to avoid...

As I mentioned in the description I have tried to improve it, by 'concatenating' multiple tridiagonal systems into a bigger system (for a test basically loaded all matrixes in a single horizontal element at once (i.j- dimensions)) and solve it instead using the same solver. This would avoid the "trailing masked threads" problem basically by averaging. But in the end it did not work. (You can see the attempt here)

Basically what happens is that since more threads participate in the synchronisation they spend more time for their brethren to catch up. In addition there are extra PCR steps due to a larger system. For the worst case scenario of Nv=33 the concatenated solver was as fast like the one in the PR. In general it was significantly slower ( I looked in a standalone "testbed" though. Not in AMIP)

What we still need to try is to do Thomas with single matrix per thread but in shared memory.

However there is also a possibility that switching to the shared memory decreased the fraction of the time spent in the solver to a point where we perhaps don't need to worry about the occupancy. Basically the change from Nv of 32 to Nv 33 would increase the runtime of the solver by a factor of about 2, but if is is small enough fraction of time it will have small overall effect on performance. I still need to test that though (the nsys results indicate that tridiagonal solutions are a very small fraction of the time, but I am not sure I count them all due to the kernel renaming, I have made nsys and ncu runs in AMIP without the renaming to verify but did not process the results yet).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To update on the above. I had a look at the nsys run without the kernel renaming and it seems that the GPU time in the tridiagonal solver is still rather large (3.4% percent) no_rename.tar.gz. This implies that something have gone wrong in the kernel renaming in the original report...

3.4% is a large number, sufficiently large to worry about the loss of the occupancy.

The only thing I am thinking is that it may not be trivial to change it so perhaps it would be better to address it in a separate, follow-up PR>

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a fix for that. I will try to open a PR for it by end of the PST day

blocks_s = (Ni, Nj, Nh),
)

call_post_op_callback() && post_op_callback(x, device, cache, x, A, b)
return nothing
end
21 changes: 19 additions & 2 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ we might want to parallelize it in the future).
If `Aₙₙ` is a diagonal matrix, the equation `Aₙₙ * xₙ = bₙ` is solved by making a
single pass over the data, setting each `xₙ[i]` to `inv(Aₙₙ[i, i]) * bₙ[i]`.

Otherwise, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination
Otherwise, on a CPU, the equation `Aₙₙ * xₙ = bₙ` is solved using Gaussian elimination
(without pivoting), which makes two passes over the data. This is currently only
implemented for tri-diagonal and penta-diagonal matrices `Aₙₙ`. In Gaussian
elimination, `Aₙₙ` is effectively factorized into the product `Lₙ * Dₙ * Uₙ`,
Expand All @@ -229,6 +229,17 @@ is referred to as "back substitution". These operations can become numerically
unstable when `Aₙₙ` has entries with large disparities in magnitude, but avoiding
this would require swapping the rows of `Aₙₙ` (i.e., replacing `Dₙ` with a
partial pivoting matrix).

For performance reasons, on a GPU different solver is used for tri-diagonal systems
that makes use of parallel cyclic reduction (PCR) method. This is to make better use of
the GPU parallelism and shared memory. Since in this solver we need to launch
a thread per each row of the matrix, it is used only for systems smaller than 512
as to not violate CUDA limitations. Above that size, the code will fall back to
the Gaussian elimination method used for CPU (and may show degraded performance).

PCR method works by recursively decomposing a larger tridiagonal system into two
system of half the size. This process is continued until we are left with a system
of size 1, which can be solved directly.
"""
struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end

Expand Down Expand Up @@ -279,9 +290,15 @@ function run_field_matrix_solver!(
case1 = length(names) == 1
case2 = all(name -> cheap_inv(A[name, name]), names.values)
case3 = any(name -> cheap_inv(A[name, name]), names.values)

# Direct all TridiagonalMatrixRow cases to `single_field_solve` path so
# they can be redirected to a specialised solver
# TODO: Group multiple Tridiagonals to the special 'multiple_field' solver
case4 = any(name -> eltype(A[name, name]) <: TridiagonalMatrixRow, names.values)

# TODO: remove case3 and implement _single_field_solve_diag_matrix_row!
# in multiple_field_solve!
if case1 || case2 || case3
if case1 || case2 || case3 || case4
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
Loading