Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torax/_src/edge/extended_lengyel_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,10 @@ def _solve_for_qcc(
# argument is order 1 (or 0), allowing a fixed dimensionless epsilon to be
# effective. This prevents vanishing gradients for deep negative excursions.
qcc_norm = math_utils.smooth_sqrt(
math_utils.safe_divide(qcc_squared, qu**2), epsilon=1e-3
)
math_utils.safe_divide(
# TODO(b/512078510): Pick a reasonable eps value for safe_divide here.
num=qcc_squared, denom=qu**2, eps=1e-7), epsilon=1e-3
)

qcc = qcc_norm * jnp.sqrt(qu**2)

Expand Down
4 changes: 3 additions & 1 deletion torax/_src/edge/updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _calculate_impurity_scaling_factor(
# Calculate scaling from the current value of the profile at the lcfs.
# This scales the whole profile shape to match the edge value.
current_val_at_edge = impurity_params.n_e_ratios_face[species][-1]
return math_utils.safe_divide(conc_lcfs, current_val_at_edge)
return math_utils.safe_divide(
num=conc_lcfs, denom=current_val_at_edge, eps=1e-7
)


def _update_impurities(
Expand Down
2 changes: 1 addition & 1 deletion torax/_src/geometry/fbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _from_fbt(
# (either here or upstream in MEQ)
# Approximate with analytical expressions for circular geometry.
flux_surf_avg_B2 = math_utils.safe_divide(
B_0**2, np.sqrt(1.0 - LY['epsilon'] ** 2)
num=B_0**2, denom=np.sqrt(1.0 - LY['epsilon'] ** 2), eps=1e-7
)
flux_surf_avg_1_over_B2 = B_0**-2 * (1.0 + 1.5 * LY['epsilon'] ** 2)

Expand Down
17 changes: 15 additions & 2 deletions torax/_src/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,21 @@ def cumulative_volume_integration(
return cumulative_cell_integration(value * geo.vpr, geo)


def safe_divide(y: chex.Array, x: chex.Array) -> chex.Array:
return y / (x + constants.CONSTANTS.eps)
def safe_divide(
*, num: chex.Array, denom: chex.Array, eps: float
) -> chex.Array:
"""Divides y by x, adding eps to the denominator for numerical stability.

Args:
num: Numerator.
denom: Denominator.
eps: Small value added to the denominator to prevent division by zero.
Choose a value that is small relative to the expected magnitude of x.

Returns:
num / (denom + eps).
"""
return num / (denom + eps)


def inverse_softplus(x: jax.Array) -> jax.Array:
Expand Down
16 changes: 12 additions & 4 deletions torax/_src/neoclassical/transport/angioni_sauter.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,24 @@ def _calculate_angioni_sauter_transport(
# --- Step 4: Calculate thermodynamic forces ---
dpsi_drhon = core_profiles.psi.face_grad()
dlnne_dpsi = math_utils.safe_divide(
core_profiles.n_e.face_grad() / core_profiles.n_e.face_value(), dpsi_drhon
num=core_profiles.n_e.face_grad() / core_profiles.n_e.face_value(),
denom=dpsi_drhon,
eps=1e-7,
)
dlnte_dpsi = math_utils.safe_divide(
core_profiles.T_e.face_grad() / core_profiles.T_e.face_value(), dpsi_drhon
num=core_profiles.T_e.face_grad() / core_profiles.T_e.face_value(),
denom=dpsi_drhon,
eps=1e-7,
)
dlnni_dpsi = math_utils.safe_divide(
core_profiles.n_i.face_grad() / core_profiles.n_i.face_value(), dpsi_drhon
num=core_profiles.n_i.face_grad() / core_profiles.n_i.face_value(),
denom=dpsi_drhon,
eps=1e-7,
)
dlnti_dpsi = math_utils.safe_divide(
core_profiles.T_i.face_grad() / core_profiles.T_i.face_value(), dpsi_drhon
num=core_profiles.T_i.face_grad() / core_profiles.T_i.face_value(),
denom=dpsi_drhon,
eps=1e-7,
)

# --- Step 5: Calculate neoclassical fluxes ---
Expand Down
8 changes: 6 additions & 2 deletions torax/_src/output_tools/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,10 +1010,14 @@ def cumulative_values():
I_external=I_external,
I_non_inductive=I_non_inductive,
f_non_inductive=math_utils.safe_divide(
I_non_inductive, sim_state.core_profiles.Ip_profile_face[-1]
num=I_non_inductive,
denom=sim_state.core_profiles.Ip_profile_face[-1],
eps=1e-7,
),
f_bootstrap=math_utils.safe_divide(
I_bootstrap, sim_state.core_profiles.Ip_profile_face[-1]
num=I_bootstrap,
denom=sim_state.core_profiles.Ip_profile_face[-1],
eps=1e-7,
),
beta_tor=beta_tor,
beta_pol=beta_pol,
Expand Down
6 changes: 3 additions & 3 deletions torax/_src/output_tools/tests/post_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,18 +311,18 @@ def test_current_outputs(self):
# f_non_inductive = I_non_inductive / Ip
# Ip comes from core_profiles.Ip_profile_face[-1]
ip = self.core_profiles.Ip_profile_face[-1]
# Code uses constants.CONSTANTS.eps for division guard
# Code uses eps=1e-7 for division guard
np.testing.assert_allclose(
outputs.f_non_inductive,
math_utils.safe_divide(outputs.I_non_inductive, ip),
math_utils.safe_divide(num=outputs.I_non_inductive, denom=ip, eps=1e-7),
rtol=1e-5,
)

# Check bootstrap fraction
# f_bootstrap = I_bootstrap / Ip
np.testing.assert_allclose(
outputs.f_bootstrap,
math_utils.safe_divide(outputs.I_bootstrap, ip),
math_utils.safe_divide(num=outputs.I_bootstrap, denom=ip, eps=1e-7),
rtol=1e-5,
)

Expand Down
14 changes: 9 additions & 5 deletions torax/_src/physics/fast_ion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _nu_epsilon(
* ln_lambda
)
denom = jnp.power(m_b_amu * T_a_ev + m_a_amu * T_b_ev, 1.5)
return jnp.asarray(math_utils.safe_divide(num, denom))
return jnp.asarray(math_utils.safe_divide(num=num, denom=denom, eps=1e-7))


def _compute_T_tail(
Expand Down Expand Up @@ -106,8 +106,9 @@ def _compute_T_tail(
n_e_cm3 = n_e / 1.0e6

tau_s = math_utils.safe_divide(
6.27e8 * mass_number * jnp.power(T_e_eV, 1.5),
charge_number**2 * n_e_cm3 * log_lambda_ei,
num=6.27e8 * mass_number * jnp.power(T_e_eV, 1.5),
denom=charge_number**2 * n_e_cm3 * log_lambda_ei,
eps=1e-7,
)
T_e_J = T_e * constants.CONSTANTS.keV_to_J
energy_density = 1.5 * n_total * T_e_J
Expand All @@ -116,7 +117,7 @@ def _compute_T_tail(
energy_slowing_down_time = 0.5 * tau_s

xi = math_utils.safe_divide(
P_density_W * energy_slowing_down_time, energy_density
num=P_density_W * energy_slowing_down_time, denom=energy_density, eps=1e-7
)

return T_e * (1.0 + xi)
Expand Down Expand Up @@ -238,7 +239,10 @@ def bimaxwellian_split(
)
)

n_tail = math_utils.safe_divide(P_density_W, energy_loss_rate_per_particle)
# TODO(b/512078510): Choose a reasonable eps value for safe_divide here.
n_tail = math_utils.safe_divide(
num=P_density_W, denom=energy_loss_rate_per_particle, eps=1e-7
)
n_tail = jnp.clip(n_tail, 0.0, n_total * 0.99)

n_tail = jnp.where(P_density_W <= 1.0e-6, 0.0, n_tail)
Expand Down
2 changes: 1 addition & 1 deletion torax/_src/physics/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def calculate_betas(
magnetic_pressure_on_axis = geo.B_0**2 / (2 * constants.CONSTANTS.mu_0)
# Add a division guard though B0 should typically be non-zero.
beta_tor = math_utils.safe_divide(
p_total_volume_avg, magnetic_pressure_on_axis
num=p_total_volume_avg, denom=magnetic_pressure_on_axis, eps=1e-7
)

beta_pol = (
Expand Down
4 changes: 3 additions & 1 deletion torax/_src/physics/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def _calculate_radial_electric_field(
denominator = Z_i_face * constants.CONSTANTS.q_e * n_i.face_value()

Er_poloidal_and_pressure_face = (
math_utils.safe_divide(jnp.array(1.0), denominator) * dpi_dr
math_utils.safe_divide(
num=jnp.array(1.0), denom=denominator, eps=1e-7
) * dpi_dr
+ poloidal_velocity.face_value() * B_tor_face
)

Expand Down
2 changes: 1 addition & 1 deletion torax/_src/transport_model/quasilinear_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def apply_fast_ion_stabilization(
) * transport.fast_ion_stabilization_multiplier + 1
return jnp.where(
transport.fast_ion_stabilization,
math_utils.safe_divide(lref_over_lti, fi_stab_factor),
math_utils.safe_divide(num=lref_over_lti, denom=fi_stab_factor, eps=1e-7),
lref_over_lti,
)

Expand Down
54 changes: 34 additions & 20 deletions torax/_src/transport_model/tglf_based_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,20 @@ def _prepare_tglf_inputs(
# r is zero on axis, so use safe_divide to avoid division by zero.
# Use face_grad to correctly handle constraints on the psi CellVariable.
B_unit = math_utils.safe_divide(
core_profiles.q_face
num=core_profiles.q_face
* core_profiles.psi.face_grad(x=geo.r_mid, x_left=r[0], x_right=r[-1]),
(2 * jnp.pi * r), # Note: psi_TGLF is psi_TORAX/2π
denom=(2 * jnp.pi * r), # Note: psi_TGLF is psi_TORAX/2π
eps=1e-7,
)

# Ion gyroradius
# TODO(b/502473098): Currently, q_e has to be outside of the safe_divide to
# avoid being swamped by the eps in the denominator.
rho_s = (
math_utils.safe_divide(
m_D * c_s,
B_unit,
num=m_D * c_s,
denom=B_unit,
eps=1e-7,
)
/ constants.CONSTANTS.q_e
)
Expand All @@ -317,13 +319,14 @@ def _prepare_tglf_inputs(
# - In the TGLF docs, the prefactor of 743.0 comes from a combination of the
# constants below plus being in CGS units. Below is the SI version.
normalized_debye = math_utils.safe_divide(
(
num=(
(constants.CONSTANTS.epsilon_0 / constants.CONSTANTS.q_e)
* (core_profiles.T_e.face_value() * 1e3) # keV -> eV
/ n_e
)
** 0.5,
rho_s,
denom=rho_s,
eps=1e-7,
)

# Temperature ratio
Expand Down Expand Up @@ -389,10 +392,11 @@ def _prepare_tglf_inputs(
# (https://gacode.io/cgyro/cgyro_list.html#s)
# - r_mid is zero on axis, so use safe_divide to avoid division by zero.
q_prime = math_utils.safe_divide(
psi_calculations.calc_s_rmid(geo, core_profiles.psi)
num=psi_calculations.calc_s_rmid(geo, core_profiles.psi)
* core_profiles.q_face**2
* a**2,
r**2,
denom=r**2,
eps=1e-7,
)

# Dimensionless pressure gradient
Expand All @@ -402,13 +406,14 @@ def _prepare_tglf_inputs(
# - 8 * pi factor missing since TGLF internally operates on it using
# beta/(8*pi)
p_prime = math_utils.safe_divide(
1.0e-7
num=1.0e-7
* core_profiles.pressure_thermal_total.face_grad(
x=geo.r_mid, x_left=r[0], x_right=r[-1]
)
* core_profiles.q_face
* a**2,
r * B_unit**2,
denom=r * B_unit**2,
eps=1e-7,
)

# Electron beta
Expand All @@ -418,8 +423,9 @@ def _prepare_tglf_inputs(
# - In the TGLF docs, beta_e equation shown in CGS units, this is the SI
# version
beta_e = math_utils.safe_divide(
2 * constants.CONSTANTS.mu_0 * n_e * T_e_J,
B_unit**2,
num=2 * constants.CONSTANTS.mu_0 * n_e * T_e_J,
denom=B_unit**2,
eps=1e-7,
)

# Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface
Expand Down Expand Up @@ -563,25 +569,33 @@ def _make_core_transport(
dT_e_drhon = core_profiles.T_e.face_grad() * constants.CONSTANTS.keV_to_J
dT_i_drhon = core_profiles.T_i.face_grad() * constants.CONSTANTS.keV_to_J
chi_e = math_utils.safe_divide(
-P_e,
core_profiles.n_e.face_value() * dT_e_drhon * geo.g1_over_vpr_face,
num=-P_e,
denom=core_profiles.n_e.face_value()
* dT_e_drhon
* geo.g1_over_vpr_face,
eps=1e-7,
)
chi_i = math_utils.safe_divide(
-P_i,
core_profiles.n_i.face_value() * dT_i_drhon * geo.g1_over_vpr_face,
num=-P_i,
denom=core_profiles.n_i.face_value()
* dT_i_drhon
* geo.g1_over_vpr_face,
eps=1e-7,
)

# Convert from particle rate to D, V using effective
# diffusivity/convectivity method. This sets purely diffusive transport in
# regions where the flux is with the temperature gradient, otherwise it
# sets purely convective transport.
D_eff = math_utils.safe_divide(
-S_e,
core_profiles.n_e.face_grad() * geo.g1_over_vpr_face,
num=-S_e,
denom=core_profiles.n_e.face_grad() * geo.g1_over_vpr_face,
eps=1e-7,
)
V_eff = math_utils.safe_divide(
S_e,
core_profiles.n_e.face_value() * geo.g0_face,
num=S_e,
denom=core_profiles.n_e.face_value() * geo.g0_face,
eps=1e-7,
)
D_eff = jnp.where(jnp.isfinite(D_eff), D_eff, 0.0)
V_eff = jnp.where(jnp.isfinite(V_eff), V_eff, 0.0)
Expand Down
Loading