Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 45 additions & 50 deletions torax/_src/fvm/discrete_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
newton_raphson_solve_block can capture nonlinear dynamics even when
each step is expressed using a matrix multiply.
"""

from typing import TypeAlias

import jax
from jax import numpy as jnp
from torax._src import tridiagonal
from torax._src.fvm import block_1d_coeffs
from torax._src.fvm import cell_variable
from torax._src.fvm import convection_terms
Expand All @@ -41,11 +43,11 @@ def calc_c(
coeffs: Block1DCoeffs,
convection_dirichlet_mode: str = 'ghost',
convection_neumann_mode: str = 'ghost',
) -> tuple[jax.Array, jax.Array]:
"""Calculate C and c such that F = C x + c.
) -> tuple[tridiagonal.BlockTriDiagonal, jax.Array]:
"""Calculate banded blocks and vector c such that F = C x + c.

See docstrings for `Block1DCoeff` and `implicit_solve_block` for
more detail.
Returns the block-tridiagonal representation of C. The matrix structure comes
from the 1D FVM stencil: each cell couples to itself and its two neighbors.

Args:
x: Tuple containing CellVariables for each channel. This function uses only
Expand All @@ -57,8 +59,10 @@ def calc_c(
`neumann_mode` argument.

Returns:
c_mat: matrix C, such that F = C x + c
c: the vector c
A tuple of (c_matrix, c_forcing) where:
c_matrix: BlockTriDiagonal with sub/main/super-diagonal blocks.
c_forcing: An array with the terms arising from explicit sources and
boundary conditions.
"""

d_face = coeffs.d_face
Expand All @@ -75,72 +79,63 @@ def calc_c(
f'but got {x_i.value.shape}.'
)

zero_block = jnp.zeros((num_cells, num_cells))
zero_row_of_blocks = [zero_block] * num_channels
zero_vec = jnp.zeros((num_cells))
zero_block_vec = [zero_vec] * num_channels

# Make a matrix C and vector c that will accumulate contributions from
# diffusion, convection, and source terms.
# C and c are both block structured, with one block per channel.
c_mat = [zero_row_of_blocks.copy() for _ in range(num_channels)]
c = zero_block_vec.copy()

# Add diffusion terms
if d_face is not None:
for i in range(num_channels):
(
diffusion_mat,
diffusion_vec,
) = diffusion_terms.make_diffusion_terms(
d_face[i],
x[i],
)

c_mat[i][i] += diffusion_mat.to_dense()
c[i] += diffusion_vec
if d_face is None:
c_matrix = tridiagonal.BlockTriDiagonal.zeros(num_cells, num_channels)
c_forcing = jnp.zeros((num_cells, num_channels))
else:
d_terms = [
diffusion_terms.make_diffusion_terms(d_face_i, x_i)
for d_face_i, x_i in zip(d_face, x)
]
# stack the forcing terms along the channel axis (axis=1)
c_forcing = jnp.stack([c_forcing for _, c_forcing in d_terms], axis=1)
c_matrix = tridiagonal.BlockTriDiagonal.from_tridiagonals(
[d_mat for d_mat, _ in d_terms]
)

# Add convection terms
if v_face is not None:
conv_terms = []
for i in range(num_channels):
# Resolve diffusion to zeros if it is not specified
d_face_i = d_face[i] if d_face is not None else None
d_face_i = jnp.zeros_like(v_face[i]) if d_face_i is None else d_face_i

