diff --git a/torax/_src/fvm/calc_coeffs.py b/torax/_src/fvm/calc_coeffs.py index 4939af2e1..dcb51d608 100644 --- a/torax/_src/fvm/calc_coeffs.py +++ b/torax/_src/fvm/calc_coeffs.py @@ -14,6 +14,7 @@ """Calculates Block1DCoeffs for a time step.""" +import dataclasses import functools import jax import jax.numpy as jnp @@ -28,6 +29,7 @@ from torax._src.fvm import cell_variable from torax._src.geometry import geometry from torax._src.internal_boundary_conditions import internal_boundary_conditions as internal_boundary_conditions_lib +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib from torax._src.sources import source_profile_builders from torax._src.sources import source_profiles as source_profiles_lib @@ -72,6 +74,9 @@ def __call__( # Checks if reduced calc_coeffs for explicit terms when theta_implicit=1 # should be called explicit_call: bool = False, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> block_1d_coeffs.Block1DCoeffs: """Returns coefficients given a state x. @@ -84,8 +89,8 @@ def __call__( state x. geo: The geometry of the system at this time step. core_profiles: The core profiles of the system at this time step. - prev_core_profiles: The core profiles of the system at the previous - time step. + prev_core_profiles: The core profiles of the system at the previous time + step. dt: The time step size. x: The state with cell-grid values of the evolving variables. explicit_source_profiles: Precomputed explicit source profiles. These @@ -104,6 +109,9 @@ def __call__( explicit_call: If True, then if theta_implicit=1, only a reduced Block1DCoeffs is calculated since most explicit coefficients will not be used. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: coeffs: The diffusion, convection, etc. coefficients for this state. @@ -133,6 +141,7 @@ def __call__( evolving_names=self.evolving_names, use_pereverzev=use_pereverzev, explicit_call=explicit_call, + pedestal_transition_state=pedestal_transition_state, ) @@ -145,6 +154,9 @@ def calc_coeffs( evolving_names: tuple[str, ...], use_pereverzev: bool = False, explicit_call: bool = False, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates Block1DCoeffs for the time step described by `core_profiles`. @@ -170,6 +182,9 @@ def calc_coeffs( explicit component of the PDE. Then calculates a reduced Block1DCoeffs if theta_implicit=1. This saves computation for the default fully implicit implementation. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: coeffs: Block1DCoeffs containing the coefficients at this time step. @@ -192,6 +207,7 @@ def calc_coeffs( physics_models=physics_models, evolving_names=evolving_names, use_pereverzev=use_pereverzev, + pedestal_transition_state=pedestal_transition_state, ) @@ -210,6 +226,9 @@ def _calc_coeffs_full( physics_models: physics_models_lib.PhysicsModels, evolving_names: tuple[str, ...], use_pereverzev: bool = False, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> block_1d_coeffs.Block1DCoeffs: """See `calc_coeffs` for details.""" @@ -268,6 +287,7 @@ def _calc_coeffs_full( core_profiles, merged_source_profiles, use_pereverzev, + pedestal_transition_state=pedestal_transition_state, ) ) @@ -415,6 +435,64 @@ def _calc_coeffs_full( runtime_params.pedestal.mode == pedestal_runtime_params_lib.Mode.ADAPTIVE_SOURCE ): + # Get the pedestal-top target values from the pedestal model. + pedestal_top_values = ( + pedestal_model_output.to_internal_boundary_conditions(geo) + ) + + # Apply ramp scaling if use_formation_model_with_adaptive_source is + # enabled. + if runtime_params.pedestal.use_formation_model_with_adaptive_source: + assert pedestal_transition_state is not None, ( + 'pedestal_transition_state must not be None when' + ' use_formation_model_with_adaptive_source is True.' + ) + # Scale the pedestal-top values from the pedestal model by the ramp + # fraction. Will be a no-op in H-mode following the transition_time_width. + internal_boundary_conditions = _apply_transition_ramp_scaling( + pedestal_top_values=pedestal_top_values, + pedestal_transition_state=pedestal_transition_state, + runtime_params=runtime_params, + ) + # If in L-mode and the H->L ramp has completed (fraction >= 1.0), skip + # the adaptive source entirely to revert to standard L-mode modeling. + # ramp_fraction will be 1.0 if simulation initialized in L-mode and has + # remained in L-mode, since initial transition_start_time is -inf. + ramp_fraction = _compute_ramp_fraction( + pedestal_transition_state=pedestal_transition_state, + transition_time_width=runtime_params.pedestal.transition_time_width, + t=runtime_params.t, + ) + # Skip adaptive source if in L-mode and the H->L ramp has completed. + skip_adaptive_source = ~pedestal_transition_state.in_H_mode & ( + ramp_fraction >= 1.0 + ) + else: + internal_boundary_conditions = pedestal_top_values + skip_adaptive_source = jnp.bool_(False) + + def _apply_source(): + return internal_boundary_conditions_lib.apply_adaptive_source( + source_T_i=source_i, + source_T_e=source_e, + source_n_e=source_n_e, + source_mat_ii=source_mat_ii, + source_mat_ee=source_mat_ee, + source_mat_nn=source_mat_nn, + runtime_params=runtime_params, + internal_boundary_conditions=internal_boundary_conditions, + ) + + def _skip_source(): + return ( + source_i, + source_e, + source_n_e, + source_mat_ii, + source_mat_ee, + source_mat_nn, + ) + ( source_i, source_e, @@ -422,19 +500,10 @@ def _calc_coeffs_full( source_mat_ii, source_mat_ee, source_mat_nn, - ) = internal_boundary_conditions_lib.apply_adaptive_source( - source_T_i=source_i, - source_T_e=source_e, - source_n_e=source_n_e, - source_mat_ii=source_mat_ii, - source_mat_ee=source_mat_ee, - source_mat_nn=source_mat_nn, - runtime_params=runtime_params, - # Pedestal contributes an internal boundary condition to the source - # terms at the pedestal top. - internal_boundary_conditions=pedestal_model_output.to_internal_boundary_conditions( - geo - ), + ) = jax.lax.cond( + skip_adaptive_source, + _skip_source, + _apply_source, ) # --- Build arguments to solver --- # @@ -539,3 +608,93 @@ def _calc_coeffs_reduced( transient_in_cell=transient_in_cell, ) return coeffs + + +def _compute_ramp_fraction( + pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState, + transition_time_width: array_typing.FloatScalar, + t: array_typing.FloatScalar, +) -> array_typing.FloatScalar: + """Computes the ramp fraction for a pedestal transition. + + Returns a value in [0, 1] representing the progress of the current + transition. 0 means the transition just started, 1 means it is complete. + + Args: + pedestal_transition_state: Current transition state. + transition_time_width: Duration of the transition ramp. + t: Current simulation time (i.e. t + dt when called from the solver). + + Returns: + Ramp fraction clipped to [0, 1]. + """ + elapsed = t - pedestal_transition_state.transition_start_time + fraction = elapsed / transition_time_width + return jnp.clip(fraction, 0.0, 1.0) + + +def _apply_transition_ramp_scaling( + pedestal_top_values: internal_boundary_conditions_lib.InternalBoundaryConditions, + pedestal_transition_state: pedestal_transition_state_lib.PedestalTransitionState, + runtime_params: runtime_params_lib.RuntimeParams, +) -> internal_boundary_conditions_lib.InternalBoundaryConditions: + """Applies ramp scaling to internal boundary conditions during transitions. + + During an L-H transition, linearly ramps from L-mode values to the H-mode + targets. During an H-L transition, ramps from the H-mode targets back to + the L-mode values. + + The L-mode values are stored in the pedestal_transition_state (captured + at the start of an L->H transition). The H-mode targets are the full + pedestal model output. + + Args: + pedestal_top_values: Pedestal-top target internal boundary conditions from + the pedestal model. + pedestal_transition_state: Current transition state containing L-mode + baseline values. + runtime_params: Runtime parameters (provides time t and pedestal config). + + Returns: + Scaled internal boundary conditions. + """ + ramp_fraction = _compute_ramp_fraction( + pedestal_transition_state=pedestal_transition_state, + transition_time_width=runtime_params.pedestal.transition_time_width, + t=runtime_params.t, + ) + + # Extract the nonzero pedestal-top values from the IBC. The IBC arrays are + # cell-grid sized with a single nonzero element at the pedestal top. We use + # jnp.max to extract the nonzero value. + h_mode_T_i_ped = jnp.max(pedestal_top_values.T_i) + h_mode_T_e_ped = jnp.max(pedestal_top_values.T_e) + h_mode_n_e_ped = jnp.max(pedestal_top_values.n_e) + + l_mode_T_i_ped = pedestal_transition_state.T_i_ped_L_mode + l_mode_T_e_ped = pedestal_transition_state.T_e_ped_L_mode + l_mode_n_e_ped = pedestal_transition_state.n_e_ped_L_mode + + # In H-mode: ramp from L-mode to H-mode (L + fraction * (H - L)) + # In L-mode (H->L ramp): ramp from H-mode to L-mode (H + fraction * (L - H)) + def _lerp(l_val, h_val, frac, in_h_mode): + return jnp.where( + in_h_mode, + l_val + frac * (h_val - l_val), # L->H ramp + h_val + frac * (l_val - h_val), # H->L ramp + ) + + in_h_mode = pedestal_transition_state.in_H_mode + scaled_T_i = _lerp(l_mode_T_i_ped, h_mode_T_i_ped, ramp_fraction, in_h_mode) + scaled_T_e = _lerp(l_mode_T_e_ped, h_mode_T_e_ped, ramp_fraction, in_h_mode) + scaled_n_e = _lerp(l_mode_n_e_ped, h_mode_n_e_ped, ramp_fraction, in_h_mode) + + # Reconstruct IBC with scaled values at the same pedestal-top location. + # The nonzero mask from the original pedestal_top_values gives us the + # location. + return dataclasses.replace( + pedestal_top_values, + T_i=jnp.where(pedestal_top_values.T_i != 0.0, scaled_T_i, 0.0), + T_e=jnp.where(pedestal_top_values.T_e != 0.0, scaled_T_e, 0.0), + n_e=jnp.where(pedestal_top_values.n_e != 0.0, scaled_n_e, 0.0), + ) diff --git a/torax/_src/fvm/newton_raphson_solve_block.py b/torax/_src/fvm/newton_raphson_solve_block.py index d7cf174ab..4e5ec24ad 100644 --- a/torax/_src/fvm/newton_raphson_solve_block.py +++ b/torax/_src/fvm/newton_raphson_solve_block.py @@ -34,6 +34,7 @@ from torax._src.fvm import fvm_conversions from torax._src.fvm import residual_and_loss from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import jax_root_finding from torax._src.solver import predictor_corrector_method from torax._src.sources import source_profiles @@ -75,6 +76,9 @@ def newton_raphson_solve_block( delta_reduction_factor: float, tau_min: float, log_iterations: bool = False, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state_module.SolverNumericOutputs, @@ -145,6 +149,8 @@ def newton_raphson_solve_block( routine resets at a lower timestep. log_iterations: If true, output diagnostic information from within iteration loop. + pedestal_transition_state: State of the pedestal transition model if using + the formation model with adaptive source. Returns: x_new: Tuple, with x_new[i] giving channel i of x at the next time step @@ -163,6 +169,7 @@ def newton_raphson_solve_block( x=x_old, explicit_source_profiles=explicit_source_profiles, explicit_call=True, + pedestal_transition_state=pedestal_transition_state, ) match initial_guess_mode: @@ -182,6 +189,7 @@ def newton_raphson_solve_block( explicit_source_profiles=explicit_source_profiles, allow_pereverzev=True, explicit_call=True, + pedestal_transition_state=pedestal_transition_state, ) # See linear_theta_method.py for comments on the predictor_corrector API @@ -199,6 +207,7 @@ def newton_raphson_solve_block( coeffs_exp=coeffs_exp_linear, coeffs_callback=coeffs_callback, explicit_source_profiles=explicit_source_profiles, + pedestal_transition_state=pedestal_transition_state, ) init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new) case enums.InitialGuessMode.X_OLD: @@ -223,6 +232,7 @@ def newton_raphson_solve_block( explicit_source_profiles=explicit_source_profiles, coeffs_old=coeffs_old, evolving_names=evolving_names, + pedestal_transition_state=pedestal_transition_state, ) x_root, metadata = jax_root_finding.root_newton_raphson( diff --git a/torax/_src/fvm/optimizer_solve_block.py b/torax/_src/fvm/optimizer_solve_block.py index 655838134..9a8b6483e 100644 --- a/torax/_src/fvm/optimizer_solve_block.py +++ b/torax/_src/fvm/optimizer_solve_block.py @@ -15,6 +15,7 @@ See function docstring for details. """ + import functools from typing import TypeAlias @@ -32,6 +33,7 @@ from torax._src.fvm import fvm_conversions from torax._src.fvm import residual_and_loss from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import predictor_corrector_method from torax._src.sources import source_profiles @@ -63,6 +65,9 @@ def optimizer_solve_block( initial_guess_mode: enums.InitialGuessMode, maxiter: int, tol: float, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -110,6 +115,9 @@ def optimizer_solve_block( solver for the optional initial guess from the linear solver. maxiter: See docstring of `jaxopt.LBFGS`. tol: See docstring of `jaxopt.LBFGS`. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: x_new: Tuple, with x_new[i] giving channel i of x at the next time step @@ -127,6 +135,7 @@ def optimizer_solve_block( x=x_old, explicit_source_profiles=explicit_source_profiles, explicit_call=True, + pedestal_transition_state=pedestal_transition_state, ) match initial_guess_mode: @@ -147,6 +156,7 @@ def optimizer_solve_block( explicit_source_profiles=explicit_source_profiles, allow_pereverzev=True, explicit_call=True, + pedestal_transition_state=pedestal_transition_state, ) # See linear_theta_method.py for comments on the predictor_corrector API x_new_guess = convertors.core_profiles_to_solver_x_tuple( diff --git a/torax/_src/fvm/residual_and_loss.py b/torax/_src/fvm/residual_and_loss.py index bddda7ed4..dee8a8b2a 100644 --- a/torax/_src/fvm/residual_and_loss.py +++ b/torax/_src/fvm/residual_and_loss.py @@ -38,6 +38,7 @@ from torax._src.fvm import discrete_system from torax._src.fvm import fvm_conversions from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.sources import source_profiles Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs @@ -205,6 +206,9 @@ def theta_method_block_residual( physics_models: physics_models_lib.PhysicsModels, coeffs_old: Block1DCoeffs, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> jax.Array: """Residual of theta-method equation for core profiles at next time-step. @@ -227,6 +231,8 @@ def theta_method_block_residual( coeffs_old: The coefficients calculated at x_old. evolving_names: The names of variables within the core profiles that should evolve. + pedestal_transition_state: State of the pedestal transition model if using + the formation model with adaptive source. Returns: residual: Vector residual between LHS and RHS of the theta method equation. @@ -259,6 +265,7 @@ def theta_method_block_residual( physics_models=physics_models, evolving_names=evolving_names, use_pereverzev=False, + pedestal_transition_state=pedestal_transition_state, ) solver_params = runtime_params_t_plus_dt.solver diff --git a/torax/_src/fvm/tests/calc_coeffs_test.py b/torax/_src/fvm/tests/calc_coeffs_test.py index eeb2c50f6..cffc98910 100644 --- a/torax/_src/fvm/tests/calc_coeffs_test.py +++ b/torax/_src/fvm/tests/calc_coeffs_test.py @@ -16,9 +16,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax.numpy as jnp from torax._src.config import build_runtime_params from torax._src.core_profiles import initialization from torax._src.fvm import calc_coeffs +from torax._src.internal_boundary_conditions import internal_boundary_conditions +from torax._src.pedestal_model import pedestal_transition_state from torax._src.sources import source_profile_builders from torax._src.test_utils import default_sources from torax._src.torax_pydantic import model_config @@ -159,5 +162,92 @@ def create_coeffs_callback( ) +class TransitionCalculationsTest(parameterized.TestCase): + + def test_pedestal_transition_state_initial_state(self): + state = pedestal_transition_state.PedestalTransitionState.initial_state() + self.assertTrue(jnp.isneginf(state.transition_start_time)) + self.assertEqual(state.T_i_ped_L_mode, 0.0) + self.assertFalse(state.in_H_mode) + + def test_compute_ramp_fraction_very_small_width(self): + state = pedestal_transition_state.PedestalTransitionState( + transition_start_time=jnp.array(1.0), + T_i_ped_L_mode=jnp.array(0.0), + T_e_ped_L_mode=jnp.array(0.0), + n_e_ped_L_mode=jnp.array(0.0), + in_H_mode=jnp.array(True), + rho_norm_ped_top=jnp.array(0.9), + ) + # Very small transition_time_width clips to 1.0 when elapsed > 0. + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1e-10, 1.5), 1.0 + ) + + def test_compute_ramp_fraction_ramp(self): + state = pedestal_transition_state.PedestalTransitionState( + transition_start_time=jnp.array(1.0), + T_i_ped_L_mode=jnp.array(0.0), + T_e_ped_L_mode=jnp.array(0.0), + n_e_ped_L_mode=jnp.array(0.0), + in_H_mode=jnp.array(True), + rho_norm_ped_top=jnp.array(0.9), + ) + # transition_time_width = 1.0. Start at 1.0. + # Clip at both ends + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1.0, 0.5), 0.0 + ) # t < start + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1.0, 1.0), 0.0 + ) # t = start + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1.0, 1.5), 0.5 + ) # t = start + 0.5 + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1.0, 2.0), 1.0 + ) # t = start + 1.0 + self.assertEqual( + calc_coeffs._compute_ramp_fraction(state, 1.0, 2.5), 1.0 + ) # t = start + 1.5 + + def test_apply_transition_ramp_scaling_l_to_h(self): + l_mode_baseline = 1.0 + h_mode_target = 3.0 + + state = pedestal_transition_state.PedestalTransitionState( + transition_start_time=jnp.array(1.0), + T_i_ped_L_mode=jnp.array(l_mode_baseline), + T_e_ped_L_mode=jnp.array(l_mode_baseline), + n_e_ped_L_mode=jnp.array(l_mode_baseline), + in_H_mode=jnp.array(True), # L -> H + rho_norm_ped_top=jnp.array(0.9), + ) + + pedestal_top_values = ( + internal_boundary_conditions.InternalBoundaryConditions( + T_i=jnp.array([0.0, h_mode_target, 0.0]), + T_e=jnp.array([0.0, h_mode_target, 0.0]), + n_e=jnp.array([0.0, h_mode_target, 0.0]), + ) + ) + + class MockPedestalRuntimeParams: + transition_time_width = 1.0 + + class MockRuntimeParams: + pedestal = MockPedestalRuntimeParams() + t = 1.5 # halfway + + scaled_ibc = calc_coeffs._apply_transition_ramp_scaling( # pytype: disable=wrong-arg-types + pedestal_top_values=pedestal_top_values, + pedestal_transition_state=state, + runtime_params=MockRuntimeParams(), + ) + + # Expected: 1.0 + 0.5 * (3.0 - 1.0) = 2.0 + self.assertTrue(jnp.allclose(scaled_ibc.T_i, jnp.array([0.0, 2.0, 0.0]))) + + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/mhd/sawtooth/sawtooth_solver.py b/torax/_src/mhd/sawtooth/sawtooth_solver.py index 398f28ce4..46bff574e 100644 --- a/torax/_src/mhd/sawtooth/sawtooth_solver.py +++ b/torax/_src/mhd/sawtooth/sawtooth_solver.py @@ -24,6 +24,7 @@ from torax._src.core_profiles import convertors from torax._src.fvm import cell_variable from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import solver from torax._src.sources import source_profiles as source_profiles_lib @@ -47,6 +48,9 @@ def _x_new( core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles_lib.SourceProfiles, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -71,6 +75,9 @@ def _x_new( prescribed profiles at time t + crash_dt. explicit_source_profiles: Explicit source profiles at time t. evolving_names: Names of evolving variables. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: Updated tuple of evolving CellVariables from CoreProfiles diff --git a/torax/_src/orchestration/adaptive_step.py b/torax/_src/orchestration/adaptive_step.py index dc8b0a3fb..417e023f7 100644 --- a/torax/_src/orchestration/adaptive_step.py +++ b/torax/_src/orchestration/adaptive_step.py @@ -29,6 +29,7 @@ from torax._src.geometry import geometry from torax._src.geometry import geometry_provider as geometry_provider_lib from torax._src.orchestration import sim_state +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import solver as solver_lib from torax._src.sources import source_profiles as source_profiles_lib @@ -85,6 +86,9 @@ def compute_state( runtime_params_provider: build_runtime_params.RuntimeParamsProvider, geometry_provider: geometry_provider_lib.GeometryProvider, solver: solver_lib.Solver, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[AdaptiveStepState, dict[str, array_typing.IntScalar]]: """Computes the state for attempt i of the adaptive step.""" dt = initial_dt / runtime_params_t.numerics.dt_reduction_factor**i @@ -116,6 +120,7 @@ def compute_state( core_profiles_t=input_state.core_profiles, core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, + pedestal_transition_state=pedestal_transition_state, ) loop_statistics[ 'inner_solver_iterations' diff --git a/torax/_src/orchestration/initial_state.py b/torax/_src/orchestration/initial_state.py index 0475f2851..ffd576117 100644 --- a/torax/_src/orchestration/initial_state.py +++ b/torax/_src/orchestration/initial_state.py @@ -30,6 +30,7 @@ from torax._src.orchestration import step_function from torax._src.output_tools import output from torax._src.output_tools import post_processing +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.sources import source_profile_builders from torax._src.torax_pydantic import file_restart as file_restart_pydantic_model from torax._src.transport_model import transport_coefficients_builder @@ -135,6 +136,27 @@ def _get_initial_state( else: edge_outputs = None + # Initialize the pedestal transition state if the feature is enabled. + if runtime_params.pedestal.use_formation_model_with_adaptive_source: + # Call the pedestal model to get the dynamically computed rho_norm_ped_top. + # This ensures models like EPEDNN that compute the pedestal top location + # are properly supported from initialization. + pedestal_model_output = physics_models.pedestal_model( + runtime_params, geo, initial_core_profiles, initial_core_sources, + ) + rho_norm_ped_top = pedestal_model_output.rho_norm_ped_top + ped_top_idx = jnp.argmin(jnp.abs(geo.rho_norm - rho_norm_ped_top)) + pedestal_transition_state = ( + pedestal_transition_state_lib.PedestalTransitionState.initial_state( + T_i_ped=initial_core_profiles.T_i.value[ped_top_idx], + T_e_ped=initial_core_profiles.T_e.value[ped_top_idx], + n_e_ped=initial_core_profiles.n_e.value[ped_top_idx], + rho_norm_ped_top=rho_norm_ped_top, + ) + ) + else: + pedestal_transition_state = None + transport_coeffs = ( transport_coefficients_builder.calculate_all_transport_coeffs( physics_models.pedestal_model, @@ -144,6 +166,7 @@ def _get_initial_state( geo, initial_core_profiles, initial_core_sources, + pedestal_transition_state=pedestal_transition_state, ) ) @@ -163,6 +186,7 @@ def _get_initial_state( ), geometry=geo, edge_outputs=edge_outputs, + pedestal_transition_state=pedestal_transition_state, ) diff --git a/torax/_src/orchestration/sawtooth_step.py b/torax/_src/orchestration/sawtooth_step.py index 851c850a0..1632d5948 100644 --- a/torax/_src/orchestration/sawtooth_step.py +++ b/torax/_src/orchestration/sawtooth_step.py @@ -31,6 +31,7 @@ from torax._src.orchestration import sim_state from torax._src.orchestration import step_function_processing from torax._src.output_tools import post_processing +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.physics import formulas from torax._src.sources import source_profiles as source_profiles_lib @@ -54,6 +55,9 @@ def sawtooth_step( edge_outputs: edge_base.EdgeModelOutputs | None, input_state: sim_state.SimState, input_post_processed_outputs: post_processing.PostProcessedOutputs, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[sim_state.SimState, post_processing.PostProcessedOutputs]: """Checks for and handles a sawtooth crash. @@ -75,6 +79,9 @@ def sawtooth_step( edge_outputs: Explicit edge outputs at time t. input_state: State at the start of the time step. input_post_processed_outputs: Post-processed outputs from the previous step. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: Returns a tuple (output_state, post_processed_outputs). @@ -152,6 +159,7 @@ def _make_post_crash_state_and_post_processed_outputs(): physics_models=sawtooth_solver.physics_models, evolving_names=runtime_params_t.numerics.evolving_names, input_post_processed_outputs=input_post_processed_outputs, + pedestal_transition_state=pedestal_transition_state, ) return jax.lax.cond( diff --git a/torax/_src/orchestration/step_function.py b/torax/_src/orchestration/step_function.py index 1b5b5d892..5bb1f7881 100644 --- a/torax/_src/orchestration/step_function.py +++ b/torax/_src/orchestration/step_function.py @@ -36,6 +36,7 @@ from torax._src.orchestration import step_function_processing from torax._src.orchestration import whilei_loop from torax._src.output_tools import post_processing +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import solver as solver_lib from torax._src.sources import source_profiles as source_profiles_lib from torax._src.time_step_calculator import time_step_calculator as ts @@ -222,13 +223,17 @@ def __call__( runtime_params_overrides or self.runtime_params_provider ) geometry_provider = geo_overrides or self._geometry_provider - runtime_params_t, geo_t, explicit_source_profiles, edge_outputs = ( - step_function_processing.pre_step( - input_state=input_state, - runtime_params_provider=runtime_params_provider, - geometry_provider=geometry_provider, - physics_models=self._solver.physics_models, - ) + ( + runtime_params_t, + geo_t, + explicit_source_profiles, + edge_outputs, + pedestal_transition_state, + ) = step_function_processing.pre_step( + input_state=input_state, + runtime_params_provider=runtime_params_provider, + geometry_provider=geometry_provider, + physics_models=self._solver.physics_models, ) def _step(): @@ -243,6 +248,7 @@ def _step(): previous_post_processed_outputs, runtime_params_provider, geometry_provider, + pedestal_transition_state, ) # If adaptive dt is enabled, take the adaptive step if the max_dt is # greater than the min_dt, otherwise take the fixed step. @@ -267,6 +273,7 @@ def _step(): previous_post_processed_outputs, runtime_params_provider, geometry_provider, + pedestal_transition_state, ) output_state, post_processed_outputs = jax.lax.cond( @@ -377,6 +384,9 @@ def _sawtooth_step( previous_post_processed_outputs: post_processing.PostProcessedOutputs, runtime_params_provider: build_runtime_params.RuntimeParamsProvider, geometry_provider: geometry_provider_lib.GeometryProvider, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ sim_state.SimState, post_processing.PostProcessedOutputs, @@ -406,6 +416,7 @@ def _sawtooth_step_fn(): edge_outputs=edge_outputs, input_state=input_state, input_post_processed_outputs=previous_post_processed_outputs, + pedestal_transition_state=pedestal_transition_state, ) # If a sawtooth crash is not triggered for any reason,the input @@ -430,6 +441,9 @@ def _adaptive_step( previous_post_processed_outputs: post_processing.PostProcessedOutputs, runtime_params_provider: build_runtime_params.RuntimeParamsProvider, geometry_provider: geometry_provider_lib.GeometryProvider, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ sim_state.SimState, post_processing.PostProcessedOutputs, @@ -457,7 +471,11 @@ def _adaptive_step( result = whilei_loop.whilei_loop( adaptive_step.cond_fun, - functools.partial(adaptive_step.compute_state, solver=self.solver), + functools.partial( + adaptive_step.compute_state, + solver=self.solver, + pedestal_transition_state=pedestal_transition_state, + ), ( initial_state, initial_loop_stats, @@ -504,6 +522,7 @@ def _adaptive_step( physics_models=self._solver.physics_models, evolving_names=evolving_names, input_post_processed_outputs=previous_post_processed_outputs, + pedestal_transition_state=pedestal_transition_state, ) ) return output_state, post_processed_outputs @@ -519,6 +538,9 @@ def _fixed_step( previous_post_processed_outputs: post_processing.PostProcessedOutputs, runtime_params_provider: build_runtime_params.RuntimeParamsProvider, geometry_provider: geometry_provider_lib.GeometryProvider, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ sim_state.SimState, post_processing.PostProcessedOutputs, @@ -561,6 +583,7 @@ def _fixed_step( core_profiles_t=input_state.core_profiles, core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, + pedestal_transition_state=pedestal_transition_state, ) output_state, post_processed_outputs = ( step_function_processing.finalize_outputs( @@ -577,6 +600,7 @@ def _fixed_step( physics_models=self._solver.physics_models, evolving_names=runtime_params_t.numerics.evolving_names, input_post_processed_outputs=previous_post_processed_outputs, + pedestal_transition_state=pedestal_transition_state, ) ) return output_state, post_processed_outputs diff --git a/torax/_src/orchestration/step_function_processing.py b/torax/_src/orchestration/step_function_processing.py index 9a9c29f1a..d5a0195ad 100644 --- a/torax/_src/orchestration/step_function_processing.py +++ b/torax/_src/orchestration/step_function_processing.py @@ -17,6 +17,7 @@ import dataclasses import functools import jax +import jax.numpy as jnp from torax._src import physics_models as physics_models_lib from torax._src import state from torax._src.config import build_runtime_params @@ -28,10 +29,138 @@ from torax._src.geometry import geometry_provider as geometry_provider_lib from torax._src.orchestration import sim_state from torax._src.output_tools import post_processing +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib +from torax._src.pedestal_model.formation import power_scaling_formation_model as power_scaling_formation_model_lib +from torax._src.physics import scaling_laws from torax._src.sources import source_profile_builders from torax._src.sources import source_profiles as source_profiles_lib from torax._src.transport_model import transport_coefficients_builder +# pylint: disable=invalid-name + + +def _update_pedestal_transition_state( + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState + ), + runtime_params: runtime_params_lib.RuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + core_sources: source_profiles_lib.SourceProfiles, + physics_models: physics_models_lib.PhysicsModels, +) -> pedestal_transition_state_lib.PedestalTransitionState: + """Evaluates P_SOL vs P_LH and updates the pedestal transition state. + + Called once per timestep in pre_step. Determines whether the plasma should + enter or exit H-mode based on the power crossing the separatrix (P_SOL) + compared to the L-H transition threshold power (P_LH). + + When transitioning from L-mode to H-mode (P_SOL > P_LH): + - Sets in_H_mode to True + - Records the current simulation time as transition_start_time + - Captures current pedestal-top profile values as L-mode baselines + + When transitioning from H-mode to L-mode (P_SOL < P_LH): + - Sets in_H_mode to False + - Records the current simulation time as transition_start_time + - Preserves the existing L-mode baselines (from the most recent L→H start) + + Args: + pedestal_transition_state: Current transition state from previous timestep. + runtime_params: Runtime parameters at time t. + geo: Geometry at time t. + core_profiles: Core plasma profiles at time t. + core_sources: Source profiles at time t. + physics_models: Physics models (used to access formation model config). + + Returns: + Updated PedestalTransitionState. + """ + formation_model = physics_models.pedestal_model.formation_model + assert isinstance( + formation_model, + power_scaling_formation_model_lib.PowerScalingFormationModel, + ), ( + 'use_formation_model_with_adaptive_source requires a' + f' PowerScalingFormationModel, got {type(formation_model)}' + ) + assert isinstance( + runtime_params.pedestal.formation, + power_scaling_formation_model_lib.PowerScalingFormationRuntimeParams, + ) + + # Calculate P_SOL (total power crossing the separatrix). + P_SOL = power_scaling_formation_model_lib.calculate_P_SOL_total( + internal_plasma_energy=core_profiles.internal_plasma_energy, + core_sources=core_sources, + geo=geo, + ) + + # Calculate P_LH (L-H transition threshold power). + P_LH, _ = scaling_laws.calculate_P_LH( + geo=geo, + core_profiles=core_profiles, + scaling_law=formation_model.scaling_law, + divertor_configuration=formation_model.divertor_configuration, + ) + # Apply the user-specified prefactor. + P_LH = P_LH * runtime_params.pedestal.formation.P_LH_prefactor + + # Determine transition direction. + should_enter_h = ~pedestal_transition_state.in_H_mode & (P_SOL > P_LH) + should_exit_h = pedestal_transition_state.in_H_mode & (P_SOL < P_LH) + transitioning = should_enter_h | should_exit_h + + # Determine the updated H-mode state. + in_H_mode = jnp.where( + should_enter_h, + True, + jnp.where(should_exit_h, False, pedestal_transition_state.in_H_mode), + ) + + # Record transition start time when a mode change occurs. + t = runtime_params.t + new_transition_start_time = jnp.where( + transitioning, t, pedestal_transition_state.transition_start_time + ) + + # Capture current pedestal-top values as L-mode baselines when entering + # H-mode. The pedestal top location is stored in the transition state and + # updated at the end of each timestep from the pedestal model output, so + # models like EPEDNN that compute rho_norm_ped_top dynamically are supported. + rho_norm_ped_top = pedestal_transition_state.rho_norm_ped_top + ped_top_idx = jnp.argmin(jnp.abs(geo.rho_norm - rho_norm_ped_top)) + + current_T_i_at_ped = core_profiles.T_i.value[ped_top_idx] + current_T_e_at_ped = core_profiles.T_e.value[ped_top_idx] + current_n_e_at_ped = core_profiles.n_e.value[ped_top_idx] + + # Only update L-mode baselines when entering H-mode. + new_T_i_ped_L_mode = jnp.where( + should_enter_h, + current_T_i_at_ped, + pedestal_transition_state.T_i_ped_L_mode, + ) + new_T_e_ped_L_mode = jnp.where( + should_enter_h, + current_T_e_at_ped, + pedestal_transition_state.T_e_ped_L_mode, + ) + new_n_e_ped_L_mode = jnp.where( + should_enter_h, + current_n_e_at_ped, + pedestal_transition_state.n_e_ped_L_mode, + ) + + return pedestal_transition_state_lib.PedestalTransitionState( + in_H_mode=in_H_mode, + transition_start_time=new_transition_start_time, + T_i_ped_L_mode=new_T_i_ped_L_mode, + T_e_ped_L_mode=new_T_e_ped_L_mode, + n_e_ped_L_mode=new_n_e_ped_L_mode, + rho_norm_ped_top=rho_norm_ped_top, + ) + def pre_step( input_state: sim_state.SimState, @@ -43,6 +172,7 @@ def pre_step( geometry.Geometry, source_profiles_lib.SourceProfiles, edge_base.EdgeModelOutputs | None, + pedestal_transition_state_lib.PedestalTransitionState | None, ]: """Performs the pre-step operations for the step function.""" runtime_params_t, geo_t = ( @@ -95,7 +225,39 @@ def pre_step( else: edge_outputs = None - return runtime_params_t, geo_t, explicit_source_profiles, edge_outputs + # Update pedestal transition state if use_formation_model_with_adaptive_source + # is enabled. + pedestal_transition_state = input_state.pedestal_transition_state + if runtime_params_t.pedestal.use_formation_model_with_adaptive_source: + assert pedestal_transition_state is not None, ( + 'pedestal_transition_state must not be None when' + ' use_formation_model_with_adaptive_source is True.' + ) + # Merge explicit sources with previous implicit sources for accurate + # P_SOL calculation (same pattern as the edge model above). + merged_sources = dataclasses.replace( + input_state.core_sources, + T_e=input_state.core_sources.T_e | explicit_source_profiles.T_e, + T_i=input_state.core_sources.T_i | explicit_source_profiles.T_i, + n_e=input_state.core_sources.n_e | explicit_source_profiles.n_e, + psi=input_state.core_sources.psi | explicit_source_profiles.psi, + ) + pedestal_transition_state = _update_pedestal_transition_state( + pedestal_transition_state=pedestal_transition_state, + runtime_params=runtime_params_t, + geo=geo_t, + core_profiles=input_state.core_profiles, + core_sources=merged_sources, + physics_models=physics_models, + ) + + return ( + runtime_params_t, + geo_t, + explicit_source_profiles, + edge_outputs, + pedestal_transition_state, + ) @functools.partial( @@ -119,6 +281,9 @@ def finalize_outputs( physics_models: physics_models_lib.PhysicsModels, evolving_names: tuple[str, ...], input_post_processed_outputs: post_processing.PostProcessedOutputs, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[sim_state.SimState, post_processing.PostProcessedOutputs]: """Returns the final state and post-processed outputs.""" final_core_profiles, final_source_profiles = ( @@ -144,9 +309,26 @@ def finalize_outputs( geometry_t_plus_dt, final_core_profiles, final_source_profiles, + pedestal_transition_state=pedestal_transition_state, ) ) + # Update rho_norm_ped_top in the pedestal transition state from the pedestal + # model output at t+dt. This ensures that models which compute + # rho_norm_ped_top dynamically (e.g. EPEDNN) propagate their value to the + # next timestep's pre_step for accurate L-mode baseline extraction. + if pedestal_transition_state is not None: + pedestal_model_output = physics_models.pedestal_model( + runtime_params_t_plus_dt, + geometry_t_plus_dt, + final_core_profiles, + final_source_profiles, + ) + pedestal_transition_state = dataclasses.replace( + pedestal_transition_state, + rho_norm_ped_top=pedestal_model_output.rho_norm_ped_top, + ) + output_state = sim_state.SimState( t=t + dt, dt=dt, @@ -156,6 +338,7 @@ def finalize_outputs( geometry=geometry_t_plus_dt, solver_numeric_outputs=solver_numeric_outputs, edge_outputs=edge_outputs, + pedestal_transition_state=pedestal_transition_state, ) post_processed_outputs = post_processing.make_post_processed_outputs( sim_state=output_state, @@ -163,3 +346,4 @@ def finalize_outputs( previous_post_processed_outputs=input_post_processed_outputs, ) return output_state, post_processed_outputs + diff --git a/torax/_src/pedestal_model/formation/power_scaling_formation_model.py b/torax/_src/pedestal_model/formation/power_scaling_formation_model.py index dfb032519..146175a08 100644 --- a/torax/_src/pedestal_model/formation/power_scaling_formation_model.py +++ b/torax/_src/pedestal_model/formation/power_scaling_formation_model.py @@ -40,7 +40,7 @@ class PowerScalingFormationRuntimeParams( P_LH_prefactor: array_typing.FloatScalar = 1.0 -def _calculate_P_SOL_total( +def calculate_P_SOL_total( internal_plasma_energy: state.PlasmaInternalEnergy, core_sources: source_profiles_lib.SourceProfiles, geo: geometry.Geometry, @@ -84,7 +84,7 @@ def __call__( runtime_params.pedestal.formation, PowerScalingFormationRuntimeParams ) - P_SOL_total = _calculate_P_SOL_total( + P_SOL_total = calculate_P_SOL_total( core_profiles.internal_plasma_energy, core_sources, geo ) diff --git a/torax/_src/pedestal_model/formation/tests/power_scaling_formation_model_test.py b/torax/_src/pedestal_model/formation/tests/power_scaling_formation_model_test.py index 5df9919e2..59c741b30 100644 --- a/torax/_src/pedestal_model/formation/tests/power_scaling_formation_model_test.py +++ b/torax/_src/pedestal_model/formation/tests/power_scaling_formation_model_test.py @@ -56,7 +56,7 @@ def setUp(self): self.runtime_params = step_fn.runtime_params_provider(t=0.0) def test_calculate_P_SOL_total(self): - P_SOL_total = power_scaling_formation_model._calculate_P_SOL_total( + P_SOL_total = power_scaling_formation_model.calculate_P_SOL_total( self.initial_state.core_profiles.internal_plasma_energy, self.initial_state.core_sources, self.initial_state.geometry, diff --git a/torax/_src/pedestal_model/pedestal_transition_state.py b/torax/_src/pedestal_model/pedestal_transition_state.py index 21f4bcbd2..95f3b7ef3 100644 --- a/torax/_src/pedestal_model/pedestal_transition_state.py +++ b/torax/_src/pedestal_model/pedestal_transition_state.py @@ -43,6 +43,10 @@ class PedestalTransitionState: start of the most recent L-mode to H-mode transition [keV]. n_e_ped_L_mode: Electron density at the pedestal top captured at the start of the most recent L-mode to H-mode transition [m^-3]. + rho_norm_ped_top: Location of the pedestal top in normalized rho. Persisted + across timesteps so that models which compute rho_norm_ped_top dynamically + (e.g. EPEDNN) can propagate it to the next pre_step for L-mode baseline + extraction. """ in_H_mode: array_typing.BoolScalar @@ -52,15 +56,37 @@ class PedestalTransitionState: T_i_ped_L_mode: array_typing.FloatScalar T_e_ped_L_mode: array_typing.FloatScalar n_e_ped_L_mode: array_typing.FloatScalar + rho_norm_ped_top: array_typing.FloatScalar @classmethod - def initial_state(cls) -> typing_extensions.Self: - """Creates an initial transition state starting in L-mode.""" + def initial_state( + cls, + T_i_ped: float = 0.0, + T_e_ped: float = 0.0, + n_e_ped: float = 0.0, + rho_norm_ped_top: float = 0.9, + ) -> typing_extensions.Self: + """Creates an initial transition state starting in L-mode. + + Args: + T_i_ped: Ion temperature at the pedestal top from initial conditions + [keV]. + T_e_ped: Electron temperature at the pedestal top from initial conditions + [keV]. + n_e_ped: Electron density at the pedestal top from initial conditions + [m^-3]. + rho_norm_ped_top: Initial pedestal top location in normalized rho. + + Returns: + A PedestalTransitionState initialized in L-mode with the given pedestal + top values as L-mode baselines. + """ dtype = jax_utils.get_dtype() return cls( in_H_mode=jnp.bool_(False), transition_start_time=jnp.array(-jnp.inf, dtype=dtype), - T_i_ped_L_mode=jnp.array(0.0, dtype=dtype), - T_e_ped_L_mode=jnp.array(0.0, dtype=dtype), - n_e_ped_L_mode=jnp.array(0.0, dtype=dtype), + T_i_ped_L_mode=jnp.array(T_i_ped, dtype=dtype), + T_e_ped_L_mode=jnp.array(T_e_ped, dtype=dtype), + n_e_ped_L_mode=jnp.array(n_e_ped, dtype=dtype), + rho_norm_ped_top=jnp.array(rho_norm_ped_top, dtype=dtype), ) diff --git a/torax/_src/pedestal_model/pydantic_model.py b/torax/_src/pedestal_model/pydantic_model.py index 1ce36255a..615ec4829 100644 --- a/torax/_src/pedestal_model/pydantic_model.py +++ b/torax/_src/pedestal_model/pydantic_model.py @@ -258,7 +258,7 @@ class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC): use_formation_model_with_adaptive_source: Annotated[ bool, torax_pydantic.JAX_STATIC ] = False - transition_time_width: torax_pydantic.TimeVaryingScalar = ( + transition_time_width: torax_pydantic.PositiveTimeVaryingScalar = ( torax_pydantic.ValidatedDefault(0.5) ) formation_model: FormationConfig = torax_pydantic.ValidatedDefault( @@ -301,6 +301,18 @@ def _defaults(cls, data: dict[str, Any]) -> dict[str, Any]: return configurable_data + @pydantic.model_validator(mode="after") + def _check_source_mode(self) -> "BasePedestal": + if ( + self.use_formation_model_with_adaptive_source + and self.mode != runtime_params.Mode.ADAPTIVE_SOURCE + ): + raise ValueError( + "use_formation_model_with_adaptive_source can only be True when mode" + " is ADAPTIVE_SOURCE" + ) + return self + @abc.abstractmethod def build_pedestal_model(self) -> pedestal_model.PedestalModel: """Builds the pedestal model.""" diff --git a/torax/_src/pedestal_model/tests/pydantic_model_test.py b/torax/_src/pedestal_model/tests/pydantic_model_test.py index aba4a8a1d..366c4fe1e 100644 --- a/torax/_src/pedestal_model/tests/pydantic_model_test.py +++ b/torax/_src/pedestal_model/tests/pydantic_model_test.py @@ -16,6 +16,7 @@ import jax from torax._src import jax_utils from torax._src.pedestal_model import pydantic_model +from torax._src.pedestal_model import runtime_params class PedestalModelPydanticTest(parameterized.TestCase): @@ -34,18 +35,49 @@ def test_build_and_call_model( def f(x: pydantic_model.BasePedestal): return x.build_runtime_params(t=0.0) - with self.subTest('first_jit_compiles_and_returns_expected_value'): + with self.subTest("first_jit_compiles_and_returns_expected_value"): output = f(pedestal_model) self.assertIsInstance(output, pydantic_model.runtime_params.RuntimeParams) self.assertFalse(output.set_pedestal) self.assertEqual(jax_utils.get_number_of_compiles(f), 1) - with self.subTest('second_jit_updates_value_without_recompile'): - pedestal_model._update_fields({'set_pedestal': True}) + with self.subTest("second_jit_updates_value_without_recompile"): + pedestal_model._update_fields({"set_pedestal": True}) output = f(pedestal_model) self.assertTrue(output.set_pedestal) self.assertEqual(jax_utils.get_number_of_compiles(f), 1) + def test_source_mode_validation(self): -if __name__ == '__main__': + # This should be fine (default mode is ADAPTIVE_SOURCE) + pydantic_model.SetTpedNped.from_dict( + {"use_formation_model_with_adaptive_source": True} + ) + + # This should fail because mode is ADAPTIVE_TRANSPORT but use_formation + # model_with_adaptive_source is True + with self.assertRaisesRegex( + ValueError, + "use_formation_model_with_adaptive_source can only be True when mode" + " is ADAPTIVE_SOURCE", + ): + pydantic_model.SetTpedNped.from_dict({ + "use_formation_model_with_adaptive_source": True, + "mode": runtime_params.Mode.ADAPTIVE_TRANSPORT, + }) + + def test_transition_time_width_must_be_positive(self): + # Positive values should be fine. + pydantic_model.SetTpedNped.from_dict({"transition_time_width": 0.5}) + + # Zero should fail. + with self.assertRaises(ValueError): + pydantic_model.SetTpedNped.from_dict({"transition_time_width": 0.0}) + + # Negative should fail. + with self.assertRaises(ValueError): + pydantic_model.SetTpedNped.from_dict({"transition_time_width": -1.0}) + + +if __name__ == "__main__": absltest.main() diff --git a/torax/_src/solver/linear_theta_method.py b/torax/_src/solver/linear_theta_method.py index fc49c88e1..56bfcb5ad 100644 --- a/torax/_src/solver/linear_theta_method.py +++ b/torax/_src/solver/linear_theta_method.py @@ -24,6 +24,7 @@ from torax._src.fvm import calc_coeffs from torax._src.fvm import cell_variable from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import predictor_corrector_method from torax._src.solver import solver as solver_lib from torax._src.sources import source_profiles @@ -50,6 +51,9 @@ def _x_new( core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -80,6 +84,7 @@ def _x_new( explicit_source_profiles=explicit_source_profiles, allow_pereverzev=True, explicit_call=True, + pedestal_transition_state=pedestal_transition_state, ) # Calculate x_new with the predictor corrector method. Reverts to a @@ -97,6 +102,7 @@ def _x_new( coeffs_exp=coeffs_exp, coeffs_callback=coeffs_callback, explicit_source_profiles=explicit_source_profiles, + pedestal_transition_state=pedestal_transition_state, ) if runtime_params_t_plus_dt.solver.use_predictor_corrector: diff --git a/torax/_src/solver/nonlinear_theta_method.py b/torax/_src/solver/nonlinear_theta_method.py index e2915742b..93703e975 100644 --- a/torax/_src/solver/nonlinear_theta_method.py +++ b/torax/_src/solver/nonlinear_theta_method.py @@ -26,6 +26,7 @@ from torax._src.fvm import newton_raphson_solve_block from torax._src.fvm import optimizer_solve_block from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import runtime_params as solver_runtime_params_lib from torax._src.solver import solver from torax._src.sources import source_profiles @@ -65,6 +66,9 @@ def _x_new( core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -89,6 +93,7 @@ def _x_new( explicit_source_profiles=explicit_source_profiles, coeffs_callback=coeffs_callback, evolving_names=evolving_names, + pedestal_transition_state=pedestal_transition_state, ) return ( @@ -109,6 +114,9 @@ def _x_new_helper( explicit_source_profiles: source_profiles.SourceProfiles, coeffs_callback: calc_coeffs.CoeffsCallback, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -140,6 +148,9 @@ def _x_new_helper( iterative solvers. evolving_names: The names of variables within the core profiles that should evolve. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: A tuple containing: @@ -164,6 +175,9 @@ def _x_new_helper( explicit_source_profiles: source_profiles.SourceProfiles, coeffs_callback: calc_coeffs.CoeffsCallback, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -194,6 +208,7 @@ def _x_new_helper( ), maxiter=solver_params.n_max_iterations, tol=solver_params.loss_tol, + pedestal_transition_state=pedestal_transition_state, ) return ( x_new, @@ -216,6 +231,9 @@ def _x_new_helper( explicit_source_profiles: source_profiles.SourceProfiles, coeffs_callback: calc_coeffs.CoeffsCallback, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -251,6 +269,7 @@ def _x_new_helper( coarse_tol=solver_params.residual_coarse_tol, delta_reduction_factor=solver_params.delta_reduction_factor, tau_min=solver_params.tau_min, + pedestal_transition_state=pedestal_transition_state, ) return ( x_new, diff --git a/torax/_src/solver/predictor_corrector_method.py b/torax/_src/solver/predictor_corrector_method.py index b1fc0ee87..42c9a858a 100644 --- a/torax/_src/solver/predictor_corrector_method.py +++ b/torax/_src/solver/predictor_corrector_method.py @@ -18,6 +18,7 @@ runtime_params_slice.solver.use_predictor_corrector is False, reverts to a standard linear solution. """ + import functools import jax @@ -28,6 +29,7 @@ from torax._src.fvm import cell_variable from torax._src.fvm import implicit_solve_block from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.solver import jax_fixed_point from torax._src.sources import source_profiles @@ -49,6 +51,9 @@ def predictor_corrector_method( coeffs_exp: block_1d_coeffs.Block1DCoeffs, explicit_source_profiles: source_profiles.SourceProfiles, coeffs_callback: calc_coeffs.CoeffsCallback, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[cell_variable.CellVariable, ...]: """Predictor-corrector method. @@ -71,6 +76,7 @@ def predictor_corrector_method( iterations. For sources that are implicit, their explicit profiles are set to all zeros. coeffs_callback: coefficient callback function. + pedestal_transition_state: State tracking the pedestal transitions. Returns: x_new: Solution of evolving core profile state variables @@ -88,6 +94,7 @@ def loop_body(x_new_guess): dt=dt, x=x_new_guess, explicit_source_profiles=explicit_source_profiles, + pedestal_transition_state=pedestal_transition_state, allow_pereverzev=True, ) diff --git a/torax/_src/solver/solver.py b/torax/_src/solver/solver.py index f8b728bc9..1d0d464a1 100644 --- a/torax/_src/solver/solver.py +++ b/torax/_src/solver/solver.py @@ -30,6 +30,7 @@ from torax._src.config import runtime_params as runtime_params_lib from torax._src.fvm import cell_variable from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.sources import source_profiles @@ -60,6 +61,9 @@ def __call__( core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -88,6 +92,9 @@ def __call__( or were independent of the core profiles. Because they were calculated outside the possibly-JAX-jitted solver logic, they can be calculated in non-JAX-friendly ways. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: x_new: Tuple containing new cell-grid values of the evolving variables. @@ -112,6 +119,7 @@ def __call__( core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, evolving_names=runtime_params_t.numerics.evolving_names, + pedestal_transition_state=pedestal_transition_state, ) else: x_new = tuple() @@ -138,6 +146,9 @@ def _x_new( core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, evolving_names: tuple[str, ...], + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs, @@ -163,6 +174,9 @@ def _x_new( are not being evolved by the PDE system. explicit_source_profiles: see the docstring of __call__ evolving_names: The names of core_profiles variables that should evolve. + pedestal_transition_state: State for tracking pedestal L-H and H-L + transitions. Only used when the pedestal mode is ADAPTIVE_SOURCE with + use_formation_model_with_adaptive_source=True. None otherwise. Returns: x_new: The values of the evolving variables at time t + dt. diff --git a/torax/_src/transport_model/transport_coefficients_builder.py b/torax/_src/transport_model/transport_coefficients_builder.py index 92d0ff5ee..107e49e3c 100644 --- a/torax/_src/transport_model/transport_coefficients_builder.py +++ b/torax/_src/transport_model/transport_coefficients_builder.py @@ -23,11 +23,14 @@ from torax._src.geometry import geometry from torax._src.neoclassical import neoclassical_models as neoclassical_models_lib from torax._src.pedestal_model import pedestal_model as pedestal_model_lib +from torax._src.pedestal_model import pedestal_transition_state as pedestal_transition_state_lib from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib from torax._src.sources import source_profiles as source_profiles_lib from torax._src.transport_model import pereverzev as pereverzev_lib from torax._src.transport_model import transport_model as transport_model_lib +# pylint: disable=invalid-name + @jax.jit( static_argnames=( @@ -45,8 +48,33 @@ def calculate_all_transport_coeffs( core_profiles: state.CoreProfiles, source_profiles: source_profiles_lib.SourceProfiles, use_pereverzev: bool = False, + pedestal_transition_state: ( + pedestal_transition_state_lib.PedestalTransitionState | None + ) = None, ) -> state.CoreTransport: """Calculates the transport coefficients from all models.""" + + if ( + runtime_params.pedestal.use_formation_model_with_adaptive_source + and pedestal_transition_state is not None + ): + is_L_mode = jnp.logical_not(pedestal_transition_state.in_H_mode) + is_transitioning = jnp.logical_and( + runtime_params.t >= pedestal_transition_state.transition_start_time, + runtime_params.t + <= pedestal_transition_state.transition_start_time + + runtime_params.pedestal.transition_time_width, + ) + need_mask = jnp.logical_or(jnp.logical_not(is_L_mode), is_transitioning) + + pedestal_params = dataclasses.replace( + runtime_params.pedestal, + set_pedestal=need_mask, + ) + runtime_params = dataclasses.replace( + runtime_params, + pedestal=pedestal_params, + ) pedestal_model_output = pedestal_model( runtime_params, geo, core_profiles, source_profiles ) diff --git a/torax/_src/transport_model/transport_model.py b/torax/_src/transport_model/transport_model.py index 58692e0e2..6b5fee9e0 100644 --- a/torax/_src/transport_model/transport_model.py +++ b/torax/_src/transport_model/transport_model.py @@ -444,13 +444,16 @@ def _build_smoothing_matrix( == pedestal_runtime_params_lib.Mode.ADAPTIVE_SOURCE ): # If in ADAPTIVE_SOURCE mode: if set_pedestal is True, mask according to the - # pedestal top. Otherwise, mask according to the outer patch, if set. + # pedestal top. Otherwise, mask according to the outer patch, if set. If no + # outer patch, do not mask. mask_outer_edge = jnp.where( - jnp.logical_not(runtime_params.pedestal.set_pedestal) - & runtime_params.transport.apply_outer_patch, - runtime_params.transport.rho_outer - consts.eps, - # If pedestal is not set, rho_norm_ped_top is inf. + runtime_params.pedestal.set_pedestal, pedestal_model_output.rho_norm_ped_top - consts.eps, + jnp.where( + runtime_params.transport.apply_outer_patch, + runtime_params.transport.rho_outer - consts.eps, + jnp.inf, + ), ) else: # If in ADAPTIVE_TRANSPORT mode, only mask according to the outer patch. diff --git a/torax/tests/sim_time_dependence_test.py b/torax/tests/sim_time_dependence_test.py index 878c66937..5bbc4524e 100644 --- a/torax/tests/sim_time_dependence_test.py +++ b/torax/tests/sim_time_dependence_test.py @@ -214,6 +214,7 @@ def __call__( core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, + pedestal_transition_state=None, ) -> tuple[ tuple[cell_variable.CellVariable, ...], state.SolverNumericOutputs,