Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion torax/_src/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torax._src.pedestal_model import set_pped_tpedratio_nped
from torax._src.pedestal_model import set_tped_nped
from torax._src.pedestal_model.formation import power_scaling_formation_model
from torax._src.pedestal_model.saturation import ballooning_stability_saturation_model
from torax._src.pedestal_model.saturation import profile_value_saturation_model
from torax._src.torax_pydantic import torax_pydantic

Expand Down Expand Up @@ -207,9 +208,72 @@ def build_runtime_params(
)


class BallooningStabilitySaturation(torax_pydantic.BaseModelFrozen):
"""Configuration for BallooningStabilitySaturation model.

This saturation model triggers an increase in pedestal transport when the
normalized pressure gradient alpha is above alpha_crit.

The formula is
transport_multiplier = 1 + alpha * base_multiplier,
where alpha is a softplus function of the normalized deviation from the target
value, with given steepness and offset:
x = (current - target) / target - offset
alpha = log(1 + exp(steepness * x))

Attributes:
alpha_crit: Margin above which transport is enhanced.
steepness: Scaling factor applied to the argument of the softplus function,
setting the steepness of the smooth saturation function. Decrease for a
smoother saturation, which may be more numerically stable but may lead to
starting saturation at a temperature or density below the target values.
offset: Bias applied to the argument of the softplus function, setting the
dimensionless offset of the saturation window. Increase to start
saturation at a higher temperature or density.
base_multiplier: The base value of the transport multiplier. Increase for
stronger increases in transport once saturation starts.
"""

model_name: Annotated[
Literal["ballooning_stability"], torax_pydantic.JAX_STATIC
] = "ballooning_stability"
alpha_crit: pydantic.PositiveFloat = 1.0
steepness: pydantic.PositiveFloat = 100.0
# Default offset is > 0 as otherwise saturation starts too early. This is
# because the softplus function is nonzero before the argument is zero.
offset: Annotated[
array_typing.FloatScalar, pydantic.Field(ge=-10.0, le=10.0)
] = 0.1
base_multiplier: Annotated[
array_typing.FloatScalar, pydantic.Field(gt=1.0)
] = 1e6

def build_saturation_model(
self,
) -> ballooning_stability_saturation_model.BallooningStabilitySaturationModel:
return (
ballooning_stability_saturation_model.BallooningStabilitySaturationModel()
)

def build_runtime_params(
self, t: chex.Numeric
) -> (
ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams
):
del t
return ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams(
alpha_crit=self.alpha_crit,
steepness=self.steepness,
offset=self.offset,
base_multiplier=self.base_multiplier,
)


# For new formation and saturation models, add to these TypeAliases via Union.
FormationConfig: TypeAlias = DelabieScalingFormation | MartinScalingFormation
SaturationConfig: TypeAlias = ProfileValueSaturation
SaturationConfig: TypeAlias = (
ProfileValueSaturation | BallooningStabilitySaturation
)


class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2026 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Saturation model based on ballooning stability limit."""

import dataclasses
import jax
import jax.numpy as jnp
from torax._src import array_typing
from torax._src import constants
from torax._src import math_utils
from torax._src import state
from torax._src.config import runtime_params as runtime_params_lib
from torax._src.geometry import geometry
from torax._src.pedestal_model import pedestal_model_output
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params_lib
from torax._src.pedestal_model.saturation import base
from torax._src.physics import formulas

# pylint: disable=invalid-name


def calculate_normalized_pressure_gradient(
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
) -> array_typing.FloatVector:
"""Calculates the normalized pressure gradient (alpha).

Equation:
alpha = -2*mu_0 * (dV/dpsi) * (1/(2*pi)^2) * sqrt(V / (2*pi^2*R_0)) *
(dp/dpsi)

Args:
core_profiles: CoreProfiles object containing information on pressures and
psi.
geo: Geometry object.

Returns:
alpha: Normalized pressure gradient evaluated on the face grid.
"""
dp_dpsi = formulas.calc_pprime(core_profiles)
dpsi_drhon = core_profiles.psi.face_grad()
# vpr is dV/drhon, so dV/dpsi = vpr / dpsi/drhon
dV_dpsi = math_utils.safe_divide(geo.vpr_face, dpsi_drhon)

# Plasma volume enclosed by the flux surface (V) and major radius (R_0)
V = geo.volume_face
R_0 = geo.R_major

# Calculate alpha
return (
-2.0
* constants.CONSTANTS.mu_0
* dV_dpsi
* (1.0 / (2.0 * jnp.pi) ** 2)
* jnp.sqrt(V / (2.0 * jnp.pi**2 * R_0))
* dp_dpsi
)


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class BallooningStabilitySaturationRuntimeParams(
pedestal_runtime_params_lib.SaturationRuntimeParams
):
"""Runtime params for ballooning stability saturation models."""

alpha_crit: array_typing.FloatScalar


@dataclasses.dataclass(frozen=True, eq=False)
class BallooningStabilitySaturationModel(base.SaturationModel):
"""Saturation model based on the maximum pressure gradient alpha_crit."""

def __call__(
self,
runtime_params: runtime_params_lib.RuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
pedestal_output: pedestal_model_output.PedestalModelOutput,
) -> array_typing.FloatScalar:
"""Calculates transport increase multipliers."""
assert isinstance(
runtime_params.pedestal.saturation,
BallooningStabilitySaturationRuntimeParams,
)

alpha = calculate_normalized_pressure_gradient(core_profiles, geo)
max_alpha_ped = jnp.max(
jnp.where(
geo.rho_face_norm >= pedestal_output.rho_norm_ped_top, alpha, 0.0
)
)
multiplier = self._calculate_multiplier(
current=max_alpha_ped,
target=runtime_params.pedestal.saturation.alpha_crit,
pedestal_runtime_params=runtime_params.pedestal,
)
return pedestal_model_output.TransportMultipliers(
chi_e_multiplier=multiplier,
chi_i_multiplier=multiplier,
D_e_multiplier=multiplier,
v_e_multiplier=multiplier,
)

def _calculate_multiplier(
self,
current: array_typing.FloatScalar,
target: array_typing.FloatScalar,
pedestal_runtime_params: pedestal_runtime_params_lib.RuntimeParams,
) -> array_typing.FloatScalar:
"""Calculates the transport increase multiplier.

