diff --git a/torax/_src/pedestal_model/pydantic_model.py b/torax/_src/pedestal_model/pydantic_model.py index e251f1f6a..0a9f90539 100644 --- a/torax/_src/pedestal_model/pydantic_model.py +++ b/torax/_src/pedestal_model/pydantic_model.py @@ -26,6 +26,7 @@ from torax._src.pedestal_model import set_pped_tpedratio_nped from torax._src.pedestal_model import set_tped_nped from torax._src.pedestal_model.formation import power_scaling_formation_model +from torax._src.pedestal_model.saturation import ballooning_stability_saturation_model from torax._src.pedestal_model.saturation import profile_value_saturation_model from torax._src.torax_pydantic import torax_pydantic @@ -207,9 +208,72 @@ def build_runtime_params( ) +class BallooningStabilitySaturation(torax_pydantic.BaseModelFrozen): + """Configuration for BallooningStabilitySaturation model. + + This saturation model triggers an increase in pedestal transport when the + normalized pressure gradient alpha is above alpha_crit. + + The formula is + transport_multiplier = 1 + alpha * base_multiplier, + where alpha is a softplus function of the normalized deviation from the target + value, with given steepness and offset: + x = (current - target) / target - offset + alpha = log(1 + exp(steepness * x)) + + Attributes: + alpha_crit: Margin above which transport is enhanced. + steepness: Scaling factor applied to the argument of the softplus function, + setting the steepness of the smooth saturation function. Decrease for a + smoother saturation, which may be more numerically stable but may lead to + starting saturation at a temperature or density below the target values. + offset: Bias applied to the argument of the softplus function, setting the + dimensionless offset of the saturation window. Increase to start + saturation at a higher temperature or density. + base_multiplier: The base value of the transport multiplier. Increase for + stronger increases in transport once saturation starts. + """ + + model_name: Annotated[ + Literal["ballooning_stability"], torax_pydantic.JAX_STATIC + ] = "ballooning_stability" + alpha_crit: pydantic.PositiveFloat = 1.0 + steepness: pydantic.PositiveFloat = 100.0 + # Default offset is > 0 as otherwise saturation starts too early. This is + # because the softplus function is nonzero before the argument is zero. + offset: Annotated[ + array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0) + ] = 0.1 + base_multiplier: Annotated[ + array_typing.FloatScalar, pydantic.Field(gt=1.0) + ] = 1e6 + + def build_saturation_model( + self, + ) -> ballooning_stability_saturation_model.BallooningStabilitySaturationModel: + return ( + ballooning_stability_saturation_model.BallooningStabilitySaturationModel() + ) + + def build_runtime_params( + self, t: chex.Numeric + ) -> ( + ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams + ): + del t + return ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams( + alpha_crit=self.alpha_crit, + steepness=self.steepness, + offset=self.offset, + base_multiplier=self.base_multiplier, + ) + + # For new formation and saturation models, add to these TypeAliases via Union. FormationConfig: TypeAlias = DelabieScalingFormation | MartinScalingFormation -SaturationConfig: TypeAlias = ProfileValueSaturation +SaturationConfig: TypeAlias = ( + ProfileValueSaturation | BallooningStabilitySaturation +) class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC): diff --git a/torax/_src/pedestal_model/saturation/ballooning_stability_saturation_model.py b/torax/_src/pedestal_model/saturation/ballooning_stability_saturation_model.py new file mode 100644 index 000000000..d06521f9d --- /dev/null +++ b/torax/_src/pedestal_model/saturation/ballooning_stability_saturation_model.py @@ -0,0 +1,143 @@ +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Saturation model based on ballooning stability limit.""" + +import dataclasses +import jax +import jax.numpy as jnp +from torax._src import array_typing +from torax._src import constants +from torax._src import math_utils +from torax._src import state +from torax._src.config import runtime_params as runtime_params_lib +from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_model_output +from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib +from torax._src.pedestal_model.saturation import base +from torax._src.physics import formulas + +# pylint: disable=invalid-name + + +def calculate_normalized_pressure_gradient( + core_profiles: state.CoreProfiles, + geo: geometry.Geometry, +) -> array_typing.FloatVector: + """Calculates the normalized pressure gradient (alpha). + + Equation: + alpha = -2*mu_0 * (dV/dpsi) * (1/(2*pi)^2) * sqrt(V / (2*pi^2*R_0)) * + (dp/dpsi) + + Args: + core_profiles: CoreProfiles object containing information on pressures and + psi. + geo: Geometry object. + + Returns: + alpha: Normalized pressure gradient evaluated on the face grid. + """ + dp_dpsi = formulas.calc_pprime(core_profiles) + dpsi_drhon = core_profiles.psi.face_grad() + # vpr is dV/drhon, so dV/dpsi = vpr / dpsi/drhon + dV_dpsi = math_utils.safe_divide(geo.vpr_face, dpsi_drhon) + + # Plasma volume enclosed by the flux surface (V) and major radius (R_0) + V = geo.volume_face + R_0 = geo.R_major + + # Calculate alpha + return ( + -2.0 + * constants.CONSTANTS.mu_0 + * dV_dpsi + * (1.0 / (2.0 * jnp.pi) ** 2) + * jnp.sqrt(V / (2.0 * jnp.pi**2 * R_0)) + * dp_dpsi + ) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class BallooningStabilitySaturationRuntimeParams( + pedestal_runtime_params_lib.SaturationRuntimeParams +): + """Runtime params for ballooning stability saturation models.""" + + alpha_crit: array_typing.FloatScalar + + +@dataclasses.dataclass(frozen=True, eq=False) +class BallooningStabilitySaturationModel(base.SaturationModel): + """Saturation model based on the maximum pressure gradient alpha_crit.""" + + def __call__( + self, + runtime_params: runtime_params_lib.RuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + pedestal_output: pedestal_model_output.PedestalModelOutput, + ) -> array_typing.FloatScalar: + """Calculates transport increase multipliers.""" + assert isinstance( + runtime_params.pedestal.saturation, + BallooningStabilitySaturationRuntimeParams, + ) + + alpha = calculate_normalized_pressure_gradient(core_profiles, geo) + max_alpha_ped = jnp.max( + jnp.where( + geo.rho_face_norm >= pedestal_output.rho_norm_ped_top, alpha, 0.0 + ) + ) + multiplier = self._calculate_multiplier( + current=max_alpha_ped, + target=runtime_params.pedestal.saturation.alpha_crit, + pedestal_runtime_params=runtime_params.pedestal, + ) + return pedestal_model_output.TransportMultipliers( + chi_e_multiplier=multiplier, + chi_i_multiplier=multiplier, + D_e_multiplier=multiplier, + v_e_multiplier=multiplier, + ) + + def _calculate_multiplier( + self, + current: array_typing.FloatScalar, + target: array_typing.FloatScalar, + pedestal_runtime_params: pedestal_runtime_params_lib.RuntimeParams, + ) -> array_typing.FloatScalar: + """Calculates the transport increase multiplier. + + If current >> target, multiplier -> infinity. + If current << target, multiplier -> 1. + + Args: + current: The current value of the profile at the pedestal top. + target: The target value of the profile at the pedestal top. + pedestal_runtime_params: The runtime parameters for the pedestal model. + + Returns: + The transport increase multiplier. + """ + steepness = pedestal_runtime_params.saturation.steepness + offset = pedestal_runtime_params.saturation.offset + base_multiplier = pedestal_runtime_params.saturation.base_multiplier + normalized_deviation = (current - target) / target - offset + transport_multiplier = 1 + base_multiplier * jax.nn.softplus( + normalized_deviation * steepness + ) + return transport_multiplier diff --git a/torax/_src/pedestal_model/saturation/tests/ballooning_stability_saturation_model_test.py b/torax/_src/pedestal_model/saturation/tests/ballooning_stability_saturation_model_test.py new file mode 100644 index 000000000..8d694fa2d --- /dev/null +++ b/torax/_src/pedestal_model/saturation/tests/ballooning_stability_saturation_model_test.py @@ -0,0 +1,123 @@ +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from torax._src.config import build_runtime_params +from torax._src.core_profiles import initialization +from torax._src.pedestal_model import pedestal_model_output +from torax._src.pedestal_model.saturation import ballooning_stability_saturation_model +from torax._src.test_utils import default_configs +from torax._src.torax_pydantic import model_config + +# pylint: disable=invalid-name + + +class BallooningStabilitySaturationModelTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + config = default_configs.get_default_config_dict() + # Add ballooning stability config with a default value for alpha_crit. + config['pedestal']['saturation_model'] = { + 'model_name': 'ballooning_stability', + 'alpha_crit': 1.0, # Will be overridden in specific tests. + } + self.torax_config = model_config.ToraxConfig.from_dict(config) + self.provider = build_runtime_params.RuntimeParamsProvider.from_config( + self.torax_config + ) + self.runtime_params = self.provider(t=0.0) + self.geo = self.torax_config.geometry.build_provider(t=0.0) + self.source_models = self.torax_config.sources.build_models() + self.neoclassical_models = self.torax_config.neoclassical.build_models() + self.core_profiles = initialization.initial_core_profiles( + self.runtime_params, + self.geo, + self.source_models, + self.neoclassical_models, + ) + + @parameterized.named_parameters( + dict( + testcase_name='active', + # Low alpha_crit -> saturation is active. + alpha_crit=1e-3, + ), + dict( + testcase_name='inactive', + # High alpha_crit -> saturation is inactive. + alpha_crit=1e3, + ), + ) + def test_saturation_multiplier( + self, + alpha_crit: float, + ): + assert isinstance( + self.runtime_params.pedestal.saturation, + ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams, + ) + saturation_model = ( + ballooning_stability_saturation_model.BallooningStabilitySaturationModel() + ) + + # For this test, we put the pedestal top at the last grid point. + pedestal_output = pedestal_model_output.PedestalModelOutput( + rho_norm_ped_top=1.0, + rho_norm_ped_top_idx=-1, + # The following values are not used in the saturation model. + T_i_ped=1.0, + T_e_ped=1.0, + n_e_ped=1.0, + ) + + # Set alpha_crit in the runtime params. + new_saturation_params = dataclasses.replace( + self.runtime_params.pedestal.saturation, + alpha_crit=alpha_crit, + steepness=self.runtime_params.pedestal.saturation.steepness, + offset=self.runtime_params.pedestal.saturation.offset, + base_multiplier=self.runtime_params.pedestal.saturation.base_multiplier, + ) + new_pedestal_params = dataclasses.replace( + self.runtime_params.pedestal, saturation=new_saturation_params + ) + runtime_params = dataclasses.replace( + self.runtime_params, pedestal=new_pedestal_params + ) + + # Get actual alpha in the pedestal. + alpha = ballooning_stability_saturation_model.calculate_normalized_pressure_gradient( + self.core_profiles, self.geo + ) + max_alpha_ped = alpha[-1] + + # Calculate the multiplier. + transport_multipliers = saturation_model( + runtime_params, + self.geo, + self.core_profiles, + pedestal_output, + ) + + if max_alpha_ped < alpha_crit: + np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0) + else: + self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0) + + +if __name__ == '__main__': + absltest.main()