diff --git a/torax/_src/fvm/discrete_system.py b/torax/_src/fvm/discrete_system.py index 44858a07a..ddcb12e37 100644 --- a/torax/_src/fvm/discrete_system.py +++ b/torax/_src/fvm/discrete_system.py @@ -23,10 +23,12 @@ newton_raphson_solve_block can capture nonlinear dynamics even when each step is expressed using a matrix multiply. """ + from typing import TypeAlias import jax from jax import numpy as jnp +from torax._src import tridiagonal from torax._src.fvm import block_1d_coeffs from torax._src.fvm import cell_variable from torax._src.fvm import convection_terms @@ -41,11 +43,11 @@ def calc_c( coeffs: Block1DCoeffs, convection_dirichlet_mode: str = 'ghost', convection_neumann_mode: str = 'ghost', -) -> tuple[jax.Array, jax.Array]: - """Calculate C and c such that F = C x + c. +) -> tuple[tridiagonal.BlockTriDiagonal, jax.Array]: + """Calculate banded blocks and vector c such that F = C x + c. - See docstrings for `Block1DCoeff` and `implicit_solve_block` for - more detail. + Returns the block-tridiagonal representation of C. The matrix structure comes + from the 1D FVM stencil: each cell couples to itself and its two neighbors. Args: x: Tuple containing CellVariables for each channel. This function uses only @@ -57,8 +59,10 @@ def calc_c( `neumann_mode` argument. Returns: - c_mat: matrix C, such that F = C x + c - c: the vector c + A tuple of (c_matrix, c_forcing) where: + c_matrix: BlockTriDiagonal with sub/main/super-diagonal blocks. + c_forcing: An array with the terms arising from explicit sources and + boundary conditions. """ d_face = coeffs.d_face @@ -75,72 +79,63 @@ def calc_c( f'but got {x_i.value.shape}.' ) - zero_block = jnp.zeros((num_cells, num_cells)) - zero_row_of_blocks = [zero_block] * num_channels - zero_vec = jnp.zeros((num_cells)) - zero_block_vec = [zero_vec] * num_channels - - # Make a matrix C and vector c that will accumulate contributions from - # diffusion, convection, and source terms. - # C and c are both block structured, with one block per channel. - c_mat = [zero_row_of_blocks.copy() for _ in range(num_channels)] - c = zero_block_vec.copy() - # Add diffusion terms - if d_face is not None: - for i in range(num_channels): - ( - diffusion_mat, - diffusion_vec, - ) = diffusion_terms.make_diffusion_terms( - d_face[i], - x[i], - ) - - c_mat[i][i] += diffusion_mat.to_dense() - c[i] += diffusion_vec + if d_face is None: + c_matrix = tridiagonal.BlockTriDiagonal.zeros(num_cells, num_channels) + c_forcing = jnp.zeros((num_cells, num_channels)) + else: + d_terms = [ + diffusion_terms.make_diffusion_terms(d_face_i, x_i) + for d_face_i, x_i in zip(d_face, x) + ] + # stack the forcing terms along the channel axis (axis=1) + c_forcing = jnp.stack([c_forcing for _, c_forcing in d_terms], axis=1) + c_matrix = tridiagonal.BlockTriDiagonal.from_tridiagonals( + [d_mat for d_mat, _ in d_terms] + ) # Add convection terms if v_face is not None: + conv_terms = [] for i in range(num_channels): # Resolve diffusion to zeros if it is not specified d_face_i = d_face[i] if d_face is not None else None d_face_i = jnp.zeros_like(v_face[i]) if d_face_i is None else d_face_i - - ( - conv_mat, - conv_vec, - ) = convection_terms.make_convection_terms( + conv_mat, conv_forcing = convection_terms.make_convection_terms( v_face[i], d_face_i, x[i], dirichlet_mode=convection_dirichlet_mode, neumann_mode=convection_neumann_mode, ) - - c_mat[i][i] += conv_mat.to_dense() - c[i] += conv_vec + conv_terms.append((conv_mat, conv_forcing)) + # stack the forcing terms along the channel axis (axis=1) + conv_forcing = jnp.stack( + [conv_forcing for _, conv_forcing in conv_terms], axis=1 + ) + c_matrix += tridiagonal.BlockTriDiagonal.from_tridiagonals( + [conv_mat for conv_mat, _ in conv_terms] + ) + c_forcing += conv_forcing # Add implicit source terms if source_mat_cell is not None: + diag = c_matrix.diagonal for i in range(num_channels): for j in range(num_channels): source = source_mat_cell[i][j] if source is not None: - c_mat[i][j] += jnp.diag(source) + diag = diag.at[:, i, j].add(source) + c_matrix = tridiagonal.BlockTriDiagonal( + lower=c_matrix.lower, + diagonal=diag, + upper=c_matrix.upper, + ) # Add explicit source terms - def add(left: jax.Array, right: jax.Array | None): - """Addition with adding None treated as no-op.""" - if right is not None: - return left + right - return left - if source_cell is not None: - c = [add(c_i, source_i) for c_i, source_i in zip(c, source_cell)] - - # Form block structure - c_mat = jnp.block(c_mat) - c = jnp.block(c) + for i in range(num_channels): + if source_cell[i] is not None: + c_forcing = c_forcing.at[:, i].add(source_cell[i]) - return c_mat, c + return c_matrix, c_forcing diff --git a/torax/_src/fvm/fvm_conversions.py b/torax/_src/fvm/fvm_conversions.py index a176c38f7..8ee6c49cc 100644 --- a/torax/_src/fvm/fvm_conversions.py +++ b/torax/_src/fvm/fvm_conversions.py @@ -33,8 +33,7 @@ def cell_variable_tuple_to_vec( Returns: A flat array of evolving state variables. """ - x_vec = jnp.concatenate([x.value for x in x_tuple]) - return x_vec + return jnp.concatenate([x.value for x in x_tuple]) def vec_to_cell_variable_tuple( @@ -77,3 +76,20 @@ def vec_to_cell_variable_tuple( ] return tuple(x_out) + + +def cell_variable_tuple_to_array( + x_tuple: tuple[cell_variable.CellVariable, ...], + axis: int, +) -> jax.Array: + """Converts a tuple of CellVariables to a multi-dimensional array. + + + Args: + x_tuple: A tuple of CellVariables. + axis: The axis along which to stack the CellVariables. + + Returns: + A multi-dimensional array of CellVariables. + """ + return jnp.stack([var.value for var in x_tuple], axis=axis) diff --git a/torax/_src/fvm/implicit_solve_block.py b/torax/_src/fvm/implicit_solve_block.py index 5cae649b7..7c00b3d30 100644 --- a/torax/_src/fvm/implicit_solve_block.py +++ b/torax/_src/fvm/implicit_solve_block.py @@ -19,7 +19,6 @@ import dataclasses import jax -from jax import numpy as jnp from torax._src.fvm import block_1d_coeffs from torax._src.fvm import cell_variable from torax._src.fvm import fvm_conversions @@ -79,10 +78,9 @@ def implicit_solve_block( # or from Picard iterations with predictor-corrector. # See residual_and_loss.theta_method_matrix_equation for a complete # description of how the equation is set up. + x_old_array = fvm_conversions.cell_variable_tuple_to_array(x_old, axis=1) - x_old_vec = fvm_conversions.cell_variable_tuple_to_vec(x_old) - - lhs_mat, lhs_vec, rhs_mat, rhs_vec = ( + lhs_matrix, lhs_vec, rhs_matrix, rhs_vec = ( residual_and_loss.theta_method_matrix_equation( dt=dt, x_old=x_old, @@ -95,16 +93,12 @@ def implicit_solve_block( ) ) - rhs = jnp.dot(rhs_mat, x_old_vec) + rhs_vec - lhs_vec - x_new = jnp.linalg.solve(lhs_mat, rhs) + rhs_result = rhs_matrix.matvec(x_old_array) + rhs_vec - lhs_vec + x_new = lhs_matrix.solve(rhs_result) # Create updated CellVariable instances based on state_plus_dt which has # updated boundary conditions and prescribed profiles. - x_new = jnp.split(x_new, len(x_old)) - out = [ - dataclasses.replace(var, value=value) - for var, value in zip(x_new_guess, x_new) - ] - out = tuple(out) - - return out + return tuple( + dataclasses.replace(var, value=x_new[:, i]) + for i, var in enumerate(x_new_guess) + ) diff --git a/torax/_src/fvm/residual_and_loss.py b/torax/_src/fvm/residual_and_loss.py index ab1425061..c782750b4 100644 --- a/torax/_src/fvm/residual_and_loss.py +++ b/torax/_src/fvm/residual_and_loss.py @@ -30,6 +30,7 @@ from torax._src import jax_utils from torax._src import models as models_lib from torax._src import state +from torax._src import tridiagonal from torax._src.config import runtime_params as runtime_params_lib from torax._src.core_profiles import updaters from torax._src.fvm import block_1d_coeffs @@ -60,8 +61,13 @@ def theta_method_matrix_equation( theta_implicit: float = 1.0, convection_dirichlet_mode: str = 'ghost', convection_neumann_mode: str = 'ghost', -) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: - """Returns the left-hand and right-hand sides of the theta method equation. +) -> tuple[ + tridiagonal.BlockTriDiagonal, + jax.Array, + tridiagonal.BlockTriDiagonal, + jax.Array, +]: + """Returns the banded left-hand and right-hand sides of the theta method. The theta method solves a differential equation @@ -116,23 +122,21 @@ def theta_method_matrix_equation( `neumann_mode` argument. Returns: - For the equation A x_new + a_vec = B x_old + b_vec. This function returns - - left-hand side matrix, A - - left-hand side vector, a - - right-hand side matrix B - - right-hand side vector, b + A tuple of (lhs, lhs_vec, rhs, rhs_vec) where: + lhs_matrix: BlockTriDiagonal for the LHS matrix A. + lhs_vec: LHS vector a. + rhs_matrix: BlockTriDiagonal for the RHS matrix B. + rhs_vec: RHS vector b. """ - x_new_guess_vec = fvm_conversions.cell_variable_tuple_to_vec(x_new_guess) - theta_exp = 1.0 - theta_implicit - tc_in_old = jnp.concatenate(coeffs_old.transient_in_cell) - tc_out_new = jnp.concatenate(coeffs_new.transient_out_cell) - tc_in_new = jnp.concatenate(coeffs_new.transient_in_cell) - chex.assert_rank(tc_in_old, 1) - chex.assert_rank(tc_out_new, 1) - chex.assert_rank(tc_in_new, 1) + tc_in_old = jnp.stack(coeffs_old.transient_in_cell, axis=-1) + tc_out_new = jnp.stack(coeffs_new.transient_out_cell, axis=-1) + tc_in_new = jnp.stack(coeffs_new.transient_in_cell, axis=-1) + chex.assert_rank(tc_in_old, 2) + chex.assert_rank(tc_out_new, 2) + chex.assert_rank(tc_in_new, 2) eps = 1e-7 # adding sanity checks for values in denominators @@ -148,42 +152,63 @@ def theta_method_matrix_equation( msg='|tc_out_new*tc_in_new| unexpectedly < eps', ) - left_transient = jnp.identity(len(x_new_guess_vec)) - right_transient = jnp.diag(jnp.squeeze(tc_in_old / tc_in_new)) + scale_new = dt * theta_implicit / (tc_out_new * tc_in_new) - c_mat_new, c_new = discrete_system.calc_c( + c_new_matrix, c_new_forcing = discrete_system.calc_c( x_new_guess, coeffs_new, convection_dirichlet_mode, convection_neumann_mode, ) - broadcasted = jnp.expand_dims(1 / (tc_out_new * tc_in_new), 1) - - lhs_mat = left_transient - dt * theta_implicit * broadcasted * c_mat_new - lhs_vec = -theta_implicit * dt * (1 / (tc_out_new * tc_in_new)) * c_new + # Compute LHS = I - scale_new * C_new directly, avoiding intermediate + # BlockTriDiagonal objects. The transient part (I) only contributes to the + # diagonal, so off-diagonal blocks are just -scale * C_new. + ch_idx = jnp.arange(len(x_old)) + lhs_diag = -scale_new[:, :, None] * c_new_matrix.diagonal + lhs_diag = lhs_diag.at[:, ch_idx, ch_idx].add(1.0) + lhs_matrix = tridiagonal.BlockTriDiagonal( + lower=-scale_new[1:, :, None] * c_new_matrix.lower, + diagonal=lhs_diag, + upper=-scale_new[:-1, :, None] * c_new_matrix.upper, + ) + lhs_vec = -scale_new * c_new_forcing if theta_exp > 0.0: - tc_out_old = jnp.concatenate(coeffs_old.transient_out_cell) + tc_out_old = jnp.stack(coeffs_old.transient_out_cell, axis=-1) tc_in_new = jax_utils.error_if( tc_in_new, jnp.any(jnp.abs(tc_out_old * tc_in_new) < eps), msg='|tc_out_old*tc_in_new| unexpectedly < eps', ) - c_mat_old, c_old = discrete_system.calc_c( + c_old_matrix, c_old_forcing = discrete_system.calc_c( x_old, coeffs_old, convection_dirichlet_mode, convection_neumann_mode, ) - broadcasted = jnp.expand_dims(1 / (tc_out_old * tc_in_new), 1) - rhs_mat = right_transient + dt * theta_exp * broadcasted * c_mat_old - rhs_vec = dt * theta_exp * (1 / (tc_out_old * tc_in_new)) * c_old + + scale_old = dt * theta_exp / (tc_out_old * tc_in_new) + + # Compute RHS = diag(tc_in_old/tc_in_new) + scale_old * C_old directly. + # The transient part only contributes to the diagonal. + rhs_diag = scale_old[:, :, None] * c_old_matrix.diagonal + rhs_diag = rhs_diag.at[:, ch_idx, ch_idx].add((tc_in_old / tc_in_new)) + rhs_matrix = tridiagonal.BlockTriDiagonal( + lower=scale_old[1:, :, None] * c_old_matrix.lower, + diagonal=rhs_diag, + upper=scale_old[:-1, :, None] * c_old_matrix.upper, + ) + rhs_vec = scale_old * c_old_forcing else: - rhs_mat = right_transient - rhs_vec = jnp.zeros_like(x_new_guess_vec) + rhs_matrix = tridiagonal.BlockTriDiagonal.from_diagonal( + tc_in_old / tc_in_new + ) + rhs_vec = jnp.zeros( + (rhs_matrix.num_blocks, rhs_matrix.block_size), dtype=tc_in_new.dtype + ) - return lhs_mat, lhs_vec, rhs_mat, rhs_vec + return lhs_matrix, lhs_vec, rhs_matrix, rhs_vec @jax.jit( @@ -217,8 +242,8 @@ def theta_method_block_residual( runtime_params_t_plus_dt: Runtime parameters for time t + dt. geo_t_plus_dt: The geometry at time t + dt. x_old: The starting x defined as a tuple of CellVariables. - core_profiles_t: Core plasma profiles which contain all available - prescribed quantities at the start of the time step. + core_profiles_t: Core plasma profiles which contain all available prescribed + quantities at the start of the time step. core_profiles_t_plus_dt: Core plasma profiles which contain all available prescribed quantities at the end of the time step. This includes evolving boundary conditions and prescribed time-dependent profiles that are not @@ -235,7 +260,6 @@ def theta_method_block_residual( Returns: residual: Vector residual between LHS and RHS of the theta method equation. """ - x_old_vec = jnp.concatenate([var.value for var in x_old]) # Prepare core_profiles_t_plus_dt for calc_coeffs. Explanation: # 1. The original (before iterative solving) core_profiles_t_plus_dt contained # updated boundary conditions and prescribed profiles. @@ -267,7 +291,7 @@ def theta_method_block_residual( ) solver_params = runtime_params_t_plus_dt.solver - lhs_mat, lhs_vec, rhs_mat, rhs_vec = theta_method_matrix_equation( + lhs, lhs_vec, rhs, rhs_vec = theta_method_matrix_equation( dt=dt, x_old=x_old, x_new_guess=x_new_guess, @@ -278,11 +302,18 @@ def theta_method_block_residual( convection_neumann_mode=solver_params.convection_neumann_mode, ) - lhs = jnp.dot(lhs_mat, x_new_guess_vec) + lhs_vec - rhs = jnp.dot(rhs_mat, x_old_vec) + rhs_vec + # TODO(b/505253351) Remove the reshape and transpose. + x_old_array = fvm_conversions.cell_variable_tuple_to_array(x_old, axis=1) + # Reshape x_new_guess_vec to a 2D array with shape (num_channels, num_cells) + # then transpose it to (num_cells, num_channels) to allow for block + # tridiagonal matvec multiplication with lhs and rhs. + num_cells, num_channels = x_old_array.shape + x_new_array = x_new_guess_vec.reshape(num_channels, num_cells).T + + lhs_result = lhs.matvec(x_new_array) + lhs_vec + rhs_result = rhs.matvec(x_old_array) + rhs_vec - residual = lhs - rhs - return residual + return (lhs_result - rhs_result).T.reshape(-1) @jax.jit( diff --git a/torax/_src/tests/tridiagonal_test.py b/torax/_src/tests/tridiagonal_test.py index 661f576fc..3c258fdad 100644 --- a/torax/_src/tests/tridiagonal_test.py +++ b/torax/_src/tests/tridiagonal_test.py @@ -58,6 +58,292 @@ def test_tridiag_add(self): np.testing.assert_array_equal(sum_tri_mat.above, expected_above) np.testing.assert_array_equal(sum_tri_mat.below, expected_below) + def test_tridiag_matvec(self): + diag = jnp.array([1.0, 2.0, 3.0]) + above = jnp.array([4.0, 5.0]) + below = jnp.array([6.0, 7.0]) + tri_mat = tridiagonal.TriDiagonal(diag, above, below) + x = jnp.array([1.0, 2.0, 3.0]) + + result = tri_mat.matvec(x) + expected = tri_mat.to_dense() @ x + + np.testing.assert_allclose(result, expected) + + +class BlockTriDiagonalTest(absltest.TestCase): + + def _make_block_tridiag( + self, num_blocks: int, block_size: int + ) -> tridiagonal.BlockTriDiagonal: + """Helper to create a BlockTriDiagonal with deterministic values.""" + rng = np.random.RandomState(42) + lower = jnp.array( + rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64 + ) + diag_blocks = jnp.array( + rng.randn(num_blocks, block_size, block_size), dtype=jnp.float64 + ) + upper = jnp.array( + rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64 + ) + return tridiagonal.BlockTriDiagonal( + lower=lower, diagonal=diag_blocks, upper=upper + ) + + def _make_nonsingular_block_tridiag( + self, num_blocks: int, block_size: int + ) -> tridiagonal.BlockTriDiagonal: + """Helper to create a diagonally-dominant BlockTriDiagonal for solve.""" + rng = np.random.RandomState(0) + lower = jnp.array( + rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64 + ) + upper = jnp.array( + rng.randn(num_blocks - 1, block_size, block_size), dtype=jnp.float64 + ) + # Make diagonal blocks diagonally dominant to ensure non-singularity. + diag_blocks = jnp.array( + rng.randn(num_blocks, block_size, block_size), dtype=jnp.float64 + ) + diag_blocks = diag_blocks + 10.0 * jnp.eye(block_size, dtype=jnp.float64) + return tridiagonal.BlockTriDiagonal( + lower=lower, diagonal=diag_blocks, upper=upper + ) + + def test_num_blocks_and_block_size(self): + bt = self._make_block_tridiag(num_blocks=4, block_size=3) + self.assertEqual(bt.num_blocks, 4) + self.assertEqual(bt.block_size, 3) + + def test_zeros(self): + bt = tridiagonal.BlockTriDiagonal.zeros( + num_blocks=3, block_size=2, dtype=jnp.float64 + ) + np.testing.assert_array_equal(bt.lower, jnp.zeros((2, 2, 2))) + np.testing.assert_array_equal(bt.diagonal, jnp.zeros((3, 2, 2))) + np.testing.assert_array_equal(bt.upper, jnp.zeros((2, 2, 2))) + + def test_from_block_diagonal(self): + vals = jnp.array([ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + ]) + bt = tridiagonal.BlockTriDiagonal.from_block_diagonal(vals) + + np.testing.assert_array_equal(bt.diagonal, vals) + np.testing.assert_array_equal(bt.lower, jnp.zeros((2, 2, 2))) + np.testing.assert_array_equal(bt.upper, jnp.zeros((2, 2, 2))) + + def test_from_diagonal(self): + vals = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + bt = tridiagonal.BlockTriDiagonal.from_diagonal(vals) + + # Each block should be a diagonal matrix. + for i in range(3): + expected_block = jnp.diag(vals[i]) + np.testing.assert_array_equal(bt.diagonal[i], expected_block) + + # Off-diagonals should be zero. + np.testing.assert_array_equal(bt.lower, jnp.zeros((2, 2, 2))) + np.testing.assert_array_equal(bt.upper, jnp.zeros((2, 2, 2))) + + def test_from_tridiagonals(self): + # 3 blocks, 2 channels. + ch0 = tridiagonal.TriDiagonal( + diagonal=jnp.array([1.0, 3.0, 5.0]), + above=jnp.array([7.0, 9.0]), + below=jnp.array([11.0, 13.0]), + ) + ch1 = tridiagonal.TriDiagonal( + diagonal=jnp.array([2.0, 4.0, 6.0]), + above=jnp.array([8.0, 10.0]), + below=jnp.array([12.0, 14.0]), + ) + bt = tridiagonal.BlockTriDiagonal.from_tridiagonals([ch0, ch1]) + + self.assertEqual(bt.num_blocks, 3) + self.assertEqual(bt.block_size, 2) + + # Diagonal blocks should be diagonal matrices. + np.testing.assert_array_equal( + bt.diagonal[0], jnp.diag(jnp.array([1.0, 2.0])) + ) + np.testing.assert_array_equal( + bt.diagonal[1], jnp.diag(jnp.array([3.0, 4.0])) + ) + np.testing.assert_array_equal( + bt.diagonal[2], jnp.diag(jnp.array([5.0, 6.0])) + ) + + # Off-diagonal blocks should also be diagonal matrices. + np.testing.assert_array_equal(bt.upper[0], jnp.diag(jnp.array([7.0, 8.0]))) + np.testing.assert_array_equal(bt.upper[1], jnp.diag(jnp.array([9.0, 10.0]))) + np.testing.assert_array_equal( + bt.lower[0], jnp.diag(jnp.array([11.0, 12.0])) + ) + np.testing.assert_array_equal( + bt.lower[1], jnp.diag(jnp.array([13.0, 14.0])) + ) + + def test_to_dense_single_block(self): + diag = jnp.array([[[1.0, 2.0], [3.0, 4.0]]]) + bt = tridiagonal.BlockTriDiagonal( + lower=jnp.zeros((0, 2, 2)), + diagonal=diag, + upper=jnp.zeros((0, 2, 2)), + ) + dense = bt.to_dense() + np.testing.assert_array_equal(dense, diag[0]) + + def test_to_dense(self): + bt = self._make_block_tridiag(num_blocks=3, block_size=2) + dense = bt.to_dense() + + # Verify shape. + self.assertEqual(dense.shape, (6, 6)) + + # Verify diagonal blocks. + for i in range(3): + np.testing.assert_array_equal( + dense[2 * i : 2 * i + 2, 2 * i : 2 * i + 2], bt.diagonal[i] + ) + + # Verify upper blocks. + for i in range(2): + np.testing.assert_array_equal( + dense[2 * i : 2 * i + 2, 2 * (i + 1) : 2 * (i + 1) + 2], + bt.upper[i], + ) + + # Verify lower blocks. + for i in range(2): + np.testing.assert_array_equal( + dense[2 * (i + 1) : 2 * (i + 1) + 2, 2 * i : 2 * i + 2], + bt.lower[i], + ) + + def test_add(self): + bt1 = self._make_block_tridiag(num_blocks=3, block_size=2) + rng = np.random.RandomState(99) + bt2 = tridiagonal.BlockTriDiagonal( + lower=jnp.array(rng.randn(2, 2, 2), dtype=jnp.float64), + diagonal=jnp.array(rng.randn(3, 2, 2), dtype=jnp.float64), + upper=jnp.array(rng.randn(2, 2, 2), dtype=jnp.float64), + ) + result = bt1 + bt2 + + np.testing.assert_allclose(result.lower, bt1.lower + bt2.lower) + np.testing.assert_allclose(result.diagonal, bt1.diagonal + bt2.diagonal) + np.testing.assert_allclose(result.upper, bt1.upper + bt2.upper) + + def test_add_matches_dense(self): + bt1 = self._make_block_tridiag(num_blocks=3, block_size=2) + rng = np.random.RandomState(99) + bt2 = tridiagonal.BlockTriDiagonal( + lower=jnp.array(rng.randn(2, 2, 2), dtype=jnp.float64), + diagonal=jnp.array(rng.randn(3, 2, 2), dtype=jnp.float64), + upper=jnp.array(rng.randn(2, 2, 2), dtype=jnp.float64), + ) + result = bt1 + bt2 + + np.testing.assert_allclose( + result.to_dense(), bt1.to_dense() + bt2.to_dense() + ) + + def test_matvec(self): + bt = self._make_block_tridiag(num_blocks=4, block_size=3) + rng = np.random.RandomState(7) + x = jnp.array(rng.randn(4, 3), dtype=jnp.float64) + + result = bt.matvec(x) + expected = (bt.to_dense() @ x.flatten()).reshape(4, 3) + + np.testing.assert_allclose(result, expected, atol=1e-12) + + def test_matvec_single_block(self): + diag = jnp.array([[[2.0, 1.0], [0.5, 3.0]]]) + bt = tridiagonal.BlockTriDiagonal( + lower=jnp.zeros((0, 2, 2)), + diagonal=diag, + upper=jnp.zeros((0, 2, 2)), + ) + x = jnp.array([[1.0, 2.0]]) + + result = bt.matvec(x) + expected = jnp.array([[4.0, 6.5]]) + + np.testing.assert_allclose(result, expected) + + def test_solve(self): + bt = self._make_nonsingular_block_tridiag(num_blocks=4, block_size=3) + rng = np.random.RandomState(123) + x_true = jnp.array(rng.randn(4, 3), dtype=jnp.float64) + rhs = bt.matvec(x_true) + + x_solved = bt.solve(rhs) + + np.testing.assert_allclose(x_solved, x_true, atol=1e-10) + + def test_solve_recovers_rhs(self): + """Verify A @ solve(A, b) = b.""" + bt = self._make_nonsingular_block_tridiag(num_blocks=3, block_size=2) + rng = np.random.RandomState(55) + rhs = jnp.array(rng.randn(3, 2), dtype=jnp.float64) + + x = bt.solve(rhs) + reconstructed_rhs = bt.matvec(x) + + np.testing.assert_allclose(reconstructed_rhs, rhs, atol=1e-10) + + def test_from_diagonal_to_dense_matches_manual(self): + """from_diagonal should produce the same dense matrix as manual diag.""" + vals = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + bt = tridiagonal.BlockTriDiagonal.from_diagonal(vals) + dense = bt.to_dense() + + expected = jnp.diag(jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + np.testing.assert_array_equal(dense, expected) + + def test_from_tridiagonals_to_dense_matches_per_channel(self): + """Each channel should form an independent scalar tridiagonal system.""" + ch0 = tridiagonal.TriDiagonal( + diagonal=jnp.array([1.0, 3.0, 5.0]), + above=jnp.array([7.0, 9.0]), + below=jnp.array([11.0, 13.0]), + ) + ch1 = tridiagonal.TriDiagonal( + diagonal=jnp.array([2.0, 4.0, 6.0]), + above=jnp.array([8.0, 10.0]), + below=jnp.array([12.0, 14.0]), + ) + bt = tridiagonal.BlockTriDiagonal.from_tridiagonals([ch0, ch1]) + dense = bt.to_dense() + + # Build per-channel scalar tridiag and interleave. + ch0 = tridiagonal.TriDiagonal( + diagonal=jnp.array([1.0, 3.0, 5.0]), + above=jnp.array([7.0, 9.0]), + below=jnp.array([11.0, 13.0]), + ) + ch1 = tridiagonal.TriDiagonal( + diagonal=jnp.array([2.0, 4.0, 6.0]), + above=jnp.array([8.0, 10.0]), + below=jnp.array([12.0, 14.0]), + ) + d0 = ch0.to_dense() + d1 = ch1.to_dense() + # Channel 0 occupies rows/cols 0, 2, 4 and channel 1 occupies 1, 3, 5. + # But block ordering interleaves them as (ch0, ch1) per block. + expected_full = jnp.zeros((6, 6)) + for r in range(3): + for c in range(3): + expected_full = expected_full.at[2 * r, 2 * c].set(d0[r, c]) + expected_full = expected_full.at[2 * r + 1, 2 * c + 1].set(d1[r, c]) + + np.testing.assert_allclose(dense, expected_full) + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/tridiagonal.py b/torax/_src/tridiagonal.py index f4a20f9ae..42822f76b 100644 --- a/torax/_src/tridiagonal.py +++ b/torax/_src/tridiagonal.py @@ -18,8 +18,11 @@ import jax from jax import numpy as jnp +import jax.scipy.linalg import jaxtyping as jt from torax._src import array_typing +from torax._src import jax_utils +import typing_extensions @jax.tree_util.register_dataclass @@ -44,7 +47,7 @@ def to_dense(self) -> jt.Float[array_typing.Array, 'size size']: + jnp.diag(self.below, -1) ) - def __add__(self, other: 'TriDiagonal') -> 'TriDiagonal': + def __add__(self, other: typing_extensions.Self) -> typing_extensions.Self: return TriDiagonal( diagonal=self.diagonal + other.diagonal, above=self.above + other.above, @@ -60,3 +63,174 @@ def matvec( + jnp.pad(self.above * x[1:], (0, 1)) + jnp.pad(self.below * x[:-1], (1, 0)) ) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class BlockTriDiagonal: + """A block-tridiagonal matrix stored as its three diagonals. + + Attributes: + lower: Sub-diagonal blocks, shape (num_blocks-1, block_size, block_size). + diagonal: Main diagonal blocks, shape (num_blocks, block_size, block_size). + upper: Super-diagonal blocks, shape (num_blocks-1, block_size, block_size). + """ + + lower: jt.Float[array_typing.Array, 'num_blocks-1 block_size block_size'] + diagonal: jt.Float[array_typing.Array, 'num_blocks block_size block_size'] + upper: jt.Float[array_typing.Array, 'num_blocks-1 block_size block_size'] + + @property + def num_blocks(self) -> int: + """Number of blocks in the main diagonal.""" + return self.diagonal.shape[0] + + @property + def block_size(self) -> int: + """Size of each block.""" + return self.diagonal.shape[1] + + def __add__(self, other: typing_extensions.Self) -> typing_extensions.Self: + return BlockTriDiagonal( + lower=self.lower + other.lower, + diagonal=self.diagonal + other.diagonal, + upper=self.upper + other.upper, + ) + + @classmethod + def from_block_diagonal( + cls, + vals: jt.Float[array_typing.Array, 'num_blocks block_size block_size'], + ) -> 'BlockTriDiagonal': + """Creates a block-tridiagonal matrix from diagonal blocks.""" + num_blocks, block_size, _ = vals.shape + off_diag = jnp.zeros( + (num_blocks - 1, block_size, block_size), dtype=vals.dtype + ) + return cls( + lower=off_diag, + diagonal=vals, + upper=off_diag, + ) + + @classmethod + def zeros( + cls, + num_blocks: int, + block_size: int, + dtype: jnp.dtype | None = None, + ) -> 'BlockTriDiagonal': + """Creates a zero block-tridiagonal matrix.""" + dtype = dtype if dtype is not None else jax_utils.get_dtype() + return cls.from_block_diagonal( + vals=jnp.zeros((num_blocks, block_size, block_size), dtype=dtype), + ) + + @classmethod + def from_diagonal( + cls, + vals: jt.Float[array_typing.Array, 'num_blocks block_size'], + ) -> 'BlockTriDiagonal': + """Creates a block-tridiagonal matrix from diagonal blocks.""" + return cls.from_block_diagonal( + vals=vals[..., None, :] * jnp.eye(vals.shape[-1], dtype=vals.dtype), + ) + + @classmethod + def from_tridiagonals( + cls, + tridiagonals: typing_extensions.Iterable[TriDiagonal], + ) -> 'BlockTriDiagonal': + """Creates a BlockTriDiagonal from an iterable of per-channel TriDiagonals. + + Each channel contributes a scalar tridiagonal system placed along the (i, i) + block diagonal. + + Args: + tridiagonals: Iterable of TriDiagonal objects, each of size N. + + Returns: + BlockTriDiagonal with block size C, where each (C, C) block is diagonal. + """ + tridiagonals_seq = tuple(tridiagonals) + stacked = jax.tree.map( + lambda *args: jnp.stack(args, axis=1), *tridiagonals_seq + ) + return cls( + lower=stacked.below[..., None, :] + * jnp.eye(stacked.below.shape[-1], dtype=stacked.below.dtype), + diagonal=stacked.diagonal[..., None, :] + * jnp.eye(stacked.diagonal.shape[-1], dtype=stacked.diagonal.dtype), + upper=stacked.above[..., None, :] + * jnp.eye(stacked.above.shape[-1], dtype=stacked.above.dtype), + ) + + def to_dense(self) -> jt.Float[array_typing.Array, 'total total']: + """Constructs the dense matrix representation. + + Returns: + Dense matrix of shape (num_blocks * block_size, num_blocks * + block_size). + """ + block_size = self.block_size + mat = jax.scipy.linalg.block_diag(*self.diagonal) + if self.num_blocks == 1: + return mat + lower_mat = jnp.pad( + jax.scipy.linalg.block_diag(*self.lower), + ((block_size, 0), (0, block_size)), + ) + upper_mat = jnp.pad( + jax.scipy.linalg.block_diag(*self.upper), + ((0, block_size), (block_size, 0)), + ) + return mat + lower_mat + upper_mat + + def solve( + self, rhs: jt.Float[array_typing.Array, 'num_blocks block_size'] + ) -> jt.Float[array_typing.Array, 'num_blocks block_size']: + """Solves A @ x = rhs. + + Args: + rhs: Right-hand side, shape (num_blocks, block_size). + + Returns: + Solution x, shape (num_blocks, block_size). + """ + return dense_solve(self, rhs) + + def matvec( + self, x: jt.Float[array_typing.Array, 'num_blocks block_size'] + ) -> jt.Float[array_typing.Array, 'num_blocks block_size']: + """Block-tridiagonal matrix-vector multiply: y = A @ x. + + Args: + x: Input vector, shape (num_blocks, block_size). + + Returns: + Result y, shape (num_blocks, block_size). + """ + y_upper = jnp.pad( + jnp.einsum('nij,nj->ni', self.upper, x[1:]), ((0, 1), (0, 0)) + ) + y_lower = jnp.pad( + jnp.einsum('nij,nj->ni', self.lower, x[:-1]), ((1, 0), (0, 0)) + ) + return jnp.einsum('nij,nj->ni', self.diagonal, x) + y_upper + y_lower + + +def dense_solve( + block_tridiag: BlockTriDiagonal, + rhs: jt.Float[array_typing.Array, 'num_blocks block_size'], +) -> jt.Float[array_typing.Array, 'num_blocks block_size']: + """Solves A @ x = rhs using a dense matrix inversion. + + Args: + block_tridiag: Block tridiagonal matrix. + rhs: Right-hand side, shape (num_blocks, block_size). + + Returns: + Solution x, shape (num_blocks, block_size). + """ + x_flat = jax.scipy.linalg.solve(block_tridiag.to_dense(), rhs.flatten()) + return x_flat.reshape((block_tridiag.num_blocks, block_tridiag.block_size))