diff --git a/src/mintpy/objects/resample.py b/src/mintpy/objects/resample.py index ce7b503af..ef1debe29 100644 --- a/src/mintpy/objects/resample.py +++ b/src/mintpy/objects/resample.py @@ -205,10 +205,18 @@ def run_resample(self, src_data, box_ind=0, print_msg=True): ) if self.software == 'pyresample': - # move 1st/time dimension to the last + + # Move spatial dimensions (row, col) from the end to the front # so that rows/cols axis are the first, as required by pyresample - if len(src_data.shape) == 3: - src_data = np.moveaxis(src_data, 0, -1) + src_data = ut.move_spatial_dimension(src_data, to_front=True) + + # save non-spatial shape for restoring later + # () for 2d data + non_spatial_shape = src_data.shape[2:] + + # and ravel all non-spatial dimensions to shape (row, col, dn1 * dn2 * ...) + # will return just return data for 2d data + src_data = ut.flatten_for_resample(src_data) # resample source data into target data dest_data = self.run_pyresample( @@ -220,9 +228,12 @@ def run_resample(self, src_data, box_ind=0, print_msg=True): **kwargs, ) - # move 1st/time dimension back - if len(dest_data.shape) == 3: - dest_data = np.moveaxis(dest_data, -1, 0) + # Restore original non-spatial dimensions + dest_data = ut.restore_from_resample(dest_data, non_spatial_shape) + + # Move spatial dimensions (row, col) from the front back to the end + # Restores (..., rows, cols) for any number of leading dimensions + dest_data = ut.move_spatial_dimension(dest_data, to_front=False) else: vprint(f'{self.interp_method} resampling using scipy.interpolate.RegularGridInterpolator ...') @@ -233,6 +244,8 @@ def run_resample(self, src_data, box_ind=0, print_msg=True): prog_bar.update(i+1) dest_data[i, :, :] = self.run_regular_grid_interpolator(src_data=src_data[i, :, :], **kwargs) prog_bar.close() + elif len(src_data.shape) > 3: + raise NotImplementedError('scipy resampling for >3D data is NOT implemented yet. Use pyresample instead.') else: dest_data = self.run_regular_grid_interpolator(src_data=src_data, **kwargs) diff --git a/src/mintpy/utils/utils0.py b/src/mintpy/utils/utils0.py index 2d36f0d12..7d35bb42d 100644 --- a/src/mintpy/utils/utils0.py +++ b/src/mintpy/utils/utils0.py @@ -1062,7 +1062,67 @@ def circle_index(atr, circle_par): return idx +def move_spatial_dimension(data, to_front=True): + """Move spatial dimensions (row, col) to front or back for pyresample + + Parameters: + data - np.ndarray, shape (..., rows, cols) or (rows, cols) + to_front - bool + True -> move spatial dims (row, col) to axes (0, 1) + False -> move spatial dims (row, col) back to the end + Returns: data - 2D/3D np.array, output data with moved spatial dimensions + """ + if data.ndim <= 2: + return data + + if to_front: + return np.moveaxis(data, [-2, -1], [0, 1]) + else: + return np.moveaxis(data, [0, 1], [-2, -1]) + +def flatten_for_resample(data): + """ + Flatten non-spatial dimensions for pyresample. + This assumes spatial dimensions are the first two dimensions. + + Parameters: data - np.ndarray, input data with spatial dimensions at the front + + Returns: src_data - np.ndarray, reshaped data with spatial dimensions at the + front and non-spatial dimensions flattened + """ + if data.ndim <= 2: + return data + rows, cols = data.shape[:2] + + # multiple dimensions. We now need to ravel all non-spatial dimensions + data = data.reshape(rows, cols, -1) + return data + +def restore_from_resample(data, non_spatial_shape): + """ + Restore non-spatial dimensions after pyresample. + This assumes spatial dimensions are the first two dimensions. + + Parameters + ---------- + data : np.ndarray + Input data with spatial dimensions at the front. + non_spatial_shape : tuple + Original shape of non-spatial dimensions. + + Returns + ------- + np.ndarray + Reshaped data with original non-spatial dimensions restored. + """ + rows, cols = data.shape[:2] + + if non_spatial_shape: + data = data.reshape(rows, cols, *non_spatial_shape) + else: + data = data.reshape(rows, cols) + return data #################################### User Interaction ##################################### def yes_or_no(question): diff --git a/tests/test_geocoding.py b/tests/test_geocoding.py new file mode 100644 index 000000000..eebf4adab --- /dev/null +++ b/tests/test_geocoding.py @@ -0,0 +1,165 @@ +import numpy as np +import pytest + +from mintpy.utils.utils0 import ( + flatten_for_resample, + move_spatial_dimension, + restore_from_resample, +) + + +def test_flatten_for_resample_2d(): + data = np.random.rand(20, 30) + + out = flatten_for_resample(data) + + assert out.shape == (20, 30) + np.testing.assert_array_equal(out, data) + +def test_flatten_for_resample_3d(): + data = np.random.rand(20, 30, 5) + + out = flatten_for_resample(data) + + assert out.shape == (20, 30, 5) + np.testing.assert_array_equal(out, data) + +def test_flatten_for_resample_4d(): + data = np.random.rand(20, 30, 5, 4) + + out = flatten_for_resample(data) + + assert out.shape == (20, 30, 20) # 5 * 4 + np.testing.assert_array_equal(out.reshape(20, 30, 5, 4), data) + +def test_restore_from_resample_2d(): + rows, cols = 20, 30 + data = np.random.rand(rows, cols) + + out = restore_from_resample( + data, + non_spatial_shape=() + ) + + assert out.shape == (rows, cols) + np.testing.assert_array_equal(out, data) + +def test_restore_from_resample_3d(): + rows, cols = 20, 30 + non_spatial_shape = (5,) + data = np.random.rand(rows, cols, 5) + + out = restore_from_resample( + data, + non_spatial_shape=non_spatial_shape + ) + + assert out.shape == (20, 30, 5) + np.testing.assert_array_equal(out, data) + +def test_restore_from_resample_4d(): + rows, cols = 20, 30 + non_spatial_shape = (5, 4) + data = np.random.rand(rows, cols, 20) # 5 * 4 + + out = restore_from_resample( + data, + non_spatial_shape=non_spatial_shape + ) + + assert out.shape == (20, 30, 5, 4) + np.testing.assert_array_equal(out.reshape(20, 30, 20), data) + +@pytest.mark.parametrize( + "shape", + [ + (20, 30), # 2D + (20, 30, 5), # 3D + (20, 30, 5, 4), # 4D + (20, 30, 2, 3, 4), # 5D + ], +) +def test_flatten_restore_roundtrip(shape): + data = np.random.rand(*shape) + non_spatial_shape = shape[2:] + + flat = flatten_for_resample(data) + restored = restore_from_resample( + flat, + non_spatial_shape=non_spatial_shape + ) + + np.testing.assert_array_equal(restored, data) + +def test_restore_from_resample_invalid_shape(): + rows, cols = 20, 30 + data = np.random.rand(rows, cols, 10) + + with pytest.raises(ValueError): + restore_from_resample( + data, + non_spatial_shape=(3, 4) # 12 != 10 + ) + +def test_move_spatial_dimension_2d_front_and_back(): + # shape: (time, row, col) + arr = np.zeros((20, 30)) + + front = move_spatial_dimension(arr, to_front=True) + assert front.shape == (20, 30) + + back = move_spatial_dimension(front, to_front=False) + assert back.shape == arr.shape + + # ensure data integrity + np.testing.assert_array_equal(back, arr) + +def test_move_spatial_dimension_3d_front_and_back(): + # shape: (time, row, col) + arr = np.zeros((5, 20, 30)) + + front = move_spatial_dimension(arr, to_front=True) + assert front.shape == (20, 30, 5) + + back = move_spatial_dimension(front, to_front=False) + assert back.shape == arr.shape + + # ensure data integrity + np.testing.assert_array_equal(back, arr) + +def test_move_spatial_dimension_4d_front_and_back(): + # shape: (time, band, row, col) + arr = np.random.rand(3, 4, 20, 30) + + front = move_spatial_dimension(arr, to_front=True) + assert front.shape == (20, 30, 3, 4) + + back = move_spatial_dimension(front, to_front=False) + assert back.shape == arr.shape + + np.testing.assert_array_equal(back, arr) + +@pytest.mark.parametrize( + "shape", + [ + (20, 30), # 2D + (20, 30, 5), # 3D + (20, 30, 5, 4), # 4D + (20, 30, 2, 3, 4), # 5D + ], +) +@pytest.mark.parametrize("to_front", [True, False]) +def test_move_spatial_dimension_roundtrip(shape, to_front): + data = np.random.rand(*shape) + + # Move spatial dimensions + moved = move_spatial_dimension(data, to_front=to_front) + + # Invert the move + restored = move_spatial_dimension(moved, to_front=not to_front) + + # Shape should match original + assert restored.shape == data.shape + + # Values should match original exactly + np.testing.assert_array_equal(restored, data)