Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions torax/_src/transport_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import copy
import dataclasses
import itertools
from typing import Annotated, Any, Literal, Sequence

from absl import logging
import chex
from fusion_surrogates.qlknn.models import registry
Expand All @@ -34,6 +34,7 @@
from torax._src.transport_model import qlknn_10d
from torax._src.transport_model import qlknn_transport_model
from torax._src.transport_model import qualikiz_based_transport_model
from torax._src.transport_model import tglf_transport_model
from torax._src.transport_model import tglfnn_ukaea_transport_model
import typing_extensions

Expand Down Expand Up @@ -220,6 +221,7 @@ class TGLFNNukaeaTransportModel(pydantic_model_base.TransportBase):
# Quasilinear transport options
DV_effective: bool = False
An_min: pydantic.PositiveFloat = 0.05
collisionality_multiplier: float = 1.0

def build_transport_model(
self,
Expand All @@ -237,6 +239,7 @@ def build_runtime_params(
An_min=self.An_min,
rotation_multiplier=self.rotation_multiplier,
use_rotation=self.use_rotation,
collisionality_multiplier=self.collisionality_multiplier,
# From base
**base_kwargs,
)
Expand Down Expand Up @@ -425,6 +428,7 @@ def build_runtime_params(
| ConstantTransportModel
| CriticalGradientTransportModel
| BohmGyroBohmTransportModel
| tglf_transport_model.TGLFTransportModelConfig
| qualikiz_transport_model.QualikizTransportModelConfig
)

Expand All @@ -435,6 +439,7 @@ def build_runtime_params(
| ConstantTransportModel
| CriticalGradientTransportModel
| BohmGyroBohmTransportModel
| tglf_transport_model.TGLFTransportModelConfig
)


Expand Down Expand Up @@ -518,7 +523,11 @@ def _check_fields(self) -> typing_extensions.Self:
any([
np.any(model.apply_inner_patch.value)
or np.any(model.apply_outer_patch.value)
for model in self.transport_models + self.pedestal_transport_models
# Use itertools.chain to iterate over both lists of models without
# needing to make a new list.
for model in itertools.chain(
self.transport_models, self.pedestal_transport_models
)
])
or np.any(self.apply_inner_patch.value)
or np.any(self.apply_outer_patch.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self):
"lref_over_lne",
"lref_over_lni0",
"lref_over_lni1",
"Ti_over_Te",
"T_i_over_T_e",
"r_minor",
"dr_major",
"q",
Expand All @@ -136,7 +136,7 @@ def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self):
"delta",
"delta_shear",
"beta_e",
"Zeff",
"Z_eff",
]
scalar_keys = ["Rmaj", "Rmin"]
expected_vector_length = geo.rho_face_norm.shape[0]
Expand Down Expand Up @@ -224,6 +224,7 @@ def build_runtime_params(self, t: chex.Numeric):
An_min=0.05,
use_rotation=True,
rotation_multiplier=1.0,
collisionality_multiplier=1.0,
**base_kwargs,
)

Expand Down
94 changes: 94 additions & 0 deletions torax/_src/transport_model/tests/tglf_transport_model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2026 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as np
from torax._src.config import build_runtime_params
from torax._src.core_profiles import initialization
from torax._src.pedestal_model import pedestal_model_output as pedestal_model_output_lib
from torax._src.test_utils import default_configs
from torax._src.torax_pydantic import model_config


class TGLFTransportModelTest(parameterized.TestCase):

@parameterized.named_parameters(
('with_jit', True),
('without_jit', False),
)
def test_call(self, jit: bool):
"""Tests that the model can be called (with entirely mocked TGLF)."""
config = default_configs.get_default_config_dict()
config['transport'] = {'model_name': 'tglf', 'tglf_exec_path': '~/tglf'}
torax_config = model_config.ToraxConfig.from_dict(config)
source_models = torax_config.sources.build_models()
neoclassical_models = torax_config.neoclassical.build_models()
transport_model = torax_config.transport.build_transport_model()
runtime_params = build_runtime_params.RuntimeParamsProvider.from_config(
torax_config
)(
t=torax_config.numerics.t_initial,
)
geo = torax_config.geometry.build_provider(torax_config.numerics.t_initial)
core_profiles = initialization.initial_core_profiles(
runtime_params=runtime_params,
geo=geo,
source_models=source_models,
neoclassical_models=neoclassical_models,
)

def _mock_subprocess_run(cmd, **kwargs):
"""Write a fake TGLF output file and return a mock subprocess result."""
del kwargs # Unused.

# cmd is [tglf_exec_path, '-n', n_cores_per_process, '-e', run_directory]
# Extract the run directory from the command
run_dir = cmd[-1]

# Populate the run directory with a fake output file.
os.makedirs(run_dir, exist_ok=True)
with open(os.path.join(run_dir, 'out.tglf.gbflux'), 'w') as f:
f.write('\n'.join(['1.0'] * 12))

# Return a mock subprocess result with fake stdout and stderr.
result = mock.Mock()
result.stdout = 'stdout'
result.stderr = 'stderr'
return result

with mock.patch.object(subprocess, 'run', side_effect=_mock_subprocess_run):
model_call = (
jax.jit(transport_model.__call__) if jit else transport_model.__call__
)
model_call(
runtime_params,
geo,
core_profiles,
pedestal_model_output_lib.PedestalModelOutput(
rho_norm_ped_top=np.inf,
rho_norm_ped_top_idx=geo.torax_mesh.nx,
T_i_ped=0.0,
T_e_ped=0.0,
n_e_ped=0.0,
),
)


if __name__ == '__main__':
absltest.main()
Loading
Loading