diff --git a/torax/_src/transport_model/pydantic_model.py b/torax/_src/transport_model/pydantic_model.py index fec2a2e11..b9f893b6e 100644 --- a/torax/_src/transport_model/pydantic_model.py +++ b/torax/_src/transport_model/pydantic_model.py @@ -220,6 +220,7 @@ class TGLFNNukaeaTransportModel(pydantic_model_base.TransportBase): # Quasilinear transport options DV_effective: bool = False An_min: pydantic.PositiveFloat = 0.05 + collisionality_multiplier: float = 1.0 def build_transport_model( self, @@ -237,6 +238,7 @@ def build_runtime_params( An_min=self.An_min, rotation_multiplier=self.rotation_multiplier, use_rotation=self.use_rotation, + collisionality_multiplier=self.collisionality_multiplier, # From base **base_kwargs, ) @@ -415,6 +417,7 @@ def build_runtime_params( try: from torax._src.transport_model import qualikiz_transport_model # pylint: disable=g-import-not-at-top + from torax._src.transport_model import tglf_transport_model # pylint: disable=g-import-not-at-top # Since CombinedCompatibleTransportModel is not constant, because of the # try/except block, unions using this type will cause invalid-annotation @@ -426,6 +429,7 @@ def build_runtime_params( | CriticalGradientTransportModel | BohmGyroBohmTransportModel | qualikiz_transport_model.QualikizTransportModelConfig + | tglf_transport_model.TGLFTransportModelConfig ) except ImportError: diff --git a/torax/_src/transport_model/tests/tglf_based_transport_model_test.py b/torax/_src/transport_model/tests/tglf_based_transport_model_test.py index 93634b66d..8c5b397dc 100644 --- a/torax/_src/transport_model/tests/tglf_based_transport_model_test.py +++ b/torax/_src/transport_model/tests/tglf_based_transport_model_test.py @@ -224,6 +224,7 @@ def build_runtime_params(self, t: chex.Numeric): An_min=0.05, use_rotation=True, rotation_multiplier=1.0, + collisionality_multiplier=1.0, **base_kwargs, ) diff --git a/torax/_src/transport_model/tglf_based_transport_model.py b/torax/_src/transport_model/tglf_based_transport_model.py index 75f32b831..14c6f7cd2 100644 --- a/torax/_src/transport_model/tglf_based_transport_model.py +++ b/torax/_src/transport_model/tglf_based_transport_model.py @@ -36,6 +36,9 @@ class RuntimeParams(quasilinear_transport_model.RuntimeParams): """Shared parameters for TGLF-based models.""" use_rotation: bool = dataclasses.field(metadata={"static": True}) rotation_multiplier: float + DV_effective: bool + An_min: float + collisionality_multiplier: float # pylint: disable=invalid-name @@ -46,20 +49,28 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs): See https://gacode.io/tglf/tglf_table.html for definitions. + This TGLF interface numbers the plasma species as such: + 1 = electrons, 2 = main ion (i0), 3 = impurity (i1) + Attributes: Ti_over_Te: Ratio of ion temperature to electron temperature. + ni0_over_ne: Ratio of main ion density to electron density. + ni1_over_ne: Ratio of impurity density to electron density. r_minor: Flux surface centroid minor radius. + r_major: Flux surface centroid major radius. dr_major: Gradient of the flux surface centroid major radius with respect to the minor radius (:math:`dr_{major}/dr_{minor}`). q: The safety factor. q_prime: Magnetic shear, defined as :math:`\frac{q^2 a^2 s}{r^2}`. - nu_ee: The electron-electron collision frequency. + nu_ee: Normalized electron-electron collision frequency. + debye: Normalized Debye length. kappa: Plasma elongation. kappa_shear: Shear in elongation, defined as :math:`\frac{r}{\kappa} \frac{d\kappa}{dr}`. delta: Plasma triangularity. delta_shear: Shear in triangularity, defined as :math:`r\frac{d\delta}{dr}`. beta_e: Electron pressure normalized by TGLF's :math:`B_\mathrm{unit}`. + p_prime: Plasma pressure gradient normalized by TGLF's :math:`B_\mathrm{unit}`. Zeff: Effective charge. Q_GB: TGLF heat flux normalisation factor. Gamma_GB: TGLF particle flux normalisation factor. @@ -67,16 +78,21 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs): """ Ti_over_Te: array_typing.FloatVectorFace + ni0_over_ne: array_typing.FloatVectorFace + ni1_over_ne: array_typing.FloatVectorFace r_minor: array_typing.FloatVectorFace + r_major: array_typing.FloatVectorFace dr_major: array_typing.FloatVectorFace q: array_typing.FloatVectorFace q_prime: array_typing.FloatVectorFace nu_ee: array_typing.FloatVectorFace + debye: array_typing.FloatVectorFace kappa: array_typing.FloatVectorFace kappa_shear: array_typing.FloatVectorFace delta: array_typing.FloatVectorFace delta_shear: array_typing.FloatVectorFace beta_e: array_typing.FloatVectorFace + p_prime: array_typing.FloatVectorFace Zeff: array_typing.FloatVectorFace Q_GB: array_typing.FloatVectorFace Gamma_GB: array_typing.FloatVectorFace @@ -87,6 +103,21 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs): def TAUS_2(self) -> array_typing.FloatVectorFace: return self.Ti_over_Te + # Assumes Timp = Ti since TORAX does not track impurity temperature + @property + def TAUS_3(self) -> array_typing.FloatVectorFace: + return self.Ti_over_Te + + @property + def AS_2(self) -> array_typing.FloatVectorFace: + return self.ni0_over_ne + + # Uses nimp as TORAX explicitly tracks impurity density even though + # it is currently not being evolved + @property + def AS_3(self) -> array_typing.FloatVectorFace: + return self.ni1_over_ne + @property def DRMAJDX_LOC(self) -> array_typing.FloatVectorFace: return self.dr_major @@ -103,6 +134,10 @@ def Q_PRIME_LOC(self) -> array_typing.FloatVectorFace: def XNUE(self) -> array_typing.FloatVectorFace: return self.nu_ee + @property + def DEBYE(self) -> array_typing.FloatVectorFace: + return self.debye + @property def KAPPA_LOC(self) -> array_typing.FloatVectorFace: return self.kappa @@ -131,6 +166,16 @@ def ZEFF(self) -> array_typing.FloatVectorFace: def RLNS_1(self) -> array_typing.FloatVectorFace: return self.lref_over_lne + @property + def RLNS_2(self) -> array_typing.FloatVectorFace: + return self.lref_over_lni0 + + # Uses normalized grad_nimp as TORAX explicitly tracks impurity density + # even though it is currently not being evolved + @property + def RLNS_3(self) -> array_typing.FloatVectorFace: + return self.lref_over_lni1 + @property def RLTS_1(self) -> array_typing.FloatVectorFace: return self.lref_over_lte @@ -139,10 +184,24 @@ def RLTS_1(self) -> array_typing.FloatVectorFace: def RLTS_2(self) -> array_typing.FloatVectorFace: return self.lref_over_lti + # Assumes normalized grad_Timp = normalized grad_Ti since TORAX does not track + # impurity temperature + @property + def RLTS_3(self) -> array_typing.FloatVectorFace: + return self.lref_over_lti + + @property + def P_PRIME_LOC(self) -> array_typing.FloatVectorFace: + return self.p_prime + @property def RMIN_LOC(self) -> array_typing.FloatVectorFace: return self.r_minor + @property + def RMAJ_LOC(self) -> array_typing.FloatVectorFace: + return self.r_major + @property def VEXB_SHEAR(self) -> array_typing.FloatVectorFace: return self.v_ExB_shear @@ -230,11 +289,33 @@ def _prepare_tglf_inputs( * core_profiles.psi.face_grad(x=geo.r_mid, x_left=r[0], x_right=r[-1]), (2 * jnp.pi * r), # Note: psi_TGLF is psi_TORAX/2π ) - rho_s = m_D * c_s / (constants.CONSTANTS.q_e * B_unit) # Ion gyroradius + + # Ion gyroradius + rho_s = math_utils.safe_divide( + m_D * c_s / constants.CONSTANTS.q_e, + B_unit, + ) + + # Debye length + # https://gacode.io/tglf/tglf_list.html#debye + # - In the TGLF docs, the prefactor of 743.0 comes from a combination of the + # constants below plus being in CGS units. Below is the SI version. + normalized_debye = math_utils.safe_divide( + ( + (constants.CONSTANTS.epsilon_0 / constants.CONSTANTS.q_e) + * 1.0e3 * core_profiles.T_e.face_value() + / n_e + ) ** 0.5, + rho_s, + ) # Temperature ratio Ti_over_Te = core_profiles.T_i.face_value() / core_profiles.T_e.face_value() + # Ion dilution + ni0_over_ne = core_profiles.n_i.face_value() / core_profiles.n_e.face_value() + ni1_over_ne = core_profiles.n_impurity.face_value() / core_profiles.n_e.face_value() + # Dimensionless gradients normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles( core_profiles=core_profiles, @@ -274,7 +355,7 @@ def _prepare_tglf_inputs( - 0.5 * jnp.log(constants.CONSTANTS.m_e) - 1.5 * jnp.log(T_e_J) ) - normalized_nu_ee = jnp.exp(log_nu_ee) / (c_s / a) + normalized_nu_ee = jnp.exp(log_nu_ee) / (c_s / a) * transport.collisionality_multiplier # Dimensionless safety factor shear # https://gacode.io/tglf/tglf_list.html#tglf-q-prime-loc @@ -289,13 +370,29 @@ def _prepare_tglf_inputs( r**2, ) + # Dimensionless pressure gradient + # https://gacode.io/tglf/tglf_list.html#tglf-p-prime-loc + # - In the TGLF docs, p_prime equation is shown in CGS units, this is the SI + # version + # - 8 * pi factor missing since TGLF internally operates on it using beta/(8*pi) + p_prime = math_utils.safe_divide( + 1.0e-7 + * core_profiles.pressure_thermal_total.face_grad(x=geo.r_mid, x_left=r[0], x_right=r[-1]) + * core_profiles.q_face + * a**2, + r * B_unit**2, + ) + # Electron beta # https://gacode.io/tglf/tglf_list.html#tglf-betae # https://gacode.io/cgyro.html#faq # https://gacode.io/cgyro/cgyro_list.html#betae-unit # - In the TGLF docs, beta_e equation shown in CGS units, this is the SI # version - beta_e = 2 * constants.CONSTANTS.mu_0 * n_e * T_e_J / B_unit**2 + beta_e = math_utils.safe_divide( + 2 * constants.CONSTANTS.mu_0 * n_e * T_e_J, + B_unit**2, + ) # Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface # centroid major radius and 'rmin' the flux surface centroid minor radius @@ -390,16 +487,21 @@ def _get_v_ExB_shear( lref_over_lni1=normalized_log_gradients.lref_over_lni1, # From TGLFInputs Ti_over_Te=Ti_over_Te, + ni0_over_ne=ni0_over_ne, + ni1_over_ne=ni1_over_ne, r_minor=r / a, + r_major=r_major / a, dr_major=dr_major, q=core_profiles.q_face, q_prime=q_prime, nu_ee=normalized_nu_ee, + debye=normalized_debye, kappa=kappa, kappa_shear=kappa_shear, delta=geo.delta_face, delta_shear=delta_shear, beta_e=beta_e, + p_prime=p_prime, Zeff=core_profiles.Z_eff_face, Q_GB=Q_GB, Gamma_GB=Gamma_GB, @@ -432,29 +534,41 @@ def _make_core_transport( # Note: g1/vpr = ⟨(∇ρₙ)²⟩ ∂V/∂ρₙ, and has units [m]. dT_e_drhon = core_profiles.T_e.face_grad() * constants.CONSTANTS.keV_to_J dT_i_drhon = core_profiles.T_i.face_grad() * constants.CONSTANTS.keV_to_J - chi_e = -P_e / ( - core_profiles.n_e.face_value() * dT_e_drhon * geo.g1_over_vpr_face + chi_e = math_utils.safe_divide( + -P_e, + core_profiles.n_e.face_value() * dT_e_drhon * geo.g1_over_vpr_face, ) - chi_i = -P_i / ( - core_profiles.n_i.face_value() * dT_i_drhon * geo.g1_over_vpr_face + chi_i = math_utils.safe_divide( + -P_i, + core_profiles.n_i.face_value() * dT_i_drhon * geo.g1_over_vpr_face, ) # Convert from particle rate to D, V using effective # diffusivity/convectivity method. This sets purely diffusive transport in # regions where the flux is with the temperature gradient, otherwise it # sets purely convective transport. - D_eff = -S_e / (core_profiles.n_e.face_grad() * geo.g1_over_vpr_face) - V_eff = S_e / (core_profiles.n_e.face_value() * geo.g0_face) - D_eff_mask = ((S_e >= 0) & (tglf_inputs.lref_over_lne >= 0)) | ( - (S_e < 0) & (tglf_inputs.lref_over_lne < 0) + D_eff = math_utils.safe_divide( + -S_e, + core_profiles.n_e.face_grad() * geo.g1_over_vpr_face, + ) + V_eff = math_utils.safe_divide( + S_e, + core_profiles.n_e.face_value() * geo.g0_face, + ) + D_eff = jnp.where(jnp.isfinite(D_eff), D_eff, 0.0) + V_eff = jnp.where(jnp.isfinite(V_eff), V_eff, 0.0) + D_eff_mask = ( + ((S_e >= 0) & (tglf_inputs.lref_over_lne >= 0)) + | ((S_e < 0) & (tglf_inputs.lref_over_lne < 0)) ) # For stability, we also set purely diffusive transport at some minimum # threshold of the temperature gradient. - D_eff_mask &= abs(tglf_inputs.lref_over_lne) >= transport.An_min + D_eff_mask &= (abs(tglf_inputs.lref_over_lne) >= (transport.An_min * geo.a_minor / geo.R_major)) + V_eff_mask = jnp.invert(D_eff_mask) # Apply the mask. d_face_el = jnp.where(D_eff_mask, D_eff, 0.0) - v_face_el = jnp.where(D_eff_mask, 0.0, V_eff) + v_face_el = jnp.where(V_eff_mask, V_eff, 0.0) return transport_model_lib.TurbulentTransport( chi_face_ion=chi_i, diff --git a/torax/_src/transport_model/tglf_transport_model.py b/torax/_src/transport_model/tglf_transport_model.py new file mode 100644 index 000000000..b4ab6612c --- /dev/null +++ b/torax/_src/transport_model/tglf_transport_model.py @@ -0,0 +1,617 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A transport model that calls TGLF. + +Used for generating ground truth for surrogate model evaluations. +""" + +import dataclasses +import datetime +import os +import copy +import subprocess +import multiprocessing +import tempfile +from typing import Annotated +from typing import Literal +from typing import Sequence, Mapping, Any +import uuid + +import chex +import jax +import numpy as np +import pydantic +from torax._src import jax_utils +from torax._src import constants +from torax._src import state +from torax._src.config import runtime_params as runtime_params_lib +from torax._src.geometry import geometry +from torax._src.pedestal_model import pedestal_model_output as pedestal_model_output_lib +from torax._src.torax_pydantic import torax_pydantic +from torax._src.transport_model import pydantic_model_base +from torax._src.transport_model import tglf_based_transport_model +from torax._src.transport_model import runtime_params as transport_runtime_params_lib +from torax._src.transport_model import transport_model + + +# pylint: disable=invalid-name +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class RuntimeParams(tglf_based_transport_model.RuntimeParams): + n_processes: int + kygrid_model: int + ky: float + n_ky: int + n_modes: int + geometry_flag: int + sat_rule: int + xnu_model: int + n_width: float + width_min: float + width: float + filter: float + theta_trapped: float + w_dia_trapped: float + sign_bt: float + sign_it: float + xnu_factor: float + debye_factor: float + etg_factor: float + find_width: bool + use_mhd_rule: bool + use_bpar: bool + use_bper: bool + use_inboard_detrapped: bool + use_ave_ion_grid: bool + alpha_e: float + alpha_zf: float + alpha_quench: float + n_xgrid: int + n_basis_min: int + n_basis_max: int + + +_DEFAULT_TGLFRUN_NAME_PREFIX = 'torax_tglf_runs' + + +def _get_tglf_exec_path() -> str: + default_tglf_exec_path = 'tglf' + return os.environ.get('TORAX_TGLF_EXEC_PATH', default_tglf_exec_path) + + +class TGLFTransportModel( + tglf_based_transport_model.TGLFBasedTransportModel +): + """Calculates turbulent transport coefficients with TGLF.""" + + def __init__(self): + self._tglfrun_parentdir = tempfile.TemporaryDirectory() + # Include UUID to prevent collisions when multiple simulations start + # simultaneously (e.g., in SLURM distributed systems) + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + short_uuid = uuid.uuid4().hex[:8] + # Check for SLURM job ID if available (common in distributed computing) + slurm_job_id = os.environ.get('SLURM_JOB_ID') + if slurm_job_id: + unique_suffix = f'job_{slurm_job_id}_uuid_{short_uuid}' + else: + unique_suffix = f'uuid_{short_uuid}' + self._tglfrun_name = ( + _DEFAULT_TGLFRUN_NAME_PREFIX + '_' + timestamp + '_' + unique_suffix + ) + self._runpath = os.path.join(self._tglfrun_parentdir.name, self._tglfrun_name) + self._frozen = True + + def call_implementation( + self, + transport_runtime_params: transport_runtime_params_lib.RuntimeParams, + runtime_params: runtime_params_lib.RuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + pedestal_model_output: pedestal_model_output_lib.PedestalModelOutput, + ) -> transport_model.TurbulentTransport: + """Calculates several transport coefficients simultaneously. + + Args: + transport_runtime_params: Input runtime parameters for this + transport model. + runtime_params: Input runtime parameters for all components + of the simulation at the current time. + geo: Geometry of the torus. + core_profiles: Core plasma profiles. + pedestal_model_output: Output of the pedestal model. + + Returns: + coeffs: transport coefficients + """ + del pedestal_model_output # Unused. + + # Required for pytype + assert isinstance(transport_runtime_params, RuntimeParams) + + tglf_inputs = self._prepare_tglf_inputs( + transport=transport_runtime_params, + geo=geo, + core_profiles=core_profiles, + poloidal_velocity_multiplier=runtime_params.neoclassical.poloidal_velocity_multiplier, # Carried over from TGLFNN + ) + + def callback(tglf_inputs, transport_runtime_params, geo, core_profiles): + # Keep mapping to numpy arrays + (tglf_inputs, transport_runtime_params, geo, core_profiles) = ( + jax.tree.map( + np.asarray, + (tglf_inputs, transport_runtime_params, geo, core_profiles), + ) + ) + tglf_plan = _extract_tglf_plan( + tglf_inputs=tglf_inputs, + transport=transport_runtime_params, + geo=geo, + core_profiles=core_profiles, + ) + self._run_tglf(tglf_plan, transport_runtime_params.n_processes) + core_transport = self._extract_run_data( + tglf_plan=tglf_plan, + tglf_inputs=tglf_inputs, + transport=transport_runtime_params, + geo=geo, + core_profiles=core_profiles, + ) + return core_transport + + face_array_shape_dtype = jax.ShapeDtypeStruct( + shape=(geo.torax_mesh.nx+1,), dtype=jax_utils.get_dtype() + ) + result_shape_dtypes = transport_model.TurbulentTransport( + chi_face_ion=face_array_shape_dtype, + chi_face_el=face_array_shape_dtype, + d_face_el=face_array_shape_dtype, + v_face_el=face_array_shape_dtype, + ) + # Even though tglf has side-effects (writing and reading from disk) we + # still use a pure_callback here as: + # 1. Nothing outside of this method depends on the side-effect. + # 2. We don't mind if results are cached or recomputed. + # 3. DCE will not happen here as we make use of the `core_transport` result. + # This is based on the current implementation of pure_callback and JAX + # may change the implementation making this not appropriate down the line. + core_transport = jax.pure_callback( + callback, + result_shape_dtypes, + tglf_inputs, + transport_runtime_params, + geo, + core_profiles, + ) + + return core_transport + + + def _run_tglf( + self, + tglf_plan: Sequence[Mapping[str, Any]], + n_processes: int, + verbose: bool = True, + ) -> None: + """Runs TGLF using command line tools. Loose coupling with TORAX.""" + execution_path = _get_tglf_exec_path() + + os.makedirs(self._runpath, exist_ok=True) + for run in tglf_plan: + run.update({'execution_path': execution_path, 'run_path': self._runpath, 'verbose': verbose}) + path = os.path.join(self._runpath, run['location']) + os.makedirs(path, exist_ok=True) + fstr = '\n'.join([f'{k}={v}' for k, v in run['inputs'].items()]) + with open(path+'/input.tglf','w+') as f: + f.write(fstr) + + # Spawns a new "server" process, from which child processes are forked to + # farm out single-radius TGLF runs in parallel. + # Safer than using fork method, and faster than using spawn method. + ctx = multiprocessing.get_context('forkserver') + queue = ctx.Queue() + with ctx.Manager() as manager: + queue = manager.Queue() + with ctx.Pool(processes=n_processes) as pool: + arguments = [(run, queue) for run in tglf_plan] + _ = pool.starmap(_run_tglf_single, arguments) + if verbose: + # Use to print detailed stdout / stderr messages from child processes + for i in range(len(arguments)): + print(queue.get()) + + + def _extract_run_data( + self, + tglf_plan: Sequence[Mapping[str, Any]], + tglf_inputs: tglf_based_transport_model.TGLFInputs, + transport: RuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> transport_model.TurbulentTransport: + """Extracts TGLF run data from runpath.""" + + qe = np.zeros((len(tglf_plan), )) + qi = np.zeros((len(tglf_plan), )) + ge = np.zeros((len(tglf_plan), )) + for i, run in enumerate(tglf_plan): + gbfluxes = np.loadtxt(os.path.join(self._runpath, run['location'], 'out.tglf.gbflux')) + nspecies = len(gbfluxes) // 4 + qe[i] = float(gbfluxes[1*nspecies+0]) # Defined TGLF species 1 as electrons + qi[i] = float(sum(gbfluxes[1*nspecies+1:1*nspecies+nspecies])) + ge[i] = float(gbfluxes[0*nspecies+0]) + + return self._make_core_transport( + electron_heat_flux_GB=qe, + ion_heat_flux_GB=qi, + electron_particle_flux_GB=ge, + tglf_inputs=tglf_inputs, + transport=transport, + geo=geo, + core_profiles=core_profiles, + #gradient_reference_length=geo.R_major, + #gyrobohm_flux_reference_length=geo.a_minor, + ) + + def __hash__(self) -> int: + return hash(('TGLFTransportModel' + self._runpath)) + + def __eq__(self, other) -> bool: + return ( + isinstance(other, TGLFTransportModel) + and self._runpath == other._runpath + ) + + +def _extract_tglf_plan( + tglf_inputs: tglf_based_transport_model.TGLFInputs, + transport: RuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, +) -> Sequence[Mapping[str, Any]]: + """Converts TORAX parameters to TGLF input dictionary. + + Args: + tglf_inputs: Precomputed physics data. + transport: Runtime parameters for the qualikiz transport model. + geo: TORAX geometry object. + core_profiles: TORAX CoreProfiles object, containing time-evolvable + quantities like q + + Returns: + A list of dictionaries containing TGLF input namelists + """ + + species_template = { + 'ZS': None, + 'MASS': None, + 'RLNS': None, + 'RLTS': None, + 'TAUS': None, + 'AS': None, + 'VPAR': 0.0, + 'VPAR_SHEAR': 0.0, + } + tglf_input_template = { + # Control + 'UNITS': 'CGYRO', + 'NS': 3, + 'USE_TRANSPORT_MODEL': '.true.', + 'GEOMETRY_FLAG': transport.geometry_flag, + 'USE_BPER': '.true.' if transport.use_bper else '.false.', + 'USE_BPAR': '.true.' if transport.use_bpar else '.false.', + 'USE_BISECTION': '.true.', + 'USE_MHD_RULE': '.true.' if transport.use_mhd_rule else '.false.', + 'USE_INBOARD_DETRAPPED': '.true.' if transport.use_inboard_detrapped else '.false.', + 'USE_AVE_ION_GRID': '.true.' if transport.use_ave_ion_grid else '.false.', + 'SAT_RULE': transport.sat_rule, + 'KYGRID_MODEL': transport.kygrid_model, + 'XNU_MODEL': transport.xnu_model, + 'VPAR_MODEL': 0, + 'SIGN_BT': transport.sign_bt, + 'SIGN_IT': transport.sign_it, + 'KY': transport.ky, + 'NEW_EIKONAL': '.true.', + 'VEXB': 0.0, + 'VEXB_SHEAR': 0.0, + 'BETAE': 0.0, + 'XNUE': 0.0, + 'ZEFF': 1.0, + 'DEBYE': 0.0, + 'IFLUX': '.true.', + 'IBRANCH': -1, + 'NMODES': transport.n_modes, + 'NBASIS_MIN': transport.n_basis_min, + 'NBASIS_MAX': transport.n_basis_max, + 'NXGRID': transport.n_xgrid, + 'NKY': transport.n_ky, + 'ADIABATIC_ELEC': '.false.', + # Multipliers + 'ALPHA_P': 1.0, + 'ALPHA_MACH': 0.0, + 'ALPHA_E': transport.alpha_e, + 'ALPHA_QUENCH': transport.alpha_quench, + 'ALPHA_ZF': transport.alpha_zf, + 'XNU_FACTOR': transport.xnu_factor, + 'DEBYE_FACTOR': transport.debye_factor, + 'ETG_FACTOR': transport.etg_factor, + 'B_MODEL_SA': 1, + 'FT_MODEL_SA': 1, + # Gaussian mode width search + 'WRITE_WAVEFUNCTION_FLAG': 0, + 'WIDTH_MIN': transport.width_min, + 'WIDTH': transport.width, + 'NWIDTH': transport.n_width, + 'FIND_WIDTH': '.true.' if transport.find_width else '.false.', + # Miller shape parameters + 'RMIN_LOC': 0.5, + 'RMAJ_LOC': 3.0, + 'ZMAJ_LOC': 0.0, + 'Q_LOC': 2.0, + 'Q_PRIME_LOC': 16.0, + 'P_PRIME_LOC': 0.0, + 'DRMINDX_LOC': 1.0, + 'DRMAJDX_LOC': 0.0, + 'DZMAJDX_LOC': 0.0, + 'KAPPA_LOC': 1.0, + 'S_KAPPA_LOC': 0.0, + 'DELTA_LOC': 0.0, + 'S_DELTA_LOC': 0.0, + 'ZETA_LOC': 0.0, + 'S_ZETA_LOC': 0.0, + 'KX0_LOC': 0.0, + # Expert options + 'THETA_TRAPPED': transport.theta_trapped, + 'PARK': 1.0, + 'GHAT': 1.0, + 'GCHAT': 1.0, + 'WD_ZERO': 0.1, + 'LINSKER_FACTOR': 0.0, + 'GRADB_FACTOR': 0.0, + 'FILTER': transport.filter, + 'DAMP_PSI': 0.0, + 'DAMP_SIG': 0.0, + 'WDIA_TRAPPED': transport.w_dia_trapped, + } + + for species_number in range(1, 4): + tglf_input_template.update({f'{key}_{species_number}': value for key, value in species_template.items()}) + + tglf_plan = [] + zi0 = np.array(core_profiles.Z_i_face) + ai0 = np.array(core_profiles.A_i) + zi1 = np.array(core_profiles.Z_impurity_face) + ai1 = np.array(core_profiles.A_impurity_face) + for i, rho in enumerate(np.array(geo.rho_face_norm)): + tglf_runpars = copy.deepcopy(tglf_input_template) + tglf_runpars['BETAE'] = float(tglf_inputs.BETAE[i]) + tglf_runpars['XNUE'] = float(tglf_inputs.XNUE[i]) + tglf_runpars['ZEFF'] = float(tglf_inputs.ZEFF[i]) + tglf_runpars['DEBYE'] = float(tglf_inputs.DEBYE[i]) + tglf_runpars['RMIN_LOC'] = float(tglf_inputs.RMIN_LOC[i]) + tglf_runpars['RMAJ_LOC'] = float(tglf_inputs.RMAJ_LOC[i]) + tglf_runpars['Q_LOC'] = float(tglf_inputs.Q_LOC[i]) + tglf_runpars['Q_PRIME_LOC'] = float(tglf_inputs.Q_PRIME_LOC[i]) + tglf_runpars['P_PRIME_LOC'] = float(tglf_inputs.P_PRIME_LOC[i]) + tglf_runpars['DRMAJDX_LOC'] = float(tglf_inputs.DRMAJDX_LOC[i]) + tglf_runpars['KAPPA_LOC'] = float(tglf_inputs.KAPPA_LOC[i]) + tglf_runpars['S_KAPPA_LOC'] = float(tglf_inputs.S_KAPPA_LOC[i]) + tglf_runpars['DELTA_LOC'] = float(tglf_inputs.DELTA_LOC[i]) + tglf_runpars['S_DELTA_LOC'] = float(tglf_inputs.S_DELTA_LOC[i]) + tglf_runpars['ZS_1'] = -1.0 + tglf_runpars['MASS_1'] = float(constants.CONSTANTS.m_e / (constants.CONSTANTS.m_amu * constants.ION_PROPERTIES_DICT['D'].A)) + tglf_runpars['RLNS_1'] = float(tglf_inputs.RLNS_1[i]) + tglf_runpars['RLTS_1'] = float(tglf_inputs.RLTS_1[i]) + tglf_runpars['TAUS_1'] = 1.0 + tglf_runpars['AS_1'] = 1.0 + tglf_runpars['ZS_2'] = float(zi0[i]) + tglf_runpars['MASS_2'] = float(ai0 / constants.ION_PROPERTIES_DICT['D'].A) + tglf_runpars['RLNS_2'] = float(tglf_inputs.RLNS_2[i]) + tglf_runpars['RLTS_2'] = float(tglf_inputs.RLTS_2[i]) + tglf_runpars['TAUS_2'] = float(tglf_inputs.TAUS_2[i]) + tglf_runpars['AS_2'] = float(tglf_inputs.AS_2[i]) + tglf_runpars['ZS_3'] = float(zi1[i]) + tglf_runpars['MASS_3'] = float(ai1[i] / constants.ION_PROPERTIES_DICT['D'].A) + tglf_runpars['RLNS_3'] = float(tglf_inputs.RLNS_3[i]) + tglf_runpars['RLTS_3'] = float(tglf_inputs.RLTS_3[i]) + tglf_runpars['TAUS_3'] = float(tglf_inputs.TAUS_3[i]) + tglf_runpars['AS_3'] = float(tglf_inputs.AS_3[i]) + tglf_plan.append({ + 'inputs': tglf_runpars, + 'location': f'tglf_run_{i:04d}', + }) + + return tglf_plan + + +def _run_tglf_single(run, queue): + """ Function to insert TGLF run command into multiprocessing queue for parallelization. """ + + # Each TGLF run hardcoded to run with 2 cores, benefit from having more simultaneous runs + # seen to outweigh higher parallelization of single run. + command = [ + run['execution_path'], + '-n', + str(2), + '-e', + str(run['location']) + ] + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=run['run_path'], + ) + + if run['verbose']: + # Get output and error messages + stdout, stderr = process.communicate() + + # Print the output + ostr = stdout.decode() + + # Print any error messages + if stderr: + ostr += stderr.decode() + + queue.put(ostr) + + +class TGLFTransportModelConfig(pydantic_model_base.TransportBase): + r"""Model for the TGLF transport model. + + Attributes: + model_name: The transport model to use. Hardcoded to 'tglf'. + n_processes: Set number of parallel TGLF calculations to run, each using 2 cores. + DV_effective: Effective D / effective V approach for particle transport. + An_min: Minimum |R/Lne| below which effective V is used instead of effective + D. + collisionality_multiplier: Collisionality multiplier. + kygrid_model: 0 = user-defined with n_ky points equally spaced with kymin = ky/n_ky. + 1 = standard ky spectrum for SAT0 and SAT1 with kymin=0.1/rho_i. + 4 = standard ky spectrum with more low ky points and kymin=0.05*grad_r0/rho_i. + ky: Specify wavenumber for single wavenumber call, or set user-defined ky grid with kygrid_model=0. + n_ky: Number of ky points with kygrid_model=0, else number of additional logarithmically + equally spaced points within 1 < ky < 24 when using kygrid_model=4. + n_modes: Number of eigenmodes to track, advise to use num_species+2 for efficiency. + geometry_flag: 0 = s-alpha, 1 = Miller/MXH, 2 = Fourier, 3 = ELITE. + sat_rule: Specify quasilinear saturation rule used to compute transport fluxes. + xnu_model: Specify collision model. 2 = default for SAT0 and SAT1, 3 = default for SAT2 and SAT3. + n_width: Maximum number of mode widths in mode width scan. + width_min: Minimum value for mode width scan, set negative for electromagnetic search. + width: Maximum value for mode width scan. + filter: Set frequency threshold to filter out non-drift-wave instabilities. + theta_trapped: Adjustment parameter for trapped fraction model. Set to 0.4 with n_basis_max = 8 + for low aspect ratio (https://eprints.whiterose.ac.uk/159700/). + w_dia_trapped: Non-standard option. Set to 1.0 for SAT2 and SAT3. + sign_bt: Sign of toroidal field, positive = CCW from the top. + sign_it: Sign of toroidal current, positive = CCW from the top. + xnu_factor: Multiplier for the trapped/passing boundary collision terms, not the same as collisionality_multiplier. + debye_factor: Multiplier for the normalized Debye length. + etg_factor: Exponent for the ETG saturation rule. + find_width: Toggle mode width scan for maximum growth rate search, uses width argument if False. + use_mhd_rule: If True, ignore pressure gradient contribution to curvature drift. Recommended to set False for high beta. + use_bpar: If True, include compressional magnetic fluctuations, :math:`\delta B_{\par}`. + use_bper: If true, include transverse magnetic fluctuations, :math:`\delta B_{\perp}`. + use_inboard_detrapped: If True, set trapped fraction to zero if eigenmode is inward ballooning. + use_ave_ion_grid: If True, use average ion gyroradius as opposed to main ion gyroradius + alpha_e: Multiplier for ExB velocity shear for spectral shift model. + alpha_zf: Zonal flow mixing coefficient. If -1.0, toggles switch that avoids picking lowest ky as maximum gamma/ky + for intensity spectrum shape in quasilinear saturation rules. + alpha_quench: 0.0 = use spectral shift model, 1.0 = use quench rule. + n_xgrid: Number of nodes in Gauss-Hermite quadrature. Recommended to use 4 * n_basis_max + n_basis_min: Minimum number of parallel basis functions (Hermite polynomials) used to find mode width. + n_basis_max: Maximum number of parallel basis functions (Hermite polynomials) used to find mode width. + Recommended 4 for SAT0 and SAT1, 6 for SAT2 and SAT3. + """ + + model_name: Annotated[Literal['tglf'], torax_pydantic.JAX_STATIC] = ( + 'tglf' + ) + n_processes: pydantic.PositiveInt = 8 + use_rotation: bool = False + rotation_multiplier: pydantic.NonNegativeFloat = 1.0 + DV_effective: bool = False + An_min: pydantic.PositiveFloat = 0.05 + collisionality_multiplier: float = 1.0 + + # Species settings + #n_species: pydantic.PositiveInt = 3 # currently hardcoded as 3 in TORAX + + # Mode settings + kygrid_model: pydantic.PositiveInt = 4 + ky: pydantic.PositiveFloat = 0.3 + n_ky: pydantic.PositiveInt = 19 + n_modes: pydantic.PositiveInt = 5 + + # Model settings + geometry_flag: pydantic.PositiveInt = 1 + sat_rule: pydantic.PositiveInt = 3 + xnu_model: pydantic.PositiveInt = 3 + n_width: pydantic.PositiveInt = 21 + width_min: pydantic.FiniteFloat = -0.3 + width: pydantic.PositiveFloat = 1.65 + filter: pydantic.FiniteFloat = 2.0 + theta_trapped: pydantic.PositiveFloat = 0.7 + w_dia_trapped: pydantic.PositiveFloat = 1.0 + sign_bt: pydantic.FiniteFloat = 1.0 + sign_it: pydantic.FiniteFloat = 1.0 + xnu_factor: pydantic.PositiveFloat = 1.0 + debye_factor: pydantic.PositiveFloat = 1.0 + etg_factor: pydantic.FiniteFloat = 1.25 + + # Flags + find_width: bool = True + use_mhd_rule: bool = False + use_bpar: bool = True + use_bper: bool = False + use_inboard_detrapped: bool = False + use_ave_ion_grid: bool = False + + # Multipliers + alpha_e: pydantic.FiniteFloat = 1.0 + alpha_zf: pydantic.FiniteFloat = 1.0 + alpha_quench: pydantic.FiniteFloat = 0.0 + + # Numerical grid settings + n_xgrid: pydantic.PositiveInt = 16 + n_basis_min: pydantic.PositiveInt = 2 + n_basis_max: pydantic.PositiveInt = 6 + + def build_transport_model(self) -> TGLFTransportModel: + return TGLFTransportModel() + + def build_runtime_params(self, t: chex.Numeric) -> RuntimeParams: + base_kwargs = dataclasses.asdict(super().build_runtime_params(t)) + return RuntimeParams( + n_processes=self.n_processes, + use_rotation=self.use_rotation, + rotation_multiplier=self.rotation_multiplier, + DV_effective=self.DV_effective, + collisionality_multiplier=self.collisionality_multiplier, + An_min=self.An_min, + kygrid_model = self.kygrid_model, + ky = self.ky, + n_ky = self.n_ky, + n_modes = self.n_modes, + geometry_flag = self.geometry_flag, + sat_rule = self.sat_rule, + xnu_model = self.xnu_model, + n_width = self.n_width, + width_min = self.width_min, + width = self.width, + filter = self.filter, + theta_trapped = self.theta_trapped, + w_dia_trapped = self.w_dia_trapped, + sign_bt = self.sign_bt, + sign_it = self.sign_it, + xnu_factor = self.xnu_factor, + debye_factor = self.debye_factor, + etg_factor = self.etg_factor, + find_width = self.find_width, + use_mhd_rule = self.use_mhd_rule, + use_bpar = self.use_bpar, + use_bper = self.use_bper, + use_inboard_detrapped = self.use_inboard_detrapped, + use_ave_ion_grid = self.use_ave_ion_grid, + alpha_e = self.alpha_e, + alpha_zf = self.alpha_zf, + alpha_quench = self.alpha_quench, + n_xgrid = self.n_xgrid, + n_basis_min = self.n_basis_min, + n_basis_max = self.n_basis_max, + **base_kwargs, + ) diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea.nc b/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea.nc index 8929448a2..3fcde66fe 100644 Binary files a/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea.nc and b/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea.nc differ diff --git a/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea_rotation.nc b/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea_rotation.nc index 299ec1271..aaa52bc24 100644 Binary files a/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea_rotation.nc and b/torax/tests/test_data/test_iterhybrid_predictor_corrector_tglfnn_ukaea_rotation.nc differ