diff --git a/torax/_src/interpolated_param.py b/torax/_src/interpolated_param.py index 9e7b176d9..e8794590b 100644 --- a/torax/_src/interpolated_param.py +++ b/torax/_src/interpolated_param.py @@ -58,6 +58,9 @@ class InterpolationMode(enum.Enum): input greater than x_n. Options: + NONE: No interpolation. Values are zeros everywhere except for the provided + points. Should only be used internally, for the rho dimension in + InterpolatedVarTimeRho. PIECEWISE_LINEAR: Does piecewise-linear interpolation between the values provided. See numpy.interp for a longer description of how it works. (This uses JAX, but the behavior is the same.) @@ -65,12 +68,13 @@ class InterpolationMode(enum.Enum): x_k+1), the output will be y_k. """ + NONE = 'none' PIECEWISE_LINEAR = 'piecewise_linear' STEP = 'step' InterpolationModeLiteral: TypeAlias = Literal[ - 'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR' + 'step', 'STEP', 'piecewise_linear', 'PIECEWISE_LINEAR', 'none', 'NONE' ] @@ -515,18 +519,60 @@ def __init__( self._rho_interpolation_mode = rho_interpolation_mode self._time_interpolation_mode = time_interpolation_mode - sorted_indices = np.array(sorted(values.keys())) - rho_norm_interpolated_values = np.stack( - [ - InterpolatedVarSingleAxis( - values[t], rho_interpolation_mode - ).get_value(rho_norm) - for t in sorted_indices - ], - axis=0, - ) + sorted_times = np.array(sorted(values.keys())) + if self._rho_interpolation_mode == InterpolationMode.NONE: + # If no rho interpolation is needed, the given rho_norm locations will be + # quantized onto the rho_norm grid, and the values are delta functions at + # the quantized rho_norm locations. + + # Check that the same rho_norm values are used for all times + given_rho_norm_locations = values[sorted_times[0]][0] + for t in sorted_times[1:]: + if not np.array_equal(values[t][0], given_rho_norm_locations): + raise ValueError( + 'When rho_interpolation_mode is InterpolationMode.NONE, the' + 'rho_norm locations must be the same for all times.' + ) + + # Find the index of the closest rho_norm values on the grid to the + # requested rho_norm locations for each time. + quantized_rho_norm_indices = np.stack( + [ + [ + np.argmin(np.abs(rho_norm - given_rho_norm_locations)) + # values[t] is a tuple of (rho_norm, values) arrays, so + # values[t][0] is the set of specified rho_norm locations for + # the time t. + for given_rho_norm_locations in values[t][0] + ] + for t in sorted_times + ], + axis=0, + ) + + # Convert to a grid of delta functions of height `values` at the quantized + # rho_norm locations. + rho_norm_interpolated_values = np.zeros( + (len(sorted_times), len(rho_norm)), + dtype=jax_utils.get_np_dtype(), + ) + rho_norm_interpolated_values[ + np.arange(len(sorted_times))[:, None], quantized_rho_norm_indices + ] = np.array( + [values[t][1] for t in sorted_times], dtype=jax_utils.get_np_dtype() + ) + else: + rho_norm_interpolated_values = np.stack( + [ + InterpolatedVarSingleAxis( + values[t], rho_interpolation_mode + ).get_value(rho_norm) + for t in sorted_times + ], + axis=0, + ) self._time_interpolated_var = InterpolatedVarSingleAxis( - value=(sorted_indices, rho_norm_interpolated_values), + value=(sorted_times, rho_norm_interpolated_values), interpolation_mode=time_interpolation_mode, ) diff --git a/torax/_src/torax_pydantic/interpolated_param_2d.py b/torax/_src/torax_pydantic/interpolated_param_2d.py index 08a60ff9f..c5b6d4be0 100644 --- a/torax/_src/torax_pydantic/interpolated_param_2d.py +++ b/torax/_src/torax_pydantic/interpolated_param_2d.py @@ -49,6 +49,7 @@ class Grid1D(model_base.BaseModelFrozen): of all faces (including boundary faces). For a grid with N cells, there are N+1 faces. """ + face_centers: pydantic_types.NumpyArray1DSorted @pydantic.model_validator(mode='before') @@ -444,6 +445,23 @@ def get_cached_interpolated_param_face_right( ) +class TimeVaryingPoints(TimeVaryingArray): + """A TimeVaryingArray that is defined on a fixed set of rho points, without interpolation in rho.""" + + rho_interpolation_mode: Literal[interpolated_param.InterpolationMode.NONE] = ( + interpolated_param.InterpolationMode.NONE + ) + + @pydantic.model_validator(mode='before') + @classmethod + def _conform_data( + cls, data: interpolated_param.TimeRhoInterpolatedInput | dict[str, Any] + ) -> dict[str, Any]: + data = super()._conform_data(data) + data['rho_interpolation_mode'] = interpolated_param.InterpolationMode.NONE + return data + + def _is_positive(array: TimeVaryingArray) -> TimeVaryingArray: for _, value in array.value.values(): if not np.all(value > 0): diff --git a/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py b/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py index 9c54361ab..2ae3b32eb 100644 --- a/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py +++ b/torax/_src/torax_pydantic/tests/interpolated_param_2d_test.py @@ -615,6 +615,40 @@ def f( np.testing.assert_allclose(cell, [8.0, 10.0, 12.0, 14.0]) self.assertEqual(jax_utils.get_number_of_compiles(f), 1) + @parameterized.named_parameters( + dict( + testcase_name='dict_input', + time_rho_interpolated_input={ + 0.0: {0.0: 1.0, 1.0: 2.0}, + 1.0: {0.0: 3.0, 1.0: 4.0}, + }, + expected_values={ + 0.0: np.array([1.0, 0.0, 0.0, 2.0]), + 0.5: np.array([2.0, 0.0, 0.0, 3.0]), + 1.0: np.array([3.0, 0.0, 0.0, 4.0]), + }, + ), + dict( + testcase_name='rho_norm_quantization', + time_rho_interpolated_input={0.0: {0.5: 1.0}}, + expected_values={ + 0.0: np.array([0.0, 1.0, 0.0, 0.0]), + }, + ), + ) + def test_time_varying_points( + self, time_rho_interpolated_input, expected_values + ): + tvp = interpolated_param_2d.TimeVaryingPoints.model_validate( + time_rho_interpolated_input + ) + face_centers = interpolated_param_2d.get_face_centers(4) + grid = interpolated_param_2d.Grid1D(face_centers=face_centers) + interpolated_param_2d.set_grid(tvp, grid=grid) + + for t, expected_value in expected_values.items(): + np.testing.assert_array_equal(tvp.get_value(t), expected_value) + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/torax_pydantic/torax_pydantic.py b/torax/_src/torax_pydantic/torax_pydantic.py index 45a0a82b1..4faa30062 100644 --- a/torax/_src/torax_pydantic/torax_pydantic.py +++ b/torax/_src/torax_pydantic/torax_pydantic.py @@ -51,6 +51,7 @@ BaseModelFrozen = model_base.BaseModelFrozen TimeVaryingScalar = interpolated_param_1d.TimeVaryingScalar +TimeVaryingPoints = interpolated_param_2d.TimeVaryingPoints TimeVaryingArray = interpolated_param_2d.TimeVaryingArray NonNegativeTimeVaryingArray = interpolated_param_2d.NonNegativeTimeVaryingArray PositiveTimeVaryingScalar = interpolated_param_1d.PositiveTimeVaryingScalar