(
conv_mat,
conv_vec,
) = convection_terms.make_convection_terms(
conv_mat, conv_forcing = convection_terms.make_convection_terms(
v_face[i],
d_face_i,
x[i],
dirichlet_mode=convection_dirichlet_mode,
neumann_mode=convection_neumann_mode,
)

c_mat[i][i] += conv_mat.to_dense()
c[i] += conv_vec
conv_terms.append((conv_mat, conv_forcing))
# stack the forcing terms along the channel axis (axis=1)
conv_forcing = jnp.stack(
[conv_forcing for _, conv_forcing in conv_terms], axis=1
)
c_matrix += tridiagonal.BlockTriDiagonal.from_tridiagonals(
[conv_mat for conv_mat, _ in conv_terms]
)
c_forcing += conv_forcing

# Add implicit source terms
if source_mat_cell is not None:
diag = c_matrix.diagonal
for i in range(num_channels):
for j in range(num_channels):
source = source_mat_cell[i][j]
if source is not None:
c_mat[i][j] += jnp.diag(source)
diag = diag.at[:, i, j].add(source)
c_matrix = tridiagonal.BlockTriDiagonal(
lower=c_matrix.lower,
diagonal=diag,
upper=c_matrix.upper,
)

# Add explicit source terms
def add(left: jax.Array, right: jax.Array | None):
"""Addition with adding None treated as no-op."""
if right is not None:
return left + right
return left

if source_cell is not None:
c = [add(c_i, source_i) for c_i, source_i in zip(c, source_cell)]

# Form block structure
c_mat = jnp.block(c_mat)
c = jnp.block(c)
for i in range(num_channels):
if source_cell[i] is not None:
c_forcing = c_forcing.at[:, i].add(source_cell[i])

return c_mat, c
return c_matrix, c_forcing
20 changes: 18 additions & 2 deletions torax/_src/fvm/fvm_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def cell_variable_tuple_to_vec(
Returns:
A flat array of evolving state variables.
"""
x_vec = jnp.concatenate([x.value for x in x_tuple])
return x_vec
return jnp.concatenate([x.value for x in x_tuple])


def vec_to_cell_variable_tuple(
Expand Down Expand Up @@ -77,3 +76,20 @@ def vec_to_cell_variable_tuple(
]

return tuple(x_out)


def cell_variable_tuple_to_array(
x_tuple: tuple[cell_variable.CellVariable, ...],
axis: int,
) -> jax.Array:
"""Converts a tuple of CellVariables to a multi-dimensional array.


Args:
x_tuple: A tuple of CellVariables.
axis: The axis along which to stack the CellVariables.

Returns:
A multi-dimensional array of CellVariables.
"""
return jnp.stack([var.value for var in x_tuple], axis=axis)
22 changes: 8 additions & 14 deletions torax/_src/fvm/implicit_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import dataclasses

import jax
from jax import numpy as jnp
from torax._src.fvm import block_1d_coeffs
from torax._src.fvm import cell_variable
from torax._src.fvm import fvm_conversions
Expand Down Expand Up @@ -79,10 +78,9 @@ def implicit_solve_block(
# or from Picard iterations with predictor-corrector.
# See residual_and_loss.theta_method_matrix_equation for a complete
# description of how the equation is set up.
x_old_array = fvm_conversions.cell_variable_tuple_to_array(x_old, axis=1)

x_old_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old)

lhs_mat, lhs_vec, rhs_mat, rhs_vec = (
lhs_matrix, lhs_vec, rhs_matrix, rhs_vec = (
residual_and_loss.theta_method_matrix_equation(
dt=dt,
x_old=x_old,
Expand All @@ -95,16 +93,12 @@ def implicit_solve_block(
)
)

rhs = jnp.dot(rhs_mat, x_old_vec) + rhs_vec - lhs_vec
x_new = jnp.linalg.solve(lhs_mat, rhs)
rhs_result = rhs_matrix.matvec(x_old_array) + rhs_vec - lhs_vec
x_new = lhs_matrix.solve(rhs_result)

# Create updated CellVariable instances based on state_plus_dt which has
# updated boundary conditions and prescribed profiles.
x_new = jnp.split(x_new, len(x_old))
out = [
dataclasses.replace(var, value=value)
for var, value in zip(x_new_guess, x_new)
]
out = tuple(out)

return out
return tuple(
dataclasses.replace(var, value=x_new[:, i])
for i, var in enumerate(x_new_guess)
)
Loading
Loading