Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 58 additions & 12 deletions torax/_src/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,23 @@ class InterpolationMode(enum.Enum):
input greater than x_n.

Options:
NONE: No interpolation. Values are zeros everywhere except for the provided
points. Should only be used internally, for the rho dimension in
InterpolatedVarTimeRho.
PIECEWISE_LINEAR: Does piecewise-linear interpolation between the values
provided. See numpy.interp for a longer description of how it works. (This
uses JAX, but the behavior is the same.)
STEP: Step-function interpolation. For any input value x in the range [x_k,
x_k+1), the output will be y_k.
"""

NONE = 'none'
PIECEWISE_LINEAR = 'piecewise_linear'
STEP = 'step'


InterpolationModeLiteral: TypeAlias = Literal[
'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR'
'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR', 'none', 'NONE'
]


Expand Down Expand Up @@ -515,18 +519,60 @@ def __init__(
self._rho_interpolation_mode = rho_interpolation_mode
self._time_interpolation_mode = time_interpolation_mode

sorted_indices = np.array(sorted(values.keys()))
rho_norm_interpolated_values = np.stack(
[
InterpolatedVarSingleAxis(
values[t], rho_interpolation_mode
).get_value(rho_norm)
for t in sorted_indices
],
axis=0,
)
sorted_times = np.array(sorted(values.keys()))
if self._rho_interpolation_mode == InterpolationMode.NONE:
# If no rho interpolation is needed, the given rho_norm locations will be
# quantized onto the rho_norm grid, and the values are delta functions at
# the quantized rho_norm locations.

# Check that the same rho_norm values are used for all times
given_rho_norm_locations = values[sorted_times[0]][0]
for t in sorted_times[1:]:
if not np.array_equal(values[t][0], given_rho_norm_locations):
raise ValueError(
'When rho_interpolation_mode is InterpolationMode.NONE, the'
'rho_norm locations must be the same for all times.'
)

# Find the index of the closest rho_norm values on the grid to the
# requested rho_norm locations for each time.
quantized_rho_norm_indices = np.stack(
[
[
np.argmin(np.abs(rho_norm - given_rho_norm_locations))
# values[t] is a tuple of (rho_norm, values) arrays, so
# values[t][0] is the set of specified rho_norm locations for
# the time t.
for given_rho_norm_locations in values[t][0]
]
for t in sorted_times
],
axis=0,
)

# Convert to a grid of delta functions of height `values` at the quantized
# rho_norm locations.
rho_norm_interpolated_values = np.zeros(
(len(sorted_times), len(rho_norm)),
dtype=jax_utils.get_np_dtype(),
)
rho_norm_interpolated_values[
np.arange(len(sorted_times))[:, None], quantized_rho_norm_indices
] = np.array(
[values[t][1] for t in sorted_times], dtype=jax_utils.get_np_dtype()
)
else:
rho_norm_interpolated_values = np.stack(
[
InterpolatedVarSingleAxis(
values[t], rho_interpolation_mode
).get_value(rho_norm)
for t in sorted_times
],
axis=0,
)
self._time_interpolated_var = InterpolatedVarSingleAxis(
value=(sorted_indices, rho_norm_interpolated_values),
value=(sorted_times, rho_norm_interpolated_values),
interpolation_mode=time_interpolation_mode,
)

Expand Down
18 changes: 18 additions & 0 deletions torax/_src/torax_pydantic/interpolated_param_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Grid1D(model_base.BaseModelFrozen):
of all faces (including boundary faces). For a grid with N cells, there are
N+1 faces.
"""

face_centers: pydantic_types.NumpyArray1DSorted

@pydantic.model_validator(mode='before')
Expand Down Expand Up @@ -444,6 +445,23 @@ def get_cached_interpolated_param_face_right(
)


class TimeVaryingPoints(TimeVaryingArray):
"""A TimeVaryingArray that is defined on a fixed set of rho points, without interpolation in rho."""

rho_interpolation_mode: Literal[interpolated_param.InterpolationMode.NONE] = (
interpolated_param.InterpolationMode.NONE
)

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(
cls, data: interpolated_param.TimeRhoInterpolatedInput | dict[str, Any]
) -> dict[str, Any]:
data = super()._conform_data(data)
data['rho_interpolation_mode'] = interpolated_param.InterpolationMode.NONE
return data


def _is_positive(array: TimeVaryingArray) -> TimeVaryingArray:
for _, value in array.value.values():
if not np.all(value > 0):
Expand Down
34 changes: 34 additions & 0 deletions torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,40 @@ def f(
np.testing.assert_allclose(cell, [8.0, 10.0, 12.0, 14.0])
self.assertEqual(jax_utils.get_number_of_compiles(f), 1)

@parameterized.named_parameters(
dict(
testcase_name='dict_input',
time_rho_interpolated_input={
0.0: {0.0: 1.0, 1.0: 2.0},
1.0: {0.0: 3.0, 1.0: 4.0},
},
expected_values={
0.0: np.array([1.0, 0.0, 0.0, 2.0]),
0.5: np.array([2.0, 0.0, 0.0, 3.0]),
1.0: np.array([3.0, 0.0, 0.0, 4.0]),
},
),
dict(
testcase_name='rho_norm_quantization',
time_rho_interpolated_input={0.0: {0.5: 1.0}},
expected_values={
0.0: np.array([0.0, 1.0, 0.0, 0.0]),
},
),
)
def test_time_varying_points(
self, time_rho_interpolated_input, expected_values
):
tvp = interpolated_param_2d.TimeVaryingPoints.model_validate(
time_rho_interpolated_input
)
face_centers = interpolated_param_2d.get_face_centers(4)
grid = interpolated_param_2d.Grid1D(face_centers=face_centers)
interpolated_param_2d.set_grid(tvp, grid=grid)

for t, expected_value in expected_values.items():
np.testing.assert_array_equal(tvp.get_value(t), expected_value)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions torax/_src/torax_pydantic/torax_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
BaseModelFrozen = model_base.BaseModelFrozen

TimeVaryingScalar = interpolated_param_1d.TimeVaryingScalar
TimeVaryingPoints = interpolated_param_2d.TimeVaryingPoints
TimeVaryingArray = interpolated_param_2d.TimeVaryingArray
NonNegativeTimeVaryingArray = interpolated_param_2d.NonNegativeTimeVaryingArray
PositiveTimeVaryingScalar = interpolated_param_1d.PositiveTimeVaryingScalar
Expand Down
Loading