diff --git a/torax/_src/edge/extended_lengyel_solvers.py b/torax/_src/edge/extended_lengyel_solvers.py index 4964c0c08..0b924f956 100644 --- a/torax/_src/edge/extended_lengyel_solvers.py +++ b/torax/_src/edge/extended_lengyel_solvers.py @@ -794,8 +794,10 @@ def _solve_for_qcc( # argument is order 1 (or 0), allowing a fixed dimensionless epsilon to be # effective. This prevents vanishing gradients for deep negative excursions. qcc_norm = math_utils.smooth_sqrt( - math_utils.safe_divide(qcc_squared, qu**2), epsilon=1e-3 - ) + math_utils.safe_divide( + # TODO(b/512078510): Pick a reasonable eps value for safe_divide here. + num=qcc_squared, denom=qu**2, eps=1e-7), epsilon=1e-3 + ) qcc = qcc_norm * jnp.sqrt(qu**2) diff --git a/torax/_src/edge/updaters.py b/torax/_src/edge/updaters.py index c47e5abc0..546f59dbb 100644 --- a/torax/_src/edge/updaters.py +++ b/torax/_src/edge/updaters.py @@ -153,7 +153,9 @@ def _calculate_impurity_scaling_factor( # Calculate scaling from the current value of the profile at the lcfs. # This scales the whole profile shape to match the edge value. current_val_at_edge = impurity_params.n_e_ratios_face[species][-1] - return math_utils.safe_divide(conc_lcfs, current_val_at_edge) + return math_utils.safe_divide( + num=conc_lcfs, denom=current_val_at_edge, eps=1e-7 + ) def _update_impurities( diff --git a/torax/_src/geometry/fbt.py b/torax/_src/geometry/fbt.py index ec4381975..75798e7fd 100644 --- a/torax/_src/geometry/fbt.py +++ b/torax/_src/geometry/fbt.py @@ -445,7 +445,7 @@ def _from_fbt( # (either here or upstream in MEQ) # Approximate with analytical expressions for circular geometry. flux_surf_avg_B2 = math_utils.safe_divide( - B_0**2, np.sqrt(1.0 - LY['epsilon'] ** 2) + num=B_0**2, denom=np.sqrt(1.0 - LY['epsilon'] ** 2), eps=1e-7 ) flux_surf_avg_1_over_B2 = B_0**-2 * (1.0 + 1.5 * LY['epsilon'] ** 2) diff --git a/torax/_src/math_utils.py b/torax/_src/math_utils.py index d15323381..7ceae71f8 100644 --- a/torax/_src/math_utils.py +++ b/torax/_src/math_utils.py @@ -300,8 +300,21 @@ def cumulative_volume_integration( return cumulative_cell_integration(value * geo.vpr, geo) -def safe_divide(y: chex.Array, x: chex.Array) -> chex.Array: - return y / (x + constants.CONSTANTS.eps) +def safe_divide( + *, num: chex.Array, denom: chex.Array, eps: float +) -> chex.Array: + """Divides y by x, adding eps to the denominator for numerical stability. + + Args: + num: Numerator. + denom: Denominator. + eps: Small value added to the denominator to prevent division by zero. + Choose a value that is small relative to the expected magnitude of x. + + Returns: + num / (denom + eps). + """ + return num / (denom + eps) def inverse_softplus(x: jax.Array) -> jax.Array: diff --git a/torax/_src/neoclassical/transport/angioni_sauter.py b/torax/_src/neoclassical/transport/angioni_sauter.py index 04e722155..222ce0a3e 100644 --- a/torax/_src/neoclassical/transport/angioni_sauter.py +++ b/torax/_src/neoclassical/transport/angioni_sauter.py @@ -251,16 +251,24 @@ def _calculate_angioni_sauter_transport( # --- Step 4: Calculate thermodynamic forces --- dpsi_drhon = core_profiles.psi.face_grad() dlnne_dpsi = math_utils.safe_divide( - core_profiles.n_e.face_grad() / core_profiles.n_e.face_value(), dpsi_drhon + num=core_profiles.n_e.face_grad() / core_profiles.n_e.face_value(), + denom=dpsi_drhon, + eps=1e-7, ) dlnte_dpsi = math_utils.safe_divide( - core_profiles.T_e.face_grad() / core_profiles.T_e.face_value(), dpsi_drhon + num=core_profiles.T_e.face_grad() / core_profiles.T_e.face_value(), + denom=dpsi_drhon, + eps=1e-7, ) dlnni_dpsi = math_utils.safe_divide( - core_profiles.n_i.face_grad() / core_profiles.n_i.face_value(), dpsi_drhon + num=core_profiles.n_i.face_grad() / core_profiles.n_i.face_value(), + denom=dpsi_drhon, + eps=1e-7, ) dlnti_dpsi = math_utils.safe_divide( - core_profiles.T_i.face_grad() / core_profiles.T_i.face_value(), dpsi_drhon + num=core_profiles.T_i.face_grad() / core_profiles.T_i.face_value(), + denom=dpsi_drhon, + eps=1e-7, ) # --- Step 5: Calculate neoclassical fluxes --- diff --git a/torax/_src/output_tools/post_processing.py b/torax/_src/output_tools/post_processing.py index d472046d8..fc109118d 100644 --- a/torax/_src/output_tools/post_processing.py +++ b/torax/_src/output_tools/post_processing.py @@ -1010,10 +1010,14 @@ def cumulative_values(): I_external=I_external, I_non_inductive=I_non_inductive, f_non_inductive=math_utils.safe_divide( - I_non_inductive, sim_state.core_profiles.Ip_profile_face[-1] + num=I_non_inductive, + denom=sim_state.core_profiles.Ip_profile_face[-1], + eps=1e-7, ), f_bootstrap=math_utils.safe_divide( - I_bootstrap, sim_state.core_profiles.Ip_profile_face[-1] + num=I_bootstrap, + denom=sim_state.core_profiles.Ip_profile_face[-1], + eps=1e-7, ), beta_tor=beta_tor, beta_pol=beta_pol, diff --git a/torax/_src/output_tools/tests/post_processing_test.py b/torax/_src/output_tools/tests/post_processing_test.py index 84a6065c8..43e59f0f7 100644 --- a/torax/_src/output_tools/tests/post_processing_test.py +++ b/torax/_src/output_tools/tests/post_processing_test.py @@ -311,10 +311,10 @@ def test_current_outputs(self): # f_non_inductive = I_non_inductive / Ip # Ip comes from core_profiles.Ip_profile_face[-1] ip = self.core_profiles.Ip_profile_face[-1] - # Code uses constants.CONSTANTS.eps for division guard + # Code uses eps=1e-7 for division guard np.testing.assert_allclose( outputs.f_non_inductive, - math_utils.safe_divide(outputs.I_non_inductive, ip), + math_utils.safe_divide(num=outputs.I_non_inductive, denom=ip, eps=1e-7), rtol=1e-5, ) @@ -322,7 +322,7 @@ def test_current_outputs(self): # f_bootstrap = I_bootstrap / Ip np.testing.assert_allclose( outputs.f_bootstrap, - math_utils.safe_divide(outputs.I_bootstrap, ip), + math_utils.safe_divide(num=outputs.I_bootstrap, denom=ip, eps=1e-7), rtol=1e-5, ) diff --git a/torax/_src/physics/fast_ion_utils.py b/torax/_src/physics/fast_ion_utils.py index 1ab925146..f3d44fb1d 100644 --- a/torax/_src/physics/fast_ion_utils.py +++ b/torax/_src/physics/fast_ion_utils.py @@ -69,7 +69,7 @@ def _nu_epsilon( * ln_lambda ) denom = jnp.power(m_b_amu * T_a_ev + m_a_amu * T_b_ev, 1.5) - return jnp.asarray(math_utils.safe_divide(num, denom)) + return jnp.asarray(math_utils.safe_divide(num=num, denom=denom, eps=1e-7)) def _compute_T_tail( @@ -106,8 +106,9 @@ def _compute_T_tail( n_e_cm3 = n_e / 1.0e6 tau_s = math_utils.safe_divide( - 6.27e8 * mass_number * jnp.power(T_e_eV, 1.5), - charge_number**2 * n_e_cm3 * log_lambda_ei, + num=6.27e8 * mass_number * jnp.power(T_e_eV, 1.5), + denom=charge_number**2 * n_e_cm3 * log_lambda_ei, + eps=1e-7, ) T_e_J = T_e * constants.CONSTANTS.keV_to_J energy_density = 1.5 * n_total * T_e_J @@ -116,7 +117,7 @@ def _compute_T_tail( energy_slowing_down_time = 0.5 * tau_s xi = math_utils.safe_divide( - P_density_W * energy_slowing_down_time, energy_density + num=P_density_W * energy_slowing_down_time, denom=energy_density, eps=1e-7 ) return T_e * (1.0 + xi) @@ -238,7 +239,10 @@ def bimaxwellian_split( ) ) - n_tail = math_utils.safe_divide(P_density_W, energy_loss_rate_per_particle) + # TODO(b/512078510): Choose a reasonable eps value for safe_divide here. + n_tail = math_utils.safe_divide( + num=P_density_W, denom=energy_loss_rate_per_particle, eps=1e-7 + ) n_tail = jnp.clip(n_tail, 0.0, n_total * 0.99) n_tail = jnp.where(P_density_W <= 1.0e-6, 0.0, n_tail) diff --git a/torax/_src/physics/formulas.py b/torax/_src/physics/formulas.py index 4170b20d6..0ff314d5b 100644 --- a/torax/_src/physics/formulas.py +++ b/torax/_src/physics/formulas.py @@ -243,7 +243,7 @@ def calculate_betas( magnetic_pressure_on_axis = geo.B_0**2 / (2 * constants.CONSTANTS.mu_0) # Add a division guard though B0 should typically be non-zero. beta_tor = math_utils.safe_divide( - p_total_volume_avg, magnetic_pressure_on_axis + num=p_total_volume_avg, denom=magnetic_pressure_on_axis, eps=1e-7 ) beta_pol = ( diff --git a/torax/_src/physics/rotation.py b/torax/_src/physics/rotation.py index b0891455d..526f8623c 100644 --- a/torax/_src/physics/rotation.py +++ b/torax/_src/physics/rotation.py @@ -96,7 +96,9 @@ def _calculate_radial_electric_field( denominator = Z_i_face * constants.CONSTANTS.q_e * n_i.face_value() Er_poloidal_and_pressure_face = ( - math_utils.safe_divide(jnp.array(1.0), denominator) * dpi_dr + math_utils.safe_divide( + num=jnp.array(1.0), denom=denominator, eps=1e-7 + ) * dpi_dr + poloidal_velocity.face_value() * B_tor_face ) diff --git a/torax/_src/transport_model/quasilinear_transport_model.py b/torax/_src/transport_model/quasilinear_transport_model.py index 5694679b7..c09e7921a 100644 --- a/torax/_src/transport_model/quasilinear_transport_model.py +++ b/torax/_src/transport_model/quasilinear_transport_model.py @@ -392,7 +392,7 @@ def apply_fast_ion_stabilization( ) * transport.fast_ion_stabilization_multiplier + 1 return jnp.where( transport.fast_ion_stabilization, - math_utils.safe_divide(lref_over_lti, fi_stab_factor), + math_utils.safe_divide(num=lref_over_lti, denom=fi_stab_factor, eps=1e-7), lref_over_lti, ) diff --git a/torax/_src/transport_model/tglf_based_transport_model.py b/torax/_src/transport_model/tglf_based_transport_model.py index a6a57920f..153c998fc 100644 --- a/torax/_src/transport_model/tglf_based_transport_model.py +++ b/torax/_src/transport_model/tglf_based_transport_model.py @@ -296,9 +296,10 @@ def _prepare_tglf_inputs( # r is zero on axis, so use safe_divide to avoid division by zero. # Use face_grad to correctly handle constraints on the psi CellVariable. B_unit = math_utils.safe_divide( - core_profiles.q_face + num=core_profiles.q_face * core_profiles.psi.face_grad(x=geo.r_mid, x_left=r[0], x_right=r[-1]), - (2 * jnp.pi * r), # Note: psi_TGLF is psi_TORAX/2π + denom=(2 * jnp.pi * r), # Note: psi_TGLF is psi_TORAX/2π + eps=1e-7, ) # Ion gyroradius @@ -306,8 +307,9 @@ def _prepare_tglf_inputs( # avoid being swamped by the eps in the denominator. rho_s = ( math_utils.safe_divide( - m_D * c_s, - B_unit, + num=m_D * c_s, + denom=B_unit, + eps=1e-7, ) / constants.CONSTANTS.q_e ) @@ -317,13 +319,14 @@ def _prepare_tglf_inputs( # - In the TGLF docs, the prefactor of 743.0 comes from a combination of the # constants below plus being in CGS units. Below is the SI version. normalized_debye = math_utils.safe_divide( - ( + num=( (constants.CONSTANTS.epsilon_0 / constants.CONSTANTS.q_e) * (core_profiles.T_e.face_value() * 1e3) # keV -> eV / n_e ) ** 0.5, - rho_s, + denom=rho_s, + eps=1e-7, ) # Temperature ratio @@ -389,10 +392,11 @@ def _prepare_tglf_inputs( # (https://gacode.io/cgyro/cgyro_list.html#s) # - r_mid is zero on axis, so use safe_divide to avoid division by zero. q_prime = math_utils.safe_divide( - psi_calculations.calc_s_rmid(geo, core_profiles.psi) + num=psi_calculations.calc_s_rmid(geo, core_profiles.psi) * core_profiles.q_face**2 * a**2, - r**2, + denom=r**2, + eps=1e-7, ) # Dimensionless pressure gradient @@ -402,13 +406,14 @@ def _prepare_tglf_inputs( # - 8 * pi factor missing since TGLF internally operates on it using # beta/(8*pi) p_prime = math_utils.safe_divide( - 1.0e-7 + num=1.0e-7 * core_profiles.pressure_thermal_total.face_grad( x=geo.r_mid, x_left=r[0], x_right=r[-1] ) * core_profiles.q_face * a**2, - r * B_unit**2, + denom=r * B_unit**2, + eps=1e-7, ) # Electron beta @@ -418,8 +423,9 @@ def _prepare_tglf_inputs( # - In the TGLF docs, beta_e equation shown in CGS units, this is the SI # version beta_e = math_utils.safe_divide( - 2 * constants.CONSTANTS.mu_0 * n_e * T_e_J, - B_unit**2, + num=2 * constants.CONSTANTS.mu_0 * n_e * T_e_J, + denom=B_unit**2, + eps=1e-7, ) # Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface @@ -563,12 +569,18 @@ def _make_core_transport( dT_e_drhon = core_profiles.T_e.face_grad() * constants.CONSTANTS.keV_to_J dT_i_drhon = core_profiles.T_i.face_grad() * constants.CONSTANTS.keV_to_J chi_e = math_utils.safe_divide( - -P_e, - core_profiles.n_e.face_value() * dT_e_drhon * geo.g1_over_vpr_face, + num=-P_e, + denom=core_profiles.n_e.face_value() + * dT_e_drhon + * geo.g1_over_vpr_face, + eps=1e-7, ) chi_i = math_utils.safe_divide( - -P_i, - core_profiles.n_i.face_value() * dT_i_drhon * geo.g1_over_vpr_face, + num=-P_i, + denom=core_profiles.n_i.face_value() + * dT_i_drhon + * geo.g1_over_vpr_face, + eps=1e-7, ) # Convert from particle rate to D, V using effective @@ -576,12 +588,14 @@ def _make_core_transport( # regions where the flux is with the temperature gradient, otherwise it # sets purely convective transport. D_eff = math_utils.safe_divide( - -S_e, - core_profiles.n_e.face_grad() * geo.g1_over_vpr_face, + num=-S_e, + denom=core_profiles.n_e.face_grad() * geo.g1_over_vpr_face, + eps=1e-7, ) V_eff = math_utils.safe_divide( - S_e, - core_profiles.n_e.face_value() * geo.g0_face, + num=S_e, + denom=core_profiles.n_e.face_value() * geo.g0_face, + eps=1e-7, ) D_eff = jnp.where(jnp.isfinite(D_eff), D_eff, 0.0) V_eff = jnp.where(jnp.isfinite(V_eff), V_eff, 0.0)