Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 174 additions & 15 deletions torax/_src/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Calculates Block1DCoeffs for a time step."""

import dataclasses
import functools
import jax
import jax.numpy as jnp
Expand All @@ -28,6 +29,7 @@
from torax._src.fvm import cell_variable
from torax._src.geometry import geometry
from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib
from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
from torax._src.sources import source_profile_builders
from torax._src.sources import source_profiles as source_profiles_lib
Expand Down Expand Up @@ -72,6 +74,9 @@ def __call__(
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
# should be called
explicit_call: bool = False,
pedestal_transition_state: (
pedestal_transition_state_lib.PedestalTransitionState | None
) = None,
) -> block_1d_coeffs.Block1DCoeffs:
"""Returns coefficients given a state x.

Expand All @@ -84,8 +89,8 @@ def __call__(
state x.
geo: The geometry of the system at this time step.
core_profiles: The core profiles of the system at this time step.
prev_core_profiles: The core profiles of the system at the previous
time step.
prev_core_profiles: The core profiles of the system at the previous time
step.
dt: The time step size.
x: The state with cell-grid values of the evolving variables.
explicit_source_profiles: Precomputed explicit source profiles. These
Expand All @@ -104,6 +109,9 @@ def __call__(
explicit_call: If True, then if theta_implicit=1, only a reduced
Block1DCoeffs is calculated since most explicit coefficients will not be
used.
pedestal_transition_state: State for tracking pedestal L-H and H-L
transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
use_formation_model_with_adaptive_source=True. None otherwise.

Returns:
coeffs: The diffusion, convection, etc. coefficients for this state.
Expand Down Expand Up @@ -133,6 +141,7 @@ def __call__(
evolving_names=self.evolving_names,
use_pereverzev=use_pereverzev,
explicit_call=explicit_call,
pedestal_transition_state=pedestal_transition_state,
)


Expand All @@ -145,6 +154,9 @@ def calc_coeffs(
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
explicit_call: bool = False,
pedestal_transition_state: (
pedestal_transition_state_lib.PedestalTransitionState | None
) = None,
) -> block_1d_coeffs.Block1DCoeffs:
"""Calculates Block1DCoeffs for the time step described by `core_profiles`.

Expand All @@ -170,6 +182,9 @@ def calc_coeffs(
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
theta_implicit=1. This saves computation for the default fully implicit
implementation.
pedestal_transition_state: State for tracking pedestal L-H and H-L
transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
use_formation_model_with_adaptive_source=True. None otherwise.

Returns:
coeffs: Block1DCoeffs containing the coefficients at this time step.
Expand All @@ -192,6 +207,7 @@ def calc_coeffs(
physics_models=physics_models,
evolving_names=evolving_names,
use_pereverzev=use_pereverzev,
pedestal_transition_state=pedestal_transition_state,
)


Expand All @@ -210,6 +226,9 @@ def _calc_coeffs_full(
physics_models: physics_models_lib.PhysicsModels,
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
pedestal_transition_state: (
pedestal_transition_state_lib.PedestalTransitionState | None
) = None,
) -> block_1d_coeffs.Block1DCoeffs:
"""See `calc_coeffs` for details."""

Expand Down Expand Up @@ -268,6 +287,7 @@ def _calc_coeffs_full(
core_profiles,
merged_source_profiles,
use_pereverzev,
pedestal_transition_state=pedestal_transition_state,
)
)

Expand Down Expand Up @@ -415,26 +435,75 @@ def _calc_coeffs_full(
runtime_params.pedestal.mode
== pedestal_runtime_params_lib.Mode.ADAPTIVE_SOURCE
):
# Get the pedestal-top target values from the pedestal model.
pedestal_top_values = (
pedestal_model_output.to_internal_boundary_conditions(geo)
)

# Apply ramp scaling if use_formation_model_with_adaptive_source is
# enabled.
if runtime_params.pedestal.use_formation_model_with_adaptive_source:
assert pedestal_transition_state is not None, (
'pedestal_transition_state must not be None when'
' use_formation_model_with_adaptive_source is True.'
)
# Scale the pedestal-top values from the pedestal model by the ramp
# fraction. Will be a no-op in H-mode following the transition_time_width.
internal_boundary_conditions = _apply_transition_ramp_scaling(
pedestal_top_values=pedestal_top_values,
pedestal_transition_state=pedestal_transition_state,
runtime_params=runtime_params,
)
# If in L-mode and the H->L ramp has completed (fraction >= 1.0), skip
# the adaptive source entirely to revert to standard L-mode modeling.
# ramp_fraction will be 1.0 if simulation initialized in L-mode and has
# remained in L-mode, since initial transition_start_time is -inf.
ramp_fraction = _compute_ramp_fraction(
pedestal_transition_state=pedestal_transition_state,
transition_time_width=runtime_params.pedestal.transition_time_width,
t=runtime_params.t,
)
# Skip adaptive source if in L-mode and the H->L ramp has completed.
skip_adaptive_source = ~pedestal_transition_state.in_H_mode & (
ramp_fraction >= 1.0
)
else:
internal_boundary_conditions = pedestal_top_values
skip_adaptive_source = jnp.bool_(False)

def _apply_source():
return internal_boundary_conditions_lib.apply_adaptive_source(
source_T_i=source_i,
source_T_e=source_e,
source_n_e=source_n_e,
source_mat_ii=source_mat_ii,
source_mat_ee=source_mat_ee,
source_mat_nn=source_mat_nn,
runtime_params=runtime_params,
internal_boundary_conditions=internal_boundary_conditions,
)

def _skip_source():
return (
source_i,
source_e,
source_n_e,
source_mat_ii,
source_mat_ee,
source_mat_nn,
)

(
source_i,
source_e,
source_n_e,
source_mat_ii,
source_mat_ee,
source_mat_nn,
) = internal_boundary_conditions_lib.apply_adaptive_source(
source_T_i=source_i,
source_T_e=source_e,
source_n_e=source_n_e,
source_mat_ii=source_mat_ii,
source_mat_ee=source_mat_ee,
source_mat_nn=source_mat_nn,
runtime_params=runtime_params,
# Pedestal contributes an internal boundary condition to the source
# terms at the pedestal top.
internal_boundary_conditions=pedestal_model_output.to_internal_boundary_conditions(
geo
),
) = jax.lax.cond(
skip_adaptive_source,
_skip_source,
_apply_source,
)

# --- Build arguments to solver --- #
Expand Down Expand Up @@ -539,3 +608,93 @@ def _calc_coeffs_reduced(
transient_in_cell=transient_in_cell,
)
return coeffs


def _compute_ramp_fraction(
pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState,
transition_time_width: array_typing.FloatScalar,
t: array_typing.FloatScalar,
) -> array_typing.FloatScalar:
"""Computes the ramp fraction for a pedestal transition.

Returns a value in [0, 1] representing the progress of the current
transition. 0 means the transition just started, 1 means it is complete.

Args:
pedestal_transition_state: Current transition state.
transition_time_width: Duration of the transition ramp.
t: Current simulation time (i.e. t + dt when called from the solver).

Returns:
Ramp fraction clipped to [0, 1].
"""
elapsed = t - pedestal_transition_state.transition_start_time
fraction = elapsed / transition_time_width
return jnp.clip(fraction, 0.0, 1.0)


def _apply_transition_ramp_scaling(
pedestal_top_values: internal_boundary_conditions_lib.InternalBoundaryConditions,
pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState,
runtime_params: runtime_params_lib.RuntimeParams,
) -> internal_boundary_conditions_lib.InternalBoundaryConditions:
"""Applies ramp scaling to internal boundary conditions during transitions.

During an L-H transition, linearly ramps from L-mode values to the H-mode
targets. During an H-L transition, ramps from the H-mode targets back to
the L-mode values.

The L-mode values are stored in the pedestal_transition_state (captured
at the start of an L->H transition). The H-mode targets are the full
pedestal model output.

Args:
pedestal_top_values: Pedestal-top target internal boundary conditions from
the pedestal model.
pedestal_transition_state: Current transition state containing L-mode
baseline values.
runtime_params: Runtime parameters (provides time t and pedestal config).

Returns:
Scaled internal boundary conditions.
"""
ramp_fraction = _compute_ramp_fraction(
pedestal_transition_state=pedestal_transition_state,
transition_time_width=runtime_params.pedestal.transition_time_width,
t=runtime_params.t,
)

# Extract the nonzero pedestal-top values from the IBC. The IBC arrays are
# cell-grid sized with a single nonzero element at the pedestal top. We use
# jnp.max to extract the nonzero value.
h_mode_T_i_ped = jnp.max(pedestal_top_values.T_i)
h_mode_T_e_ped = jnp.max(pedestal_top_values.T_e)
h_mode_n_e_ped = jnp.max(pedestal_top_values.n_e)

l_mode_T_i_ped = pedestal_transition_state.T_i_ped_L_mode
l_mode_T_e_ped = pedestal_transition_state.T_e_ped_L_mode
l_mode_n_e_ped = pedestal_transition_state.n_e_ped_L_mode

# In H-mode: ramp from L-mode to H-mode (L + fraction * (H - L))
# In L-mode (H->L ramp): ramp from H-mode to L-mode (H + fraction * (L - H))
def _lerp(l_val, h_val, frac, in_h_mode):
return jnp.where(
in_h_mode,
l_val + frac * (h_val - l_val), # L->H ramp
h_val + frac * (l_val - h_val), # H->L ramp
)

in_h_mode = pedestal_transition_state.in_H_mode
scaled_T_i = _lerp(l_mode_T_i_ped, h_mode_T_i_ped, ramp_fraction, in_h_mode)
scaled_T_e = _lerp(l_mode_T_e_ped, h_mode_T_e_ped, ramp_fraction, in_h_mode)
scaled_n_e = _lerp(l_mode_n_e_ped, h_mode_n_e_ped, ramp_fraction, in_h_mode)

# Reconstruct IBC with scaled values at the same pedestal-top location.
# The nonzero mask from the original pedestal_top_values gives us the
# location.
return dataclasses.replace(
pedestal_top_values,
T_i=jnp.where(pedestal_top_values.T_i != 0.0, scaled_T_i, 0.0),
T_e=jnp.where(pedestal_top_values.T_e != 0.0, scaled_T_e, 0.0),
n_e=jnp.where(pedestal_top_values.n_e != 0.0, scaled_n_e, 0.0),
)
10 changes: 10 additions & 0 deletions torax/_src/fvm/newton_raphson_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torax._src.fvm import fvm_conversions
from torax._src.fvm import residual_and_loss
from torax._src.geometry import geometry
from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
from torax._src.solver import jax_root_finding
from torax._src.solver import predictor_corrector_method
from torax._src.sources import source_profiles
Expand Down Expand Up @@ -75,6 +76,9 @@ def newton_raphson_solve_block(
delta_reduction_factor: float,
tau_min: float,
log_iterations: bool = False,
pedestal_transition_state: (
pedestal_transition_state_lib.PedestalTransitionState | None
) = None,
) -> tuple[
tuple[cell_variable.CellVariable, ...],
state_module.SolverNumericOutputs,
Expand Down Expand Up @@ -145,6 +149,8 @@ def newton_raphson_solve_block(
routine resets at a lower timestep.
log_iterations: If true, output diagnostic information from within iteration
loop.
pedestal_transition_state: State of the pedestal transition model if using
the formation model with adaptive source.

Returns:
x_new: Tuple, with x_new[i] giving channel i of x at the next time step
Expand All @@ -163,6 +169,7 @@ def newton_raphson_solve_block(
x=x_old,
explicit_source_profiles=explicit_source_profiles,
explicit_call=True,
pedestal_transition_state=pedestal_transition_state,
)

match initial_guess_mode:
Expand All @@ -182,6 +189,7 @@ def newton_raphson_solve_block(
explicit_source_profiles=explicit_source_profiles,
allow_pereverzev=True,
explicit_call=True,
pedestal_transition_state=pedestal_transition_state,
)

# See linear_theta_method.py for comments on the predictor_corrector API
Expand All @@ -199,6 +207,7 @@ def newton_raphson_solve_block(
coeffs_exp=coeffs_exp_linear,
coeffs_callback=coeffs_callback,
explicit_source_profiles=explicit_source_profiles,
pedestal_transition_state=pedestal_transition_state,
)
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
case enums.InitialGuessMode.X_OLD:
Expand All @@ -223,6 +232,7 @@ def newton_raphson_solve_block(
explicit_source_profiles=explicit_source_profiles,
coeffs_old=coeffs_old,
evolving_names=evolving_names,
pedestal_transition_state=pedestal_transition_state,
)

x_root, metadata = jax_root_finding.root_newton_raphson(
Expand Down
10 changes: 10 additions & 0 deletions torax/_src/fvm/optimizer_solve_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

See function docstring for details.
"""

import functools
from typing import TypeAlias

Expand All @@ -32,6 +33,7 @@
from torax._src.fvm import fvm_conversions
from torax._src.fvm import residual_and_loss
from torax._src.geometry import geometry
from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib
from torax._src.solver import predictor_corrector_method
from torax._src.sources import source_profiles

Expand Down Expand Up @@ -63,6 +65,9 @@ def optimizer_solve_block(
initial_guess_mode: enums.InitialGuessMode,
maxiter: int,
tol: float,
pedestal_transition_state: (
pedestal_transition_state_lib.PedestalTransitionState | None
) = None,
) -> tuple[
tuple[cell_variable.CellVariable, ...],
state.SolverNumericOutputs,
Expand Down Expand Up @@ -110,6 +115,9 @@ def optimizer_solve_block(
solver for the optional initial guess from the linear solver.
maxiter: See docstring of `jaxopt.LBFGS`.
tol: See docstring of `jaxopt.LBFGS`.
pedestal_transition_state: State for tracking pedestal L-H and H-L
transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with
use_formation_model_with_adaptive_source=True. None otherwise.

Returns:
x_new: Tuple, with x_new[i] giving channel i of x at the next time step
Expand All @@ -127,6 +135,7 @@ def optimizer_solve_block(
x=x_old,
explicit_source_profiles=explicit_source_profiles,
explicit_call=True,
pedestal_transition_state=pedestal_transition_state,
)

match initial_guess_mode:
Expand All @@ -147,6 +156,7 @@ def optimizer_solve_block(
explicit_source_profiles=explicit_source_profiles,
allow_pereverzev=True,
explicit_call=True,
pedestal_transition_state=pedestal_transition_state,
)
# See linear_theta_method.py for comments on the predictor_corrector API
x_new_guess = convertors.core_profiles_to_solver_x_tuple(
Expand Down
Loading
Loading