If current >> target, multiplier -> infinity.
If current << target, multiplier -> 1.

Args:
current: The current value of the profile at the pedestal top.
target: The target value of the profile at the pedestal top.
pedestal_runtime_params: The runtime parameters for the pedestal model.

Returns:
The transport increase multiplier.
"""
steepness = pedestal_runtime_params.saturation.steepness
offset = pedestal_runtime_params.saturation.offset
base_multiplier = pedestal_runtime_params.saturation.base_multiplier
normalized_deviation = (current - target) / target - offset
transport_multiplier = 1 + base_multiplier * jax.nn.softplus(
normalized_deviation * steepness
)
return transport_multiplier
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2026 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from torax._src.config import build_runtime_params
from torax._src.core_profiles import initialization
from torax._src.pedestal_model import pedestal_model_output
from torax._src.pedestal_model.saturation import ballooning_stability_saturation_model
from torax._src.test_utils import default_configs
from torax._src.torax_pydantic import model_config

# pylint: disable=invalid-name


class BallooningStabilitySaturationModelTest(parameterized.TestCase):

def setUp(self):
super().setUp()
config = default_configs.get_default_config_dict()
# Add ballooning stability config with a default value for alpha_crit.
config['pedestal']['saturation_model'] = {
'model_name': 'ballooning_stability',
'alpha_crit': 1.0, # Will be overridden in specific tests.
}
self.torax_config = model_config.ToraxConfig.from_dict(config)
self.provider = build_runtime_params.RuntimeParamsProvider.from_config(
self.torax_config
)
self.runtime_params = self.provider(t=0.0)
self.geo = self.torax_config.geometry.build_provider(t=0.0)
self.source_models = self.torax_config.sources.build_models()
self.neoclassical_models = self.torax_config.neoclassical.build_models()
self.core_profiles = initialization.initial_core_profiles(
self.runtime_params,
self.geo,
self.source_models,
self.neoclassical_models,
)

@parameterized.named_parameters(
dict(
testcase_name='active',
# Low alpha_crit -> saturation is active.
alpha_crit=1e-3,
),
dict(
testcase_name='inactive',
# High alpha_crit -> saturation is inactive.
alpha_crit=1e3,
),
)
def test_saturation_multiplier(
self,
alpha_crit: float,
):
assert isinstance(
self.runtime_params.pedestal.saturation,
ballooning_stability_saturation_model.BallooningStabilitySaturationRuntimeParams,
)
saturation_model = (
ballooning_stability_saturation_model.BallooningStabilitySaturationModel()
)

# For this test, we put the pedestal top at the last grid point.
pedestal_output = pedestal_model_output.PedestalModelOutput(
rho_norm_ped_top=1.0,
rho_norm_ped_top_idx=-1,
# The following values are not used in the saturation model.
T_i_ped=1.0,
T_e_ped=1.0,
n_e_ped=1.0,
)

# Set alpha_crit in the runtime params.
new_saturation_params = dataclasses.replace(
self.runtime_params.pedestal.saturation,
alpha_crit=alpha_crit,
steepness=self.runtime_params.pedestal.saturation.steepness,
offset=self.runtime_params.pedestal.saturation.offset,
base_multiplier=self.runtime_params.pedestal.saturation.base_multiplier,
)
new_pedestal_params = dataclasses.replace(
self.runtime_params.pedestal, saturation=new_saturation_params
)
runtime_params = dataclasses.replace(
self.runtime_params, pedestal=new_pedestal_params
)

# Get actual alpha in the pedestal.
alpha = ballooning_stability_saturation_model.calculate_normalized_pressure_gradient(
self.core_profiles, self.geo
)
max_alpha_ped = alpha[-1]

# Calculate the multiplier.
transport_multipliers = saturation_model(
runtime_params,
self.geo,
self.core_profiles,
pedestal_output,
)

if max_alpha_ped < alpha_crit:
np.testing.assert_allclose(transport_multipliers.chi_e_multiplier, 1.0)
else:
self.assertGreater(transport_multipliers.chi_e_multiplier, 1.0)


if __name__ == '__main__':
absltest.main()
Loading