From 6879bf3a996644802c0117cadf250a0497ff160e Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 14:26:27 -0600 Subject: [PATCH 01/63] Add a source bathy class --- mom6_forge/_source_bathy.py | 117 ++++++++++++++++++++++++++++++++++++ mom6_forge/topo.py | 48 ++++++++------- 2 files changed, 144 insertions(+), 21 deletions(-) create mode 100644 mom6_forge/_source_bathy.py diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py new file mode 100644 index 00000000..b156a22e --- /dev/null +++ b/mom6_forge/_source_bathy.py @@ -0,0 +1,117 @@ +"""Source bathymetry loader for mom6_forge. + +``SourceBathy`` is a lightweight data container for a regional slice of a +source bathymetry dataset. ``Topo._get_src()`` creates and caches one +automatically when ``set_from_dataset`` is called. Users who call pipeline +methods directly (e.g. ``high_res_regrid``, ``generate_mask_ocean_frac``) +should construct a ``SourceBathy`` explicitly:: + + from mom6_forge._source_bathy import SourceBathy + src = SourceBathy("gebco_2023.nc").slice_to_domain(topo) +""" + +import numpy as np +import xarray as xr +from pathlib import Path +from mom6_forge.utils import longitude_slicer + + +class SourceBathy: + """Regional slice of a source bathymetry dataset (e.g. GEBCO). + + Holds the loaded, domain-clipped elevation DataArray together with its + coordinate-name metadata. Per-cell depth statistics are computed and + cached here so repeated calls with the same source file skip the + expensive sub-sampling step. + + Parameters + ---------- + path : str or Path + lon_name : str — longitude coordinate name. Default ``"lon"``. + lat_name : str — latitude coordinate name. Default ``"lat"``. + elevation_name : str — elevation variable (positive-up). Default ``"elevation"``. + """ + + def __init__( + self, + path, + lon_name="lon", + lat_name="lat", + elevation_name="elevation", + ): + self.path = Path(path) + self.lon_name = lon_name + self.lat_name = lat_name + self.elevation_name = elevation_name + self._da = None # set by slice_to_domain + self._topo_stats = None # set by compute_topo_stats + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + + def slice_to_domain(self, topo, buf=0.5): + """Load and clip elevation to the topo grid extent plus ``buf`` degrees. + + Handles the global-longitude seam automatically. Mutates ``self`` + in place and returns ``self`` for chaining. + + Parameters + ---------- + topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used. + buf : float — degree buffer around the Q-grid bounding box. Default 0.5. + """ + lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) + lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) + + ds_src = xr.open_dataset(self.path, chunks="auto") + da = ds_src[self.elevation_name] + + da = da.sel({self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)}) + + dlon = float(da[self.lon_name][1] - da[self.lon_name][0]) + total_lon = float(da[self.lon_name][-1] - da[self.lon_name][0] + dlon) + if np.isclose(total_lon, 360): + da = longitude_slicer( + da, + np.array(lon_extent) + np.array([-buf, buf]), + self.lon_name, + ) + else: + da = da.sel( + {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} + ) + + self._da = da.load() + return self + + # ------------------------------------------------------------------ + # Accessors + # ------------------------------------------------------------------ + + @property + def lon(self): + """1-D longitude array.""" + return self._da[self.lon_name].values + + @property + def lat(self): + """1-D latitude array.""" + return self._da[self.lat_name].values + + @property + def depth(self): + """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src).""" + return -self._da.values.astype(float) + + @property + def da(self): + """Raw elevation DataArray with source coordinate names (positive-up).""" + return self._da + + def __repr__(self): + shape = self._da.shape if self._da is not None else "not loaded" + return ( + f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, " + f"lat={self.lat_name!r}, elevation={self.elevation_name!r}, shape={shape})" + ) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 78d19951..efeb537e 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -12,7 +12,7 @@ from mom6_forge.edit_command import * from mom6_forge.command_manager import TopoCommandManager, CommandType from mom6_forge.mapping import regrid_dataset_via_xesmf - +from mom6_forge._source_bathy import SourceBathy class Topo: """ @@ -38,6 +38,7 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary"): attrs={"units": "m"}, ) # Initialize depth with NaNs self._min_depth = min_depth + self._src = None # cached SourceBathy; set by _get_src() if version_control_dir is None: raise ValueError( @@ -217,26 +218,6 @@ def umask(self): return umask - @property - def umask(self): - """ - Ocean domain mask on U grid. 1 if ocean, 0 if land. - """ - tmask = self.tmask - - # Create empty mask DataArray for umask - umask = xr.DataArray( - np.ones(self._grid.ulat.shape, dtype=int), - dims=["yh", "xq"], - attrs={"name": "U mask"}, - ) - - # Fill umask with mask values - umask[:, :-1] &= tmask.values # h-point translates to the left u-point - umask[:, 1:] &= tmask.values # h-point translates to the right u-point - - return umask - @property def vmask(self): """ @@ -310,6 +291,31 @@ def supergridmask(self): supergridmask[1::2, ::2] = self.umask.values supergridmask[1::2, 1::2] = self.tmask.values return supergridmask + + def _get_src( + self, + bathymetry_path, + longitude_coordinate_name, + latitude_coordinate_name, + vertical_coordinate_name, + ): + """Return a cached :class:`SourceBathy`, creating and slicing a new one + only when the path or coordinate names differ from the current cache.""" + path = Path(bathymetry_path) + if ( + self._src is None + or self._src.path != path + or self._src.lon_name != longitude_coordinate_name + or self._src.lat_name != latitude_coordinate_name + or self._src.elevation_name != vertical_coordinate_name + ): + self._src = SourceBathy( + path, + longitude_coordinate_name, + latitude_coordinate_name, + vertical_coordinate_name, + ).slice_to_domain(self) + return self._src def point_is_ocean(self, lons, lats): """ From 8a524ec0366df0d2dd61492a8bd764552d3fe4df Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 14:26:44 -0600 Subject: [PATCH 02/63] Black --- mom6_forge/topo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index efeb537e..ef464513 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -14,6 +14,7 @@ from mom6_forge.mapping import regrid_dataset_via_xesmf from mom6_forge._source_bathy import SourceBathy + class Topo: """ Bathymetry Generator for MOM6 grids (mom6_forge.grid.Grid). @@ -291,7 +292,7 @@ def supergridmask(self): supergridmask[1::2, ::2] = self.umask.values supergridmask[1::2, 1::2] = self.tmask.values return supergridmask - + def _get_src( self, bathymetry_path, From cdff74c9bd88050065d4a3a6102350a328e78e95 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 14:38:32 -0600 Subject: [PATCH 03/63] Bleh --- setup.py | 1 + tests/test_source_bathy.py | 98 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 tests/test_source_bathy.py diff --git a/setup.py b/setup.py index 4b4ca999..f4050260 100644 --- a/setup.py +++ b/setup.py @@ -37,5 +37,6 @@ "gitpython>=3.1,<3.2.0", "cartopy>=0.23,<0.30", "xesmf>=0.8.10,<1.0.0", + "dask", ], ) diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py new file mode 100644 index 00000000..4beca408 --- /dev/null +++ b/tests/test_source_bathy.py @@ -0,0 +1,98 @@ +"""Smoke test for SourceBathy loader.""" + +import numpy as np +import pytest +import tempfile +from pathlib import Path +import xarray as xr +from mom6_forge._source_bathy import SourceBathy + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() + + +def test_source_bathy_initialization(synthetic_bathy_file): + """Test SourceBathy initialization and coordinate names.""" + src = SourceBathy( + synthetic_bathy_file, + lon_name="lon", + lat_name="lat", + elevation_name="elevation", + ) + + assert src.path == Path(synthetic_bathy_file) + assert src.lon_name == "lon" + assert src.lat_name == "lat" + assert src.elevation_name == "elevation" + assert src._da is None # Not loaded yet + + +def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): + """Smoke test: load and slice elevation to topo domain.""" + topo = get_rect_topo + + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + + # Verify data was loaded + assert src._da is not None + assert src.lon is not None + assert src.lat is not None + + # Verify shape makes sense + assert len(src.lon) > 0 + assert len(src.lat) > 0 + assert src._da.shape == (len(src.lat), len(src.lon)) + + +def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): + """Test that elevation is converted to positive-down depth.""" + topo = get_rect_topo + + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + + # Get depth and verify sign conversion + depth = src.depth + + assert not bool(np.isnan(depth).all()) + assert depth.shape == src.da.shape From 2ce1f06362f024e30a3758d54dbe1c378298349f Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 14:39:24 -0600 Subject: [PATCH 04/63] Black + other --- tests/test_source_bathy.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index 4beca408..14defd1b 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -79,8 +79,8 @@ def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): assert src.lat is not None # Verify shape makes sense - assert len(src.lon) > 0 - assert len(src.lat) > 0 + assert len(src.lon) > 0, f"Expected lon data, got empty array" + assert len(src.lat) > 0, f"Expected lat data, got empty array" assert src._da.shape == (len(src.lat), len(src.lon)) @@ -94,5 +94,12 @@ def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): # Get depth and verify sign conversion depth = src.depth - assert not bool(np.isnan(depth).all()) + # Verify no NaNs in the result + assert not bool(np.isnan(depth).all()), "All depth values are NaN" + + # Verify positive depth values for ocean (elevation is negative) + non_nan_values = depth[~np.isnan(depth)] + assert len(non_nan_values) > 0, "No valid depth values" + assert np.any(non_nan_values > 0), "Expected positive depth values for ocean" + assert depth.shape == src.da.shape From b3f7a2cfb5ba8bd86b7f5a99ece5a08dbfe757cd Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 14:47:17 -0600 Subject: [PATCH 05/63] Bump --- tests/test_source_bathy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index 14defd1b..a96b6b3f 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -71,7 +71,7 @@ def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.5) + src.slice_to_domain(topo, buf=0.4) # Verify data was loaded assert src._da is not None From 8da58c9c6e2e339fc47d140943997ee7fbb778d5 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 15:01:05 -0600 Subject: [PATCH 06/63] Compute topo stats func --- mom6_forge/topo.py | 145 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index ef464513..718dcb7d 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -571,6 +571,151 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0): # Save to object (Build TCM Object) self.send_entire_depth_change_to_tcm(new_values) + def _compute_topo_stats(self, src, nx_sub, ny_sub, mask_hmin): + """Compute per-cell depth statistics by uniform sub-sampling. + + Results are cached on ``src._topo_stats`` so a second call with the + same source file returns immediately without recomputation. + + Parameters + ---------- + src : SourceBathy + nx_sub, ny_sub : int + mask_hmin : float + + Returns + ------- + xr.Dataset — ``OCN_FRAC``, ``D_mean``, ``D_min``, ``D_max``, ``D2_mean``. + """ + if src._topo_stats is not None: + return src._topo_stats + + dlon = float(src.lon[1] - src.lon[0]) + dlat = float(src.lat[1] - src.lat[0]) + + SW_lon = self._grid.qlon.values[:-1, :-1] + SE_lon = self._grid.qlon.values[:-1, 1:] + NE_lon = self._grid.qlon.values[1:, 1:] + NW_lon = self._grid.qlon.values[1:, :-1] + SW_lat = self._grid.qlat.values[:-1, :-1] + SE_lat = self._grid.qlat.values[:-1, 1:] + NE_lat = self._grid.qlat.values[1:, 1:] + NW_lat = self._grid.qlat.values[1:, :-1] + + # Fix 2: ensure all corners are in the same 360° period as NE, + # matching Fortran create_model_topo.f90 lines 322-333. + # Cells straddling the antimeridian would otherwise produce garbage + # bilinear-interpolated sub_lon values. + def _fix_lon_period(lon, ref): + diff = lon - ref + lon = np.where(diff > 270, lon - 360, lon) + lon = np.where(diff < -270, lon + 360, lon) + return lon + + SW_lon = _fix_lon_period(SW_lon, NE_lon) + SE_lon = _fix_lon_period(SE_lon, NE_lon) + NW_lon = _fix_lon_period(NW_lon, NE_lon) + + ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float) + jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float) + + i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] + j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] + SW_lon = SW_lon[:, :, np.newaxis, np.newaxis] + SE_lon = SE_lon[:, :, np.newaxis, np.newaxis] + NE_lon = NE_lon[:, :, np.newaxis, np.newaxis] + NW_lon = NW_lon[:, :, np.newaxis, np.newaxis] + SW_lat = SW_lat[:, :, np.newaxis, np.newaxis] + SE_lat = SE_lat[:, :, np.newaxis, np.newaxis] + NE_lat = NE_lat[:, :, np.newaxis, np.newaxis] + NW_lat = NW_lat[:, :, np.newaxis, np.newaxis] + + sub_lon = ( + (1 - i_) * (1 - j_) * SW_lon + + i_ * (1 - j_) * SE_lon + + i_ * j_ * NE_lon + + (1 - i_) * j_ * NW_lon + ) + sub_lat = ( + (1 - i_) * (1 - j_) * SW_lat + + i_ * (1 - j_) * SE_lat + + i_ * j_ * NE_lat + + (1 - i_) * j_ * NW_lat + ) + + # Create destination subpoints + ds_sub = xr.Dataset( + coords={ + "lat": (["y", "x"], sub_lat), + "lon": (["y", "x"], sub_lon), + } + ) + ds_sub["depth"] = xr.zeros_like(ds_sub["lon"]) + # Create source dataset + ds_src = xr.Dataset( + coords={ + "lat": (["y"], src.lat), + "lon": (["x"], src.lon), + }, + data_vars={"depth": (["y", "x"], src.depth)} + ) + + + depth_sub = regrid_dataset_via_xesmf( + input_dataset = ds_src, + output_dataset = ds_sub, + regridding_method="nearest_s2d", # Gets the closest source point + write_to_file=False, + ).depth.data + + is_ocean = depth_sub > mask_hmin + ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) + + depth_ocean = np.where(is_ocean, depth_sub, np.nan) + with np.errstate(all="ignore"): + D_mean = np.nanmean(depth_ocean, axis=(-2, -1)) + D_min = np.nanmin(depth_ocean, axis=(-2, -1)) + D_max = np.nanmax(depth_ocean, axis=(-2, -1)) + D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1)) + + dims = ["ny", "nx"] + src._topo_stats = xr.Dataset( + { + "OCN_FRAC": xr.DataArray( + ocn_frac, + dims=dims, + attrs={ + "long_name": "ocean fraction from sub-sampling", + "units": "1", + }, + ), + "D_mean": xr.DataArray( + D_mean, + dims=dims, + attrs={"long_name": "mean ocean depth in cell", "units": "m"}, + ), + "D_min": xr.DataArray( + D_min, + dims=dims, + attrs={"long_name": "minimum ocean depth in cell", "units": "m"}, + ), + "D_max": xr.DataArray( + D_max, + dims=dims, + attrs={"long_name": "maximum ocean depth in cell", "units": "m"}, + ), + "D2_mean": xr.DataArray( + D2_mean, + dims=dims, + attrs={ + "long_name": "mean squared ocean depth in cell", + "units": "m2", + }, + ), + } + ) + return src._topo_stats + def set_from_dataset( self, bathymetry_path, From 58d67d1ba242fb54afc7d89181e9acf93667bfb8 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 15:02:15 -0600 Subject: [PATCH 07/63] Add reqs --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f4050260..facab1e4 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,6 @@ "gitpython>=3.1,<3.2.0", "cartopy>=0.23,<0.30", "xesmf>=0.8.10,<1.0.0", - "dask", + "dask>=2026.3.0,<2026.4.0", ], ) From a07f7b17ad2c62bbed19cf8cb0784ea220dc26d6 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 15:06:18 -0600 Subject: [PATCH 08/63] Bleh --- mom6_forge/topo.py | 21 ++++++++++++--------- tests/test_topo_bathymetry_workflows.py | 1 + 2 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 tests/test_topo_bathymetry_workflows.py diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 718dcb7d..ed13803b 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -571,7 +571,7 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0): # Save to object (Build TCM Object) self.send_entire_depth_change_to_tcm(new_values) - def _compute_topo_stats(self, src, nx_sub, ny_sub, mask_hmin): + def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): """Compute per-cell depth statistics by uniform sub-sampling. Results are cached on ``src._topo_stats`` so a second call with the @@ -579,7 +579,7 @@ def _compute_topo_stats(self, src, nx_sub, ny_sub, mask_hmin): Parameters ---------- - src : SourceBathy + src : SourceBathy (part of class) nx_sub, ny_sub : int mask_hmin : float @@ -587,6 +587,10 @@ def _compute_topo_stats(self, src, nx_sub, ny_sub, mask_hmin): ------- xr.Dataset — ``OCN_FRAC``, ``D_mean``, ``D_min``, ``D_max``, ``D2_mean``. """ + assert ( + self._src is not None + ), "Source bathymetry must be loaded to compute topo stats" + src = self._src if src._topo_stats is not None: return src._topo_stats @@ -602,7 +606,7 @@ def _compute_topo_stats(self, src, nx_sub, ny_sub, mask_hmin): NE_lat = self._grid.qlat.values[1:, 1:] NW_lat = self._grid.qlat.values[1:, :-1] - # Fix 2: ensure all corners are in the same 360° period as NE, + # ensure all corners are in the same 360° period as NE, # matching Fortran create_model_topo.f90 lines 322-333. # Cells straddling the antimeridian would otherwise produce garbage # bilinear-interpolated sub_lon values. @@ -657,15 +661,14 @@ def _fix_lon_period(lon, ref): "lat": (["y"], src.lat), "lon": (["x"], src.lon), }, - data_vars={"depth": (["y", "x"], src.depth)} + data_vars={"depth": (["y", "x"], src.depth)}, ) - depth_sub = regrid_dataset_via_xesmf( - input_dataset = ds_src, - output_dataset = ds_sub, - regridding_method="nearest_s2d", # Gets the closest source point - write_to_file=False, + input_dataset=ds_src, + output_dataset=ds_sub, + regridding_method="nearest_s2d", # Gets the closest source point + write_to_file=False, ).depth.data is_ocean = depth_sub > mask_hmin diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py new file mode 100644 index 00000000..52c36d10 --- /dev/null +++ b/tests/test_topo_bathymetry_workflows.py @@ -0,0 +1 @@ +from mom6_forge.topo import * From 28949024bcf70662e096fe994f2ce41aad978253 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 17 Apr 2026 15:29:28 -0600 Subject: [PATCH 09/63] Smoke test --- mom6_forge/topo.py | 41 +++++++++--------- tests/conftest.py | 46 ++++++++++++++++++++ tests/test_topo.py | 48 +++++++++++++++++++++ tests/test_topo_bathymetry_workflows.py | 56 +++++++++++++++++++++++++ 4 files changed, 169 insertions(+), 22 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index ed13803b..be5b70b5 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -606,7 +606,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): NE_lat = self._grid.qlat.values[1:, 1:] NW_lat = self._grid.qlat.values[1:, :-1] - # ensure all corners are in the same 360° period as NE, + # Fix 2: ensure all corners are in the same 360° period as NE, # matching Fortran create_model_topo.f90 lines 322-333. # Cells straddling the antimeridian would otherwise produce garbage # bilinear-interpolated sub_lon values. @@ -648,28 +648,25 @@ def _fix_lon_period(lon, ref): ) # Create destination subpoints - ds_sub = xr.Dataset( - coords={ - "lat": (["y", "x"], sub_lat), - "lon": (["y", "x"], sub_lon), - } - ) - ds_sub["depth"] = xr.zeros_like(ds_sub["lon"]) - # Create source dataset - ds_src = xr.Dataset( - coords={ - "lat": (["y"], src.lat), - "lon": (["x"], src.lon), - }, - data_vars={"depth": (["y", "x"], src.depth)}, - ) + ii = np.round((sub_lon - src.lon[0]) / dlon).astype(int) + jj = np.round((sub_lat - src.lat[0]) / dlat).astype(int) + # wrap longitude index periodically rather than clipping — + # sub-points near the antimeridian must find the correct source pixel + # on the other side, not snap to the edge. + if np.any((ii < 0) | (ii >= len(src.lon))): + src_span = float(src.lon[-1] - src.lon[0]) + if src_span < 355: + raise ValueError( + f"Sub-points fall outside the source longitude range " + f"[{float(src.lon[0]):.2f}, {float(src.lon[-1]):.2f}] " + f"(span {src_span:.1f}°). Longitude wraparound requires a " + f"global source (~360° span); got {src_span:.1f}°. " + f"Pass a global SourceBathy rather than a regional slice." + ) + ii = ii % len(src.lon) + jj = np.clip(jj, 0, len(src.lat) - 1) # latitude: clamp, no wraparound - depth_sub = regrid_dataset_via_xesmf( - input_dataset=ds_src, - output_dataset=ds_sub, - regridding_method="nearest_s2d", # Gets the closest source point - write_to_file=False, - ).depth.data + depth_sub = src.depth[jj, ii] # positive-down is_ocean = depth_sub > mask_hmin ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) diff --git a/tests/conftest.py b/tests/conftest.py index 0b83f72e..63ac1745 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ import numpy as np import pytest +import tempfile +from pathlib import Path +import xarray as xr from mom6_forge.grid import Grid from mom6_forge.topo import Topo +from mom6_forge._source_bathy import SourceBathy @pytest.fixture @@ -252,3 +256,45 @@ def get_rect_topo(get_rect_grid, tmp_path): topo = Topo(get_rect_grid, min_depth=0, version_control_dir=tmp_path) topo.set_flat(1000) return topo + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() diff --git a/tests/test_topo.py b/tests/test_topo.py index 5da340d8..7e89b0e2 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -1,4 +1,52 @@ +import numpy as np +import pytest +import tempfile +from pathlib import Path +import xarray as xr from mom6_forge.topo import * +from mom6_forge._source_bathy import SourceBathy + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() def test_topo_from_version_control(get_rect_topo): diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 52c36d10..5a5dbf30 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -1 +1,57 @@ +import numpy as np +import pytest from mom6_forge.topo import * +from mom6_forge._source_bathy import SourceBathy + + +def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): + """Test _compute_topo_stats: per-cell depth statistics via xesmf nearest neighbor regridding. + + This test validates the refactored _compute_topo_stats method which: + - Generates sub-points within each grid cell + - Uses xesmf nearest_s2d regridding to snap sub-points to nearest source data + - Computes per-cell statistics (OCN_FRAC, D_mean, D_min, D_max, D2_mean) + """ + topo = get_rect_topo + + # Load source bathymetry and slice to topo domain + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + topo._src = src + + # Test with different sub-sampling densities + for nx_sub, ny_sub in [(2, 2), (3, 3)]: + # Call _compute_topo_stats + stats = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + + # Verify output is a Dataset with expected variables + assert isinstance(stats, xr.Dataset) + required_vars = ["OCN_FRAC", "D_mean", "D_min", "D_max", "D2_mean"] + for var in required_vars: + assert var in stats.data_vars, f"Missing {var} in output" + + # Verify shapes match topo grid + expected_shape = (topo.depth.shape[0], topo.depth.shape[1]) + assert stats["OCN_FRAC"].shape == expected_shape + assert stats["D_mean"].shape == expected_shape + assert stats["D_min"].shape == expected_shape + assert stats["D_max"].shape == expected_shape + assert stats["D2_mean"].shape == expected_shape + + # Verify OCN_FRAC is between 0 and 1 + assert (stats["OCN_FRAC"] >= 0).all() + assert (stats["OCN_FRAC"] <= 1).all() + + # Verify D_min <= D_mean <= D_max + ocean_cells = stats["OCN_FRAC"].values > 0 + assert ( + stats["D_min"].values[ocean_cells] <= stats["D_mean"].values[ocean_cells] + ).all() + assert ( + stats["D_mean"].values[ocean_cells] <= stats["D_max"].values[ocean_cells] + ).all() + + # Verify caching: second call should return cached result + stats2 = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + # Should be the exact same object (cached) + assert stats2 is stats From 422caa26da870dd639b5347ee4d7b6c5c4d55791 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 13:46:37 -0600 Subject: [PATCH 10/63] Reformat with xesmf (provides significant seam protection) --- mom6_forge/chl.py | 39 +++++++----- mom6_forge/mapping.py | 140 ++++++++++++++++++++++++++++++++++++++++++ mom6_forge/topo.py | 82 +++---------------------- mom6_forge/utils.py | 27 ++++++++ tests/conftest.py | 53 ++++++++++++++++ tests/test_mapping.py | 84 +++++++++++++++++++++++++ 6 files changed, 337 insertions(+), 88 deletions(-) create mode 100644 tests/test_mapping.py diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index 98e8d40c..13e9bbb5 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -6,7 +6,8 @@ from datetime import datetime from pathlib import Path from os.path import isfile -from mom6_forge.utils import fill_missing_data +from mom6_forge.utils import fill_missing_data, compute_subsampling_factor +from mom6_forge.mapping import regrid_with_subsampling def interpolate_and_fill_seawifs( @@ -64,7 +65,7 @@ def interpolate_and_fill_seawifs( (np.arange(src_ni) + 0.5) / src_ni ) * 360.0 + src_x0 # Recompute as doubles - spr_lat, spr_lon = grid.tlat.values, grid.tlon.values + ny_sub, nx_sub = compute_subsampling_factor(src_nj, src_ni, ocn_nj, ocn_ni) # Set output path if output_path is None: @@ -88,21 +89,27 @@ def interpolate_and_fill_seawifs( # Iterate through time for t in range(src_data.shape[0]): - # Bilinearly interpolate the source data onto the super-sampled grid - # adj lon to -180 to 180 - adj_lon = spr_lon - 360 - data = src_data[t, ::-1, :].values - - src = {"lon": src_lon, "lat": src_lat} - dst = {"lon": adj_lon, "lat": spr_lat} - regridder = xe.Regridder( - src, - dst, - "bilinear", - filename="bilin_weights.nc", - reuse_weights=isfile("bilin_weights.nc"), + # Build source dataset for this timestep + src_ds = xr.Dataset( + { + "chlor_a": xr.DataArray( + src_data[t, ::-1, :].values, + dims=["lat", "lon"], + coords={"lat": src_lat, "lon": src_lon}, + ) + } + ) + + # Regrid to super-sampled sub-point grid and average back to model grid + q_sub = regrid_with_subsampling( + input_dataset=src_ds, + qlon=grid.qlon.values, + qlat=grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="bilinear", ) - q_int = regridder(data) + q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values # Fill any missing data q = q_int * ocn_mask diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 065c7e97..e495575d 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -994,6 +994,146 @@ def regrid_dataset_via_xesmf( return dataset +def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub): + """ + Given corner coordinates of a low-res grid (qlon/qlat on the q-point staggering), + return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). + + Parameters + ---------- + qlon, qlat : np.ndarray shape (ny+1, nx+1) + Corner (q-point) coordinates of the low-res grid. + nx_sub, ny_sub : int + Number of sub-points per cell in each direction. + + Returns + ------- + sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub) + """ + SW_lon = qlon[:-1, :-1] + SW_lat = qlat[:-1, :-1] + SE_lon = qlon[:-1, 1:] + SE_lat = qlat[:-1, 1:] + NE_lon = qlon[1:, 1:] + NE_lat = qlat[1:, 1:] + NW_lon = qlon[1:, :-1] + NW_lat = qlat[1:, :-1] + + # Fix antimeridian-straddling cells + def _fix_period(lon, ref): + lon = np.where(lon - ref > 270, lon - 360, lon) + lon = np.where(lon - ref < -270, lon + 360, lon) + return lon + + SW_lon = _fix_period(SW_lon, NE_lon) + SE_lon = _fix_period(SE_lon, NE_lon) + NW_lon = _fix_period(NW_lon, NE_lon) + + ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float) + jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float) + + i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] # (1,1,1,nx_sub) + j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] # (1,1,ny_sub,1) + + # Broadcast all corners to (ny, nx, 1, 1) + SW_lon = SW_lon[:, :, np.newaxis, np.newaxis] + SE_lon = SE_lon[:, :, np.newaxis, np.newaxis] + NE_lon = NE_lon[:, :, np.newaxis, np.newaxis] + NW_lon = NW_lon[:, :, np.newaxis, np.newaxis] + SW_lat = SW_lat[:, :, np.newaxis, np.newaxis] + SE_lat = SE_lat[:, :, np.newaxis, np.newaxis] + NE_lat = NE_lat[:, :, np.newaxis, np.newaxis] + NW_lat = NW_lat[:, :, np.newaxis, np.newaxis] + + sub_lon = ( + (1 - i_) * (1 - j_) * SW_lon + + i_ * (1 - j_) * SE_lon + + i_ * j_ * NE_lon + + (1 - i_) * j_ * NW_lon + ) + sub_lat = ( + (1 - i_) * (1 - j_) * SW_lat + + i_ * (1 - j_) * SE_lat + + i_ * j_ * NE_lat + + (1 - i_) * j_ * NW_lat + ) + + return sub_lon, sub_lat + + +def regrid_with_subsampling( + input_dataset, + qlon, + qlat, + nx_sub, + ny_sub, + regridding_method="nearest_s2d", +): + """ + Regrids input_dataset to sub_sampled_grid to + properly analyze high-res source data into each coarse cell. + + Parameters + ---------- + input_dataset : xr.Dataset + qlon, qlat : np.ndarray shape (ny+1, nx+1) + Corner coordinates of the destination grid. + nx_sub, ny_sub : int + Number of sub-points per cell (typically from compute_subsampling_factor). + Returns + ------- + regridded_dataset : xr.Dataset + Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc) + """ + ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1 + + # Build the (ny, nx, ny_sub, nx_sub) sub-point grid + sub_lon, sub_lat = _make_subgrid_points(qlon, qlat, nx_sub, ny_sub) + + # Flatten to (ny, nx*ny_sub*nx_sub) so xesmf sees a 2D destination + flat_lon = sub_lon.reshape(ny, nx * ny_sub * nx_sub) + flat_lat = sub_lat.reshape(ny, nx * ny_sub * nx_sub) + + flat_output = xr.Dataset( + { + "lon": xr.DataArray(flat_lon, dims=["ny", "nx"]), + "lat": xr.DataArray(flat_lat, dims=["ny", "nx"]), + } + ) + + regridded_flat = regrid_dataset_via_xesmf( + input_dataset, + flat_output, + regridding_method=regridding_method, + write_to_file=False, + ) + + # Reshape to 4D, keeping sub-points as their own dimension + data_vars = {} + for var in regridded_flat.data_vars: + data = regridded_flat[var].values # (..., ny, nx*ny_sub*nx_sub) + reshaped = data.reshape(ny, nx, ny_sub, nx_sub) + + original_dims = regridded_flat[var].dims + new_dims = (*original_dims, "ny_sub", "nx_sub") + + data_vars[var] = xr.DataArray( + reshaped, + dims=new_dims, + attrs=regridded_flat[var].attrs, + ) + + coords = { + k: v + for k, v in regridded_flat.coords.items() + if "ny" not in v.dims and "nx" not in v.dims + } + coords["ny_sub"] = np.arange(ny_sub) + coords["nx_sub"] = np.arange(nx_sub) + + return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs) + + def main(args): if args.parallel: diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index be5b70b5..e9da40f5 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -11,7 +11,7 @@ from pathlib import Path from mom6_forge.edit_command import * from mom6_forge.command_manager import TopoCommandManager, CommandType -from mom6_forge.mapping import regrid_dataset_via_xesmf +from mom6_forge.mapping import regrid_dataset_via_xesmf, regrid_with_subsampling from mom6_forge._source_bathy import SourceBathy @@ -594,79 +594,17 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): if src._topo_stats is not None: return src._topo_stats - dlon = float(src.lon[1] - src.lon[0]) - dlat = float(src.lat[1] - src.lat[0]) - - SW_lon = self._grid.qlon.values[:-1, :-1] - SE_lon = self._grid.qlon.values[:-1, 1:] - NE_lon = self._grid.qlon.values[1:, 1:] - NW_lon = self._grid.qlon.values[1:, :-1] - SW_lat = self._grid.qlat.values[:-1, :-1] - SE_lat = self._grid.qlat.values[:-1, 1:] - NE_lat = self._grid.qlat.values[1:, 1:] - NW_lat = self._grid.qlat.values[1:, :-1] - - # Fix 2: ensure all corners are in the same 360° period as NE, - # matching Fortran create_model_topo.f90 lines 322-333. - # Cells straddling the antimeridian would otherwise produce garbage - # bilinear-interpolated sub_lon values. - def _fix_lon_period(lon, ref): - diff = lon - ref - lon = np.where(diff > 270, lon - 360, lon) - lon = np.where(diff < -270, lon + 360, lon) - return lon - - SW_lon = _fix_lon_period(SW_lon, NE_lon) - SE_lon = _fix_lon_period(SE_lon, NE_lon) - NW_lon = _fix_lon_period(NW_lon, NE_lon) - - ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float) - jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float) - - i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] - j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] - SW_lon = SW_lon[:, :, np.newaxis, np.newaxis] - SE_lon = SE_lon[:, :, np.newaxis, np.newaxis] - NE_lon = NE_lon[:, :, np.newaxis, np.newaxis] - NW_lon = NW_lon[:, :, np.newaxis, np.newaxis] - SW_lat = SW_lat[:, :, np.newaxis, np.newaxis] - SE_lat = SE_lat[:, :, np.newaxis, np.newaxis] - NE_lat = NE_lat[:, :, np.newaxis, np.newaxis] - NW_lat = NW_lat[:, :, np.newaxis, np.newaxis] - - sub_lon = ( - (1 - i_) * (1 - j_) * SW_lon - + i_ * (1 - j_) * SE_lon - + i_ * j_ * NE_lon - + (1 - i_) * j_ * NW_lon + # Compute subsampling factor and generate sub-point grid + ds = regrid_with_subsampling( + input_dataset=src.dataset, + qlon=self._grid.qlon.values, + qlat=self._grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="nearest_s2d", ) - sub_lat = ( - (1 - i_) * (1 - j_) * SW_lat - + i_ * (1 - j_) * SE_lat - + i_ * j_ * NE_lat - + (1 - i_) * j_ * NW_lat - ) - - # Create destination subpoints - ii = np.round((sub_lon - src.lon[0]) / dlon).astype(int) - jj = np.round((sub_lat - src.lat[0]) / dlat).astype(int) - # wrap longitude index periodically rather than clipping — - # sub-points near the antimeridian must find the correct source pixel - # on the other side, not snap to the edge. - if np.any((ii < 0) | (ii >= len(src.lon))): - src_span = float(src.lon[-1] - src.lon[0]) - if src_span < 355: - raise ValueError( - f"Sub-points fall outside the source longitude range " - f"[{float(src.lon[0]):.2f}, {float(src.lon[-1]):.2f}] " - f"(span {src_span:.1f}°). Longitude wraparound requires a " - f"global source (~360° span); got {src_span:.1f}°. " - f"Pass a global SourceBathy rather than a regional slice." - ) - ii = ii % len(src.lon) - jj = np.clip(jj, 0, len(src.lat) - 1) # latitude: clamp, no wraparound - depth_sub = src.depth[jj, ii] # positive-down + depth_sub = ds[src.elevation_name].values # (ny, nx, ny_sub, nx_sub) is_ocean = depth_sub > mask_hmin ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py index adc09bca..1fcbce39 100644 --- a/mom6_forge/utils.py +++ b/mom6_forge/utils.py @@ -697,3 +697,30 @@ def quadrilateral_area(v1, v2, v3, v4): a4 = angle_between(v4, v1, v3) return (a1 + a2 + a3 + a4 - 2 * np.pi) * R**2 + + +def compute_subsampling_factor(src_nj, src_ni, dst_nj, dst_ni): + """ + Compute the sub-sampling factors needed so that the super-sampled + destination grid has at least as many points as the source grid. + + Parameters + ---------- + src_nj, src_ni : int + Source grid dimensions. + dst_nj, dst_ni : int + Destination grid dimensions. + + Returns + ------- + ny_sub, nx_sub : int + """ + nx_sub = 1 + while nx_sub * dst_ni < src_ni: + nx_sub += 1 + + ny_sub = 1 + while ny_sub * dst_nj < src_nj: + ny_sub += 1 + + return ny_sub, nx_sub diff --git a/tests/conftest.py b/tests/conftest.py index 63ac1745..36161409 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -298,3 +298,56 @@ def synthetic_bathy_file(): # Cleanup Path(bathy_file).unlink() + + +@pytest.fixture +def get_simple_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=1, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid + + +@pytest.fixture +def get_simple_global_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=0, # min longitude in [0, 360] + lenx=360, # longitude extent in degrees + ystart=-90, # min latitude in [-90, 90] + leny=180, # latitude extent in degrees + name="panama1", + cyclic_x=True, + ) + return grid + + +@pytest.fixture +def get_PM_seam_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=359, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=-1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid + + +@pytest.fixture +def get_dateline_seam_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=-1, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=-1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid diff --git a/tests/test_mapping.py b/tests/test_mapping.py new file mode 100644 index 00000000..5abe51ff --- /dev/null +++ b/tests/test_mapping.py @@ -0,0 +1,84 @@ +from mom6_forge.mapping import _make_subgrid_points, regrid_with_subsampling +import numpy as np +import xarray as xr + + +def test_make_subgrid_points(get_simple_grid): + # Test with a simple 2x2 grid and 2 sub-points per cell + nx_sub = ny_sub = 2 + grid = get_simple_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + + expected_sub_lon = np.array( + [ + [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]], + [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]], + ] + ) + expected_sub_lat = np.array( + [ + [[[4 / 3, 4 / 3], [5 / 3, 5 / 3]], [[4 / 3, 4 / 3], [5 / 3, 5 / 3]]], + [[[7 / 3, 7 / 3], [8 / 3, 8 / 3]], [[7 / 3, 7 / 3], [8 / 3, 8 / 3]]], + ] + ) + + assert np.allclose( + sub_lon, expected_sub_lon + ), "Sub-grid longitudes do not match expected values." + assert np.allclose( + sub_lat, expected_sub_lat + ), "Sub-grid latitudes do not match expected values." + + +def test_smoke_seams_and_global_make_subgrid_points( + get_dateline_seam_grid, get_PM_seam_grid, get_simple_global_grid +): + # Test with a simple 2x2 grid and 2 sub-points per cell + nx_sub = ny_sub = 2 + grid = get_dateline_seam_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + grid = get_PM_seam_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + grid = get_simple_global_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + + +def test_regrid_with_subsampling(get_simple_grid): + # Test with a simple 2x2 grid and 2 sub-points per cell with data that lands exactly on the sub points (subtracted by 0.1 to show snapping to sub points) + nx_sub = ny_sub = 2 + grid = get_simple_grid + lon = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + lat = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + input_ds = xr.Dataset( + { + "data": ( + ["lon", "lat"], + [ + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + ], + ) + }, + coords={ + "lon": (["lon"], [x - 0.1 for x in lon]), + "lat": (["lat"], [x - 0.1 for x in lat]), + }, + ) + ds = regrid_with_subsampling(input_ds, grid.qlon, grid.qlat, nx_sub, ny_sub) + assert ds["data"].shape == (2, 2, 2, 2), "Output shape is incorrect." + expected_data = np.array( + [[[[1, 1], [2, 2]], [[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[3, 3], [4, 4]]]] + ) + assert np.allclose( + ds["data"].values, expected_data + ), "Regridded data does not match expected values." From 7ea25736871938bd2bba9809b98b8c187353ab3c Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 13:50:37 -0600 Subject: [PATCH 11/63] Bleh --- mom6_forge/mapping.py | 4 +++- tests/test_mapping.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index e495575d..19277a03 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -1010,6 +1010,8 @@ def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub): ------- sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub) """ + assert type(qlon) == type(qlat) == np.ndarray, "qlon and qlat must be numpy arrays" + SW_lon = qlon[:-1, :-1] SW_lat = qlat[:-1, :-1] SE_lon = qlon[:-1, 1:] @@ -1035,7 +1037,7 @@ def _fix_period(lon, ref): i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] # (1,1,1,nx_sub) j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] # (1,1,ny_sub,1) - # Broadcast all corners to (ny, nx, 1, 1) + # Broadcast all corners to (ny, nx, 1, 1), SW_lon = SW_lon[:, :, np.newaxis, np.newaxis] SE_lon = SE_lon[:, :, np.newaxis, np.newaxis] NE_lon = NE_lon[:, :, np.newaxis, np.newaxis] diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 5abe51ff..81407562 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -74,7 +74,9 @@ def test_regrid_with_subsampling(get_simple_grid): "lat": (["lat"], [x - 0.1 for x in lat]), }, ) - ds = regrid_with_subsampling(input_ds, grid.qlon, grid.qlat, nx_sub, ny_sub) + ds = regrid_with_subsampling( + input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) assert ds["data"].shape == (2, 2, 2, 2), "Output shape is incorrect." expected_data = np.array( [[[[1, 1], [2, 2]], [[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[3, 3], [4, 4]]]] From e4f7b2281e504fd382e44c7a6afbb8abad66b601 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 13:52:10 -0600 Subject: [PATCH 12/63] Changes --- mom6_forge/mapping.py | 4 ++-- mom6_forge/topo.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 19277a03..6cc6da86 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -997,7 +997,7 @@ def regrid_dataset_via_xesmf( def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub): """ Given corner coordinates of a low-res grid (qlon/qlat on the q-point staggering), - return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). + return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) Parameters ---------- @@ -1073,7 +1073,7 @@ def regrid_with_subsampling( ): """ Regrids input_dataset to sub_sampled_grid to - properly analyze high-res source data into each coarse cell. + properly analyze high-res source data into each coarse cell. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) Parameters ---------- diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index e9da40f5..11e6c796 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -576,6 +576,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): Results are cached on ``src._topo_stats`` so a second call with the same source file returns immediately without recomputation. + (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) Parameters ---------- From 94057e74f767346ee39ec94d70f6f69123915579 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 13:55:34 -0600 Subject: [PATCH 13/63] NewName --- mom6_forge/topo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 11e6c796..b88acf03 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -597,7 +597,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( - input_dataset=src.dataset, + input_dataset=src._da, qlon=self._grid.qlon.values, qlat=self._grid.qlat.values, nx_sub=nx_sub, From 4d1fc7ace3bf7daf8c6dbabd672f679209cf238d Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 14:01:08 -0600 Subject: [PATCH 14/63] this --- mom6_forge/_source_bathy.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index b156a22e..f94abd8e 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -65,24 +65,25 @@ def slice_to_domain(self, topo, buf=0.5): lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) ds_src = xr.open_dataset(self.path, chunks="auto") - da = ds_src[self.elevation_name] - da = da.sel({self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)}) + ds = ds_src.sel( + {self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)} + ) - dlon = float(da[self.lon_name][1] - da[self.lon_name][0]) - total_lon = float(da[self.lon_name][-1] - da[self.lon_name][0] + dlon) + dlon = float(ds[self.lon_name][1] - ds[self.lon_name][0]) + total_lon = float(ds[self.lon_name][-1] - ds[self.lon_name][0] + dlon) if np.isclose(total_lon, 360): - da = longitude_slicer( - da, + ds = longitude_slicer( + ds, np.array(lon_extent) + np.array([-buf, buf]), self.lon_name, ) else: - da = da.sel( + ds = ds.sel( {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} ) - - self._da = da.load() + self._ds = ds + self._da = ds[self.elevation_name].load() return self # ------------------------------------------------------------------ @@ -109,6 +110,11 @@ def da(self): """Raw elevation DataArray with source coordinate names (positive-up).""" return self._da + @property + def ds(self): + """Raw dataset with source coordinate names (positive-up).""" + return self._ds + def __repr__(self): shape = self._da.shape if self._da is not None else "not loaded" return ( From 6298b0b6bb61b16dfd441d6e0b4bb16c43116e26 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 20 Apr 2026 14:02:18 -0600 Subject: [PATCH 15/63] Bug Fixes --- mom6_forge/topo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index b88acf03..5823cee3 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -597,7 +597,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( - input_dataset=src._da, + input_dataset=src.ds, qlon=self._grid.qlon.values, qlat=self._grid.qlat.values, nx_sub=nx_sub, From 5277c238daf7cf81d10ab337a74337c4bf1c924f Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 23 Apr 2026 12:37:02 -0600 Subject: [PATCH 16/63] Print --- mom6_forge/_source_bathy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index f94abd8e..31842f34 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -63,6 +63,7 @@ def slice_to_domain(self, topo, buf=0.5): """ lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) + print(f"Slicing source bathymetry to domain: {lon_extent} x {lat_extent} with buffer {buf}") ds_src = xr.open_dataset(self.path, chunks="auto") From abe58c3e9323573ebe25b0160310e8a14d57ce16 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 23 Apr 2026 12:41:36 -0600 Subject: [PATCH 17/63] Black --- mom6_forge/_source_bathy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 31842f34..cf7b4f80 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -63,7 +63,9 @@ def slice_to_domain(self, topo, buf=0.5): """ lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) - print(f"Slicing source bathymetry to domain: {lon_extent} x {lat_extent} with buffer {buf}") + print( + f"Slicing source bathymetry to domain: {lon_extent} x {lat_extent} with buffer {buf}" + ) ds_src = xr.open_dataset(self.path, chunks="auto") From 3ce309d0c062024976f641d730a225c8911fede8 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:27:47 -0600 Subject: [PATCH 18/63] Bleh --- mom6_forge/topo.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index b3b73f9f..1c05dcff 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -814,6 +814,21 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): ) return src._topo_stats + def direct_stats_depth(self, statistic): + """Set the topo depth to a statistic from compute_topo_stats""" + assert statistic in [ + "D_mean", + "D_min", + "D_max", + ], f"Invalid statistic {statistic}, must be one of OCN_FRAC, D_mean, D_min, D_max, D2_mean" + assert ( + self._src is not None + ), "Source bathymetry must be loaded to compute topo stats" + assert ( + self._src._topo_stats is not None + ), "Source bathymetry must have topo stats computed, please call _compute_topo_stats first if you have not already" + self.send_entire_depth_change_to_tcm(self._src._topo_stats[statistic]) + def set_from_dataset( self, bathymetry_path, From cf08bc1429f584317b5579403d22da5693e1ae8c Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:33:34 -0600 Subject: [PATCH 19/63] Source Property --- mom6_forge/_source_bathy.py | 1 + mom6_forge/topo.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index cf7b4f80..11931370 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -61,6 +61,7 @@ def slice_to_domain(self, topo, buf=0.5): topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used. buf : float — degree buffer around the Q-grid bounding box. Default 0.5. """ + self.topo = topo lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) print( diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 732383e5..2cee6bc4 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -44,7 +44,7 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary"): None # Binary ocean/land mask (None = no mask applied) ) self._min_depth = min_depth - self._src = None # cached SourceBathy; set by _get_src() + self._src = None # cached SourceBathy; set by _set_src() self.land_fillval = 0.0 # Depth value for land cells if version_control_dir is None: @@ -184,6 +184,17 @@ def masked_depth(self): ) return masked_depth + @property + def src(self): + """ + Cached SourceBathy object representing the source bathymetry dataset sliced to the topo grid extent. This is set by set_src() when a new source bathymetry is specified, and can be accessed for any cached source dataset. + """ + return self._src + + @src.setter + def src(self, new_src): + self._src = new_src + @property def depth(self): return self._depth @@ -423,14 +434,14 @@ def supergridmask(self): supergridmask[1::2, 1::2] = self.tmask.values return supergridmask - def _get_src( + def set_src( self, bathymetry_path, longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, ): - """Return a cached :class:`SourceBathy`, creating and slicing a new one + """Set a :class:`SourceBathy`, creating and slicing a new one only when the path or coordinate names differ from the current cache.""" path = Path(bathymetry_path) if ( @@ -440,13 +451,13 @@ def _get_src( or self._src.lat_name != latitude_coordinate_name or self._src.elevation_name != vertical_coordinate_name ): - self._src = SourceBathy( + self.src = SourceBathy( path, longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, ).slice_to_domain(self) - return self._src + return self.src def clear_user_mask(self): cmd = ClearMaskCommand( From a71834d1d20e0d9fb37e17d5e3f8785e1ee4f629 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:35:13 -0600 Subject: [PATCH 20/63] Src --- mom6_forge/topo.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 2cee6bc4..705724d9 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -445,11 +445,11 @@ def set_src( only when the path or coordinate names differ from the current cache.""" path = Path(bathymetry_path) if ( - self._src is None - or self._src.path != path - or self._src.lon_name != longitude_coordinate_name - or self._src.lat_name != latitude_coordinate_name - or self._src.elevation_name != vertical_coordinate_name + self.src is None + or self.src.path != path + or self.src.lon_name != longitude_coordinate_name + or self.src.lat_name != latitude_coordinate_name + or self.src.elevation_name != vertical_coordinate_name ): self.src = SourceBathy( path, From 77a57383788fc5802ce1c9f1ef200a1414d20f43 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:38:34 -0600 Subject: [PATCH 21/63] This --- mom6_forge/topo.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 72eb8d7f..1d8c63ef 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -741,10 +741,10 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0): # Save to object (Build TCM Object) self.send_entire_depth_change_to_tcm(new_values) - def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): + def _compute_stats(self, nx_sub, ny_sub, mask_hmin): """Compute per-cell depth statistics by uniform sub-sampling. - Results are cached on ``src._topo_stats`` so a second call with the + Results are cached on ``stats`` so a second call with the same source file returns immediately without recomputation. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) @@ -759,11 +759,11 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): xr.Dataset — ``OCN_FRAC``, ``D_mean``, ``D_min``, ``D_max``, ``D2_mean``. """ assert ( - self._src is not None + self.src is not None ), "Source bathymetry must be loaded to compute topo stats" - src = self._src - if src._topo_stats is not None: - return src._topo_stats + src = self.src + if self._stats is not None: + return self._stats # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( @@ -788,7 +788,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1)) dims = ["ny", "nx"] - src._topo_stats = xr.Dataset( + self._stats = xr.Dataset( { "OCN_FRAC": xr.DataArray( ocn_frac, @@ -823,7 +823,7 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): ), } ) - return src._topo_stats + return self._stats def set_from_dataset( self, From 0d3e55ecc5d71f78a0cd553022f25536407d2235 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:42:26 -0600 Subject: [PATCH 22/63] Bleh --- mom6_forge/topo.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 1c05dcff..5a2916b2 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -816,18 +816,19 @@ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): def direct_stats_depth(self, statistic): """Set the topo depth to a statistic from compute_topo_stats""" - assert statistic in [ - "D_mean", - "D_min", - "D_max", - ], f"Invalid statistic {statistic}, must be one of OCN_FRAC, D_mean, D_min, D_max, D2_mean" + assert ( - self._src is not None - ), "Source bathymetry must be loaded to compute topo stats" + self._stats is not None + ), "Source bathymetry must have topo stats computed, please call _compute_stats first if you have not already" + approved_list = [] + for key in self._stats: + if key.startswith("D_"): + approved_list.append(key[2:]) assert ( - self._src._topo_stats is not None - ), "Source bathymetry must have topo stats computed, please call _compute_topo_stats first if you have not already" - self.send_entire_depth_change_to_tcm(self._src._topo_stats[statistic]) + statistic in approved_list + ), f"Invalid statistic {statistic}, must be one of {approved_list}" + + self.send_entire_depth_change_to_tcm(self._stats[statistic]) def set_from_dataset( self, From 3006ab00ed217f2715ef7b6b0f81b73bb5c1529f Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 14:49:53 -0600 Subject: [PATCH 23/63] Bug Fix --- mom6_forge/topo.py | 2 +- tests/test_topo_bathymetry_workflows.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 1d8c63ef..5bcbfede 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -762,7 +762,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): self.src is not None ), "Source bathymetry must be loaded to compute topo stats" src = self.src - if self._stats is not None: + if hasattr(self, "_stats") and self._stats is not None: return self._stats # Compute subsampling factor and generate sub-point grid diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 5a5dbf30..ca1f9116 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -22,7 +22,7 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): # Test with different sub-sampling densities for nx_sub, ny_sub in [(2, 2), (3, 3)]: # Call _compute_topo_stats - stats = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + stats = topo._compute_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) # Verify output is a Dataset with expected variables assert isinstance(stats, xr.Dataset) @@ -52,6 +52,6 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): ).all() # Verify caching: second call should return cached result - stats2 = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + stats2 = topo._compute_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) # Should be the exact same object (cached) assert stats2 is stats From f62bf9dacb17206866125a43520b3c46fc01f416 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 15:02:39 -0600 Subject: [PATCH 24/63] Fix --- mom6_forge/topo.py | 2 +- tests/conftest.py | 6 +++--- tests/test_topo.py | 6 +++--- tests/test_topo_bathymetry_workflows.py | 17 +++++++++++++++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 6d013a7b..efd3f123 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -839,7 +839,7 @@ def direct_stats_depth(self, statistic): statistic in approved_list ), f"Invalid statistic {statistic}, must be one of {approved_list}" - self.send_entire_depth_change_to_tcm(self._stats[statistic]) + self.send_entire_depth_change_to_tcm(self._stats[f"D_{statistic}"]) def set_from_dataset( self, diff --git a/tests/conftest.py b/tests/conftest.py index bbdbbbb1..ff6e8b11 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -307,14 +307,14 @@ def synthetic_bathy_file(): lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer # Create synthetic elevation data (positive-up, like GEBCO) - # Ocean is negative (water), land is positive - elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + # Ocean is positive (water), land is negative + elevation = np.full((len(lat), len(lon)), 500.0) # Ocean baseline = 500m deep # Add synthetic land masses (islands) # Create an island around (280, 8.5) lon_2d, lat_2d = np.meshgrid(lon, lat) island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 - elevation[island_mask] = 200.0 # Synthetic island + elevation[island_mask] = -200.0 # Synthetic island ds = xr.Dataset( { diff --git a/tests/test_topo.py b/tests/test_topo.py index 56c7db2a..5e3390a4 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -23,14 +23,14 @@ def synthetic_bathy_file(): lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer # Create synthetic elevation data (positive-up, like GEBCO) - # Ocean is negative (water), land is positive - elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + # Ocean is positive (water), land is negative + elevation = np.full((len(lat), len(lon)), 500.0) # Ocean baseline = 500m deep # Add synthetic land masses (islands) # Create an island around (280, 8.5) lon_2d, lat_2d = np.meshgrid(lon, lat) island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 - elevation[island_mask] = 200.0 # Synthetic island + elevation[island_mask] = -200.0 # Synthetic island ds = xr.Dataset( { diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index ca1f9116..eaeaf702 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -55,3 +55,20 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): stats2 = topo._compute_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) # Should be the exact same object (cached) assert stats2 is stats + + +def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): + """Test direct_stats_depth method which computes depth stats directly from source bathymetry without sub-sampling.""" + topo = get_rect_topo + + # Load source bathymetry and slice to topo domain + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + topo.src = src + topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) + + # Call direct_stats_depth + topo.direct_stats_depth("mean") + + mask = ~np.isnan(topo.depth.values) + assert np.isclose(topo.depth.values[mask], topo._stats["D_mean"].values[mask]).all() From 525445e4c4acad8d4116e438b7bbf9e870e04244 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 15:45:44 -0600 Subject: [PATCH 25/63] positive down --- mom6_forge/_source_bathy.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 11931370..92aba566 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -38,11 +38,13 @@ def __init__( lon_name="lon", lat_name="lat", elevation_name="elevation", + positive_down=False, ): self.path = Path(path) self.lon_name = lon_name self.lat_name = lat_name self.elevation_name = elevation_name + self.positive_down = positive_down # depth should be positive down (ocean > 0) if True, otherwise positive up (ocean < 0) self._da = None # set by slice_to_domain self._topo_stats = None # set by compute_topo_stats @@ -86,7 +88,8 @@ def slice_to_domain(self, topo, buf=0.5): ds = ds.sel( {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} ) - self._ds = ds + if not self.positive_down: + ds[self.elevation_name] = -ds[self.elevation_name] self._da = ds[self.elevation_name].load() return self @@ -107,7 +110,10 @@ def lat(self): @property def depth(self): """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src).""" - return -self._da.values.astype(float) + if self.positive_down: + return self._da.values.astype(float) + else: + return -self._da.values.astype(float) @property def da(self): From 294beeac41a4accd0b5f7299054cdec3cec3dc94 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 30 Apr 2026 15:48:08 -0600 Subject: [PATCH 26/63] Sb --- mom6_forge/_source_bathy.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 92aba566..98d893a5 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -46,6 +46,7 @@ def __init__( self.elevation_name = elevation_name self.positive_down = positive_down # depth should be positive down (ocean > 0) if True, otherwise positive up (ocean < 0) self._da = None # set by slice_to_domain + self._ds = None # set by slice_to_domain self._topo_stats = None # set by compute_topo_stats # ------------------------------------------------------------------ @@ -88,8 +89,7 @@ def slice_to_domain(self, topo, buf=0.5): ds = ds.sel( {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} ) - if not self.positive_down: - ds[self.elevation_name] = -ds[self.elevation_name] + self._ds = ds self._da = ds[self.elevation_name].load() return self @@ -118,12 +118,20 @@ def depth(self): @property def da(self): """Raw elevation DataArray with source coordinate names (positive-up).""" - return self._da + if not self.positive_down: + return -self._da + else: + return self._da @property def ds(self): """Raw dataset with source coordinate names (positive-up).""" - return self._ds + if not self.positive_down: + ds = self._ds.copy() + ds[self.elevation_name] = -ds[self.elevation_name] + return ds + else: + return self._ds def __repr__(self): shape = self._da.shape if self._da is not None else "not loaded" From 700740aac1d9bff7729530d41cb8510a175f2829 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 1 May 2026 11:35:35 -0600 Subject: [PATCH 27/63] Have Slice be on Init like tcm --- mom6_forge/_source_bathy.py | 5 ++++- mom6_forge/topo.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 98d893a5..fb70e625 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -34,11 +34,13 @@ class SourceBathy: def __init__( self, + topo, path, lon_name="lon", lat_name="lat", elevation_name="elevation", positive_down=False, + buf=0.5, ): self.path = Path(path) self.lon_name = lon_name @@ -48,12 +50,13 @@ def __init__( self._da = None # set by slice_to_domain self._ds = None # set by slice_to_domain self._topo_stats = None # set by compute_topo_stats + self._slice_to_domain(topo, buf=buf) # ------------------------------------------------------------------ # Loading # ------------------------------------------------------------------ - def slice_to_domain(self, topo, buf=0.5): + def _slice_to_domain(self, topo, buf=0.5): """Load and clip elevation to the topo grid extent plus ``buf`` degrees. Handles the global-longitude seam automatically. Mutates ``self`` diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 705724d9..830bc0b7 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -452,11 +452,12 @@ def set_src( or self.src.elevation_name != vertical_coordinate_name ): self.src = SourceBathy( + self, path, longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, - ).slice_to_domain(self) + ) return self.src def clear_user_mask(self): From 045b7135173b8ffeea317e245e528e140d48fa17 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 1 May 2026 11:48:18 -0600 Subject: [PATCH 28/63] Changes --- mom6_forge/topo.py | 4 ++++ tests/test_source_bathy.py | 10 ++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 830bc0b7..fac2e14b 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -440,6 +440,8 @@ def set_src( longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, + positive_down=False, + buf=0.5, ): """Set a :class:`SourceBathy`, creating and slicing a new one only when the path or coordinate names differ from the current cache.""" @@ -457,6 +459,8 @@ def set_src( longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, + positive_down=positive_down, + buf=buf, ) return self.src diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index a96b6b3f..4d2b5e81 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -50,9 +50,10 @@ def synthetic_bathy_file(): Path(bathy_file).unlink() -def test_source_bathy_initialization(synthetic_bathy_file): +def test_source_bathy_initialization(synthetic_bathy_file, get_rect_topo): """Test SourceBathy initialization and coordinate names.""" src = SourceBathy( + get_rect_topo, synthetic_bathy_file, lon_name="lon", lat_name="lat", @@ -63,15 +64,13 @@ def test_source_bathy_initialization(synthetic_bathy_file): assert src.lon_name == "lon" assert src.lat_name == "lat" assert src.elevation_name == "elevation" - assert src._da is None # Not loaded yet def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): """Smoke test: load and slice elevation to topo domain.""" topo = get_rect_topo - src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.4) + src = SourceBathy(topo, synthetic_bathy_file) # Verify data was loaded assert src._da is not None @@ -88,8 +87,7 @@ def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): """Test that elevation is converted to positive-down depth.""" topo = get_rect_topo - src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.5) + src = SourceBathy(topo, synthetic_bathy_file) # Get depth and verify sign conversion depth = src.depth From 7b46dcf6ad3749a6de7491d612b41c79c21ebcaf Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 1 May 2026 12:59:15 -0600 Subject: [PATCH 29/63] Bug Fix --- tests/test_topo_bathymetry_workflows.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index eaeaf702..41a4834b 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -15,8 +15,7 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.5) + src = SourceBathy(topo, synthetic_bathy_file) topo._src = src # Test with different sub-sampling densities @@ -62,8 +61,7 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.5) + src = SourceBathy(topo, synthetic_bathy_file) topo.src = src topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) From 594f898b72c18c7afdafba940755eb7983858662 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 1 May 2026 13:22:38 -0600 Subject: [PATCH 30/63] Source XESMF ready dataset --- mom6_forge/_source_bathy.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index fb70e625..165885c5 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -118,6 +118,24 @@ def depth(self): else: return -self._da.values.astype(float) + @property + def xesmf_ready_ds(self): + """Dataset with standardized coordinate names and positive-down elevation.""" + ds = self.ds.copy() + ds.rename( + { + self.lon_name: "lon", + self.lat_name: "lat", + self.elevation_name: "depth", + } + ) + ds.depth.attrs["_FillValue"] = -1e20 + ds.depth.attrs["units"] = "meters" + ds.depth.attrs["standard_name"] = "height_above_reference_ellipsoid" + ds.depth.attrs["long_name"] = "Depth relative to sea level" + ds.depth.attrs["coordinates"] = "lon lat" + return ds + @property def da(self): """Raw elevation DataArray with source coordinate names (positive-up).""" From d10c07f847adb3911ac7ef0af7ae82f4d43e8d71 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 1 May 2026 15:26:21 -0600 Subject: [PATCH 31/63] Bleh --- tests/test_topo_bathymetry_workflows.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index ca1f9116..214581c4 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -15,8 +15,7 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(synthetic_bathy_file) - src.slice_to_domain(topo, buf=0.5) + src = SourceBathy(topo, synthetic_bathy_file) topo._src = src # Test with different sub-sampling densities From 24bc6af9733d9dafbaea27c3282b1301468a603b Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 4 May 2026 10:01:17 -0600 Subject: [PATCH 32/63] Review Comments --- mom6_forge/_source_bathy.py | 100 +++++++++++++++--------------------- 1 file changed, 42 insertions(+), 58 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 165885c5..44113779 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -19,7 +19,7 @@ class SourceBathy: """Regional slice of a source bathymetry dataset (e.g. GEBCO). - Holds the loaded, domain-clipped elevation DataArray together with its + Holds the loaded, domain-clipped, ESMF prepped elevation DataArray together with its coordinate-name metadata. Per-cell depth statistics are computed and cached here so repeated calls with the same source file skip the expensive sub-sampling step. @@ -38,24 +38,36 @@ def __init__( path, lon_name="lon", lat_name="lat", - elevation_name="elevation", - positive_down=False, + depth_name="elevation", + depth_positive=True, buf=0.5, ): self.path = Path(path) - self.lon_name = lon_name - self.lat_name = lat_name - self.elevation_name = elevation_name - self.positive_down = positive_down # depth should be positive down (ocean > 0) if True, otherwise positive up (ocean < 0) - self._da = None # set by slice_to_domain - self._ds = None # set by slice_to_domain - self._topo_stats = None # set by compute_topo_stats + self.ds = xr.open_dataset(self.path, chunks="auto") + self._rename_dims( + lon_name=lon_name, lat_name=lat_name, depth_name=depth_name + ) # ensure consistent coordinate names for slicing self._slice_to_domain(topo, buf=buf) + self._ensure_depth_positive(depth_positive) # ------------------------------------------------------------------ # Loading # ------------------------------------------------------------------ + def _rename_dims(self, lon_name, lat_name, depth_name): + """Rename dimensions in the source dataset to match the provided names. This helps prep the dataset for ESMF regridding, which expects specific coordinate names.""" + + self._ds = self._ds.rename( + { + lon_name: "lon", + lat_name: "lat", + depth_name: "depth", + } + ) + self.lon_name = "lon" + self.lat_name = "lat" + self.depth_name = "depth" + def _slice_to_domain(self, topo, buf=0.5): """Load and clip elevation to the topo grid extent plus ``buf`` degrees. @@ -74,27 +86,30 @@ def _slice_to_domain(self, topo, buf=0.5): f"Slicing source bathymetry to domain: {lon_extent} x {lat_extent} with buffer {buf}" ) - ds_src = xr.open_dataset(self.path, chunks="auto") - - ds = ds_src.sel( + self._ds = self._ds.sel( {self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)} ) - dlon = float(ds[self.lon_name][1] - ds[self.lon_name][0]) - total_lon = float(ds[self.lon_name][-1] - ds[self.lon_name][0] + dlon) + dlon = float(self._ds[self.lon_name][1] - self._ds[self.lon_name][0]) + total_lon = float( + self._ds[self.lon_name][-1] - self._ds[self.lon_name][0] + dlon + ) if np.isclose(total_lon, 360): - ds = longitude_slicer( - ds, + self._ds = longitude_slicer( + self._ds, np.array(lon_extent) + np.array([-buf, buf]), self.lon_name, ) else: - ds = ds.sel( + self._ds = self._ds.sel( {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} ) - self._ds = ds - self._da = ds[self.elevation_name].load() - return self + return self._ds + + def _ensure_depth_positive(self, depth_positive): + """Ensure depth is positive-down. Mutates self in place.""" + if not depth_positive: + self._ds[self.depth_name] = -self._ds[self.depth_name] # ------------------------------------------------------------------ # Accessors @@ -103,56 +118,25 @@ def _slice_to_domain(self, topo, buf=0.5): @property def lon(self): """1-D longitude array.""" - return self._da[self.lon_name].values + return self.ds[self.lon_name].values @property def lat(self): """1-D latitude array.""" - return self._da[self.lat_name].values + return self.ds[self.lat_name].values @property def depth(self): """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src).""" if self.positive_down: - return self._da.values.astype(float) + return self.ds[self.depth_name].values else: - return -self._da.values.astype(float) - - @property - def xesmf_ready_ds(self): - """Dataset with standardized coordinate names and positive-down elevation.""" - ds = self.ds.copy() - ds.rename( - { - self.lon_name: "lon", - self.lat_name: "lat", - self.elevation_name: "depth", - } - ) - ds.depth.attrs["_FillValue"] = -1e20 - ds.depth.attrs["units"] = "meters" - ds.depth.attrs["standard_name"] = "height_above_reference_ellipsoid" - ds.depth.attrs["long_name"] = "Depth relative to sea level" - ds.depth.attrs["coordinates"] = "lon lat" - return ds - - @property - def da(self): - """Raw elevation DataArray with source coordinate names (positive-up).""" - if not self.positive_down: - return -self._da - else: - return self._da + return -self.ds[self.depth_name].values @property def ds(self): - """Raw dataset with source coordinate names (positive-up).""" - if not self.positive_down: - ds = self._ds.copy() - ds[self.elevation_name] = -ds[self.elevation_name] - return ds - else: - return self._ds + """Raw dataset with source coordinate names (positive-down).""" + return self._ds def __repr__(self): shape = self._da.shape if self._da is not None else "not loaded" From fdca2352daff5f25749038bebe764054695915b1 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 4 May 2026 10:04:07 -0600 Subject: [PATCH 33/63] Pass Tests --- mom6_forge/_source_bathy.py | 9 +++------ tests/test_source_bathy.py | 9 ++++----- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 44113779..6eb5e6ef 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -43,7 +43,7 @@ def __init__( buf=0.5, ): self.path = Path(path) - self.ds = xr.open_dataset(self.path, chunks="auto") + self._ds = xr.open_dataset(self.path, chunks="auto") self._rename_dims( lon_name=lon_name, lat_name=lat_name, depth_name=depth_name ) # ensure consistent coordinate names for slicing @@ -127,11 +127,8 @@ def lat(self): @property def depth(self): - """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src).""" - if self.positive_down: - return self.ds[self.depth_name].values - else: - return -self.ds[self.depth_name].values + """2-D depth array, depth positive (ocean > 0)""" + return self.ds[self.depth_name].values @property def ds(self): diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index 4d2b5e81..dd1a4554 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -57,13 +57,13 @@ def test_source_bathy_initialization(synthetic_bathy_file, get_rect_topo): synthetic_bathy_file, lon_name="lon", lat_name="lat", - elevation_name="elevation", + depth_name="elevation", ) assert src.path == Path(synthetic_bathy_file) assert src.lon_name == "lon" assert src.lat_name == "lat" - assert src.elevation_name == "elevation" + assert src.depth_name == "depth" def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): @@ -73,14 +73,13 @@ def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): src = SourceBathy(topo, synthetic_bathy_file) # Verify data was loaded - assert src._da is not None assert src.lon is not None assert src.lat is not None # Verify shape makes sense assert len(src.lon) > 0, f"Expected lon data, got empty array" assert len(src.lat) > 0, f"Expected lat data, got empty array" - assert src._da.shape == (len(src.lat), len(src.lon)) + assert src.depth.shape == (len(src.lat), len(src.lon)) def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): @@ -100,4 +99,4 @@ def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): assert len(non_nan_values) > 0, "No valid depth values" assert np.any(non_nan_values > 0), "Expected positive depth values for ocean" - assert depth.shape == src.da.shape + assert depth.shape == src.depth.shape From 7dcf6d3d21155ca41158f71a64977d82ecb50a9f Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 4 May 2026 10:54:38 -0600 Subject: [PATCH 34/63] Review Comments --- mom6_forge/mapping.py | 8 ++++++-- mom6_forge/topo.py | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 6cc6da86..0b5aed2b 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -1010,7 +1010,8 @@ def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub): ------- sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub) """ - assert type(qlon) == type(qlat) == np.ndarray, "qlon and qlat must be numpy arrays" + assert isinstance(qlon, np.ndarray), "qlon must be a numpy array" + assert isinstance(qlat, np.ndarray), "qlat must be a numpy array" SW_lon = qlon[:-1, :-1] SW_lat = qlat[:-1, :-1] @@ -1077,7 +1078,7 @@ def regrid_with_subsampling( Parameters ---------- - input_dataset : xr.Dataset + input_dataset : xr.Dataset (not curvilinear) qlon, qlat : np.ndarray shape (ny+1, nx+1) Corner coordinates of the destination grid. nx_sub, ny_sub : int @@ -1087,6 +1088,9 @@ def regrid_with_subsampling( regridded_dataset : xr.Dataset Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc) """ + assert input_dataset.lon.dims == ("lon",) and input_dataset.lat.dims == ( + "lat", + ), "input_dataset must have 1D 'lon' and 'lat' coordinates" ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1 # Build the (ny, nx, ny_sub, nx_sub) sub-point grid diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 2e38bf29..ba9e01b8 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -768,8 +768,13 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): self.src is not None ), "Source bathymetry must be loaded to compute topo stats" src = self.src - if hasattr(self, "_stats") and self._stats is not None: - return self._stats + if hasattr(self, "_stats") and isinstance(self._stats, dict): + if self._stats["args"] == { + "nx_sub": nx_sub, + "ny_sub": ny_sub, + "mask_hmin": mask_hmin, + }: + return self._stats # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( @@ -796,6 +801,11 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): dims = ["ny", "nx"] self._stats = xr.Dataset( { + "args": { + "nx_sub": nx_sub, + "ny_sub": ny_sub, + "mask_hmin": mask_hmin, + }, "OCN_FRAC": xr.DataArray( ocn_frac, dims=dims, From c81311b546b3da24286c5ea616d4d71a21dc6aef Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 4 May 2026 11:01:31 -0600 Subject: [PATCH 35/63] tests pass --- mom6_forge/mapping.py | 2 +- mom6_forge/topo.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 0b5aed2b..8d452da5 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -1088,7 +1088,7 @@ def regrid_with_subsampling( regridded_dataset : xr.Dataset Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc) """ - assert input_dataset.lon.dims == ("lon",) and input_dataset.lat.dims == ( + assert len(input_dataset.lon.dims) == 1 and input_dataset.lat.dims == ( "lat", ), "input_dataset must have 1D 'lon' and 'lat' coordinates" ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1 diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index ba9e01b8..1e052ee2 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -768,12 +768,12 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): self.src is not None ), "Source bathymetry must be loaded to compute topo stats" src = self.src - if hasattr(self, "_stats") and isinstance(self._stats, dict): - if self._stats["args"] == { - "nx_sub": nx_sub, - "ny_sub": ny_sub, - "mask_hmin": mask_hmin, - }: + if hasattr(self, "_stats") and isinstance(self._stats, xr.Dataset): + if ( + self._stats.attrs.get("nx_sub") == nx_sub + and self._stats.attrs.get("ny_sub") == ny_sub + and self._stats.attrs.get("mask_hmin") == mask_hmin + ): return self._stats # Compute subsampling factor and generate sub-point grid @@ -786,7 +786,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): regridding_method="nearest_s2d", ) - depth_sub = ds[src.elevation_name].values # (ny, nx, ny_sub, nx_sub) + depth_sub = ds[src.depth_name].values # (ny, nx, ny_sub, nx_sub) is_ocean = depth_sub > mask_hmin ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) @@ -801,11 +801,6 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): dims = ["ny", "nx"] self._stats = xr.Dataset( { - "args": { - "nx_sub": nx_sub, - "ny_sub": ny_sub, - "mask_hmin": mask_hmin, - }, "OCN_FRAC": xr.DataArray( ocn_frac, dims=dims, @@ -837,7 +832,12 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): "units": "m2", }, ), - } + }, + attrs={ + "nx_sub": nx_sub, + "ny_sub": ny_sub, + "mask_hmin": mask_hmin, + }, ) return self._stats From 292c57fdc1deb8bc40fc4743c43ed18473d957c1 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 4 May 2026 14:32:59 -0600 Subject: [PATCH 36/63] Ensure attributes are there on the source ds --- mom6_forge/_source_bathy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 6eb5e6ef..fd36ca09 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -67,6 +67,10 @@ def _rename_dims(self, lon_name, lat_name, depth_name): self.lon_name = "lon" self.lat_name = "lat" self.depth_name = "depth" + if "units" not in self._ds[self.lon_name].attrs: + self._ds[self.lon_name].attrs["units"] = "degrees_east" + if "units" not in self._ds[self.lat_name].attrs: + self._ds[self.lat_name].attrs["units"] = "degrees_north" def _slice_to_domain(self, topo, buf=0.5): """Load and clip elevation to the topo grid extent plus ``buf`` degrees. From 8db1346edc34c5680bf4c6ca9ca5993c5b605723 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 8 May 2026 14:00:30 -0600 Subject: [PATCH 37/63] Review Commetns --- mom6_forge/_source_bathy.py | 10 +++++----- mom6_forge/topo.py | 26 +++++++++----------------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index fd36ca09..dae6e573 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -29,7 +29,7 @@ class SourceBathy: path : str or Path lon_name : str — longitude coordinate name. Default ``"lon"``. lat_name : str — latitude coordinate name. Default ``"lat"``. - elevation_name : str — elevation variable (positive-up). Default ``"elevation"``. + depth_name : str — depth variable. Default ``"depth"``. """ def __init__( @@ -38,8 +38,8 @@ def __init__( path, lon_name="lon", lat_name="lat", - depth_name="elevation", - depth_positive=True, + depth_name="depth", + is_input_depth_positive_below_msl=True, buf=0.5, ): self.path = Path(path) @@ -48,7 +48,7 @@ def __init__( lon_name=lon_name, lat_name=lat_name, depth_name=depth_name ) # ensure consistent coordinate names for slicing self._slice_to_domain(topo, buf=buf) - self._ensure_depth_positive(depth_positive) + self._ensure_depth_is_positive_below_msl(is_input_depth_positive_below_msl) # ------------------------------------------------------------------ # Loading @@ -110,7 +110,7 @@ def _slice_to_domain(self, topo, buf=0.5): ) return self._ds - def _ensure_depth_positive(self, depth_positive): + def _ensure_depth_is_positive_below_msl(self, depth_positive): """Ensure depth is positive-down. Mutates self in place.""" if not depth_positive: self._ds[self.depth_name] = -self._ds[self.depth_name] diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 8b217b3a..eb7a77f3 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -446,23 +446,15 @@ def set_src( ): """Set a :class:`SourceBathy`, creating and slicing a new one only when the path or coordinate names differ from the current cache.""" - path = Path(bathymetry_path) - if ( - self.src is None - or self.src.path != path - or self.src.lon_name != longitude_coordinate_name - or self.src.lat_name != latitude_coordinate_name - or self.src.elevation_name != vertical_coordinate_name - ): - self.src = SourceBathy( - self, - path, - longitude_coordinate_name, - latitude_coordinate_name, - vertical_coordinate_name, - positive_down=positive_down, - buf=buf, - ) + self.src = SourceBathy( + self, + Path(bathymetry_path), + longitude_coordinate_name, + latitude_coordinate_name, + vertical_coordinate_name, + positive_down=positive_down, + buf=buf, + ) return self.src def clear_user_mask(self): From feb866e24ebe20f79697858acf162720239973bc Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 8 May 2026 14:04:54 -0600 Subject: [PATCH 38/63] Bleh --- mom6_forge/_source_bathy.py | 4 ++-- tests/test_source_bathy.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index dae6e573..030b574a 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -39,7 +39,7 @@ def __init__( lon_name="lon", lat_name="lat", depth_name="depth", - is_input_depth_positive_below_msl=True, + is_input_positive_below_msl=True, buf=0.5, ): self.path = Path(path) @@ -48,7 +48,7 @@ def __init__( lon_name=lon_name, lat_name=lat_name, depth_name=depth_name ) # ensure consistent coordinate names for slicing self._slice_to_domain(topo, buf=buf) - self._ensure_depth_is_positive_below_msl(is_input_depth_positive_below_msl) + self._ensure_depth_is_positive_below_msl(is_input_positive_below_msl) # ------------------------------------------------------------------ # Loading diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index dd1a4554..97cdac93 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -70,7 +70,12 @@ def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): """Smoke test: load and slice elevation to topo domain.""" topo = get_rect_topo - src = SourceBathy(topo, synthetic_bathy_file) + src = SourceBathy( + topo, + synthetic_bathy_file, + depth_name="elevation", + is_input_positive_below_msl=False, + ) # Verify data was loaded assert src.lon is not None @@ -86,7 +91,12 @@ def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): """Test that elevation is converted to positive-down depth.""" topo = get_rect_topo - src = SourceBathy(topo, synthetic_bathy_file) + src = SourceBathy( + topo, + synthetic_bathy_file, + depth_name="elevation", + is_input_positive_below_msl=False, + ) # Get depth and verify sign conversion depth = src.depth From dfb4b0805c0811b24270c59b5d4cb178f4083aef Mon Sep 17 00:00:00 2001 From: manishvenu Date: Mon, 11 May 2026 10:19:55 -0600 Subject: [PATCH 39/63] Doc --- mom6_forge/_source_bathy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 030b574a..7b141060 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -1,9 +1,8 @@ """Source bathymetry loader for mom6_forge. ``SourceBathy`` is a lightweight data container for a regional slice of a -source bathymetry dataset. ``Topo._get_src()`` creates and caches one -automatically when ``set_from_dataset`` is called. Users who call pipeline -methods directly (e.g. ``high_res_regrid``, ``generate_mask_ocean_frac``) +source bathymetry dataset. Users who call pipeline +methods directly should construct a ``SourceBathy`` explicitly:: from mom6_forge._source_bathy import SourceBathy From 9bd4a3e20ec1f0da4814d0e939d53dac1409f3d9 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 15:53:55 -0400 Subject: [PATCH 40/63] Replace config_dataset --- .vscode/settings.json | 7 ++ mom6_forge/_source_bathy.py | 29 ++--- mom6_forge/grid.py | 17 +++ mom6_forge/topo.py | 209 ++++++------------------------------ mom6_forge/utils.py | 9 +- setup.py | 2 +- tests/test_utils.py | 42 ++++++++ 7 files changed, 121 insertions(+), 194 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9b388533 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 7b141060..83ca3454 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -43,7 +43,7 @@ def __init__( ): self.path = Path(path) self._ds = xr.open_dataset(self.path, chunks="auto") - self._rename_dims( + self._rename_dims_and_format_ds( lon_name=lon_name, lat_name=lat_name, depth_name=depth_name ) # ensure consistent coordinate names for slicing self._slice_to_domain(topo, buf=buf) @@ -53,7 +53,7 @@ def __init__( # Loading # ------------------------------------------------------------------ - def _rename_dims(self, lon_name, lat_name, depth_name): + def _rename_dims_and_format_ds(self, lon_name, lat_name, depth_name): """Rename dimensions in the source dataset to match the provided names. This helps prep the dataset for ESMF regridding, which expects specific coordinate names.""" self._ds = self._ds.rename( @@ -66,6 +66,14 @@ def _rename_dims(self, lon_name, lat_name, depth_name): self.lon_name = "lon" self.lat_name = "lat" self.depth_name = "depth" + self._ds.depth.attrs["missing_value"] = ( + -1e20 + ) # missing value expected by FRE tools + self._ds.depth.attrs["_FillValue"] = -1e20 + self._ds.depth.attrs["units"] = "meters" + self._ds.depth.attrs["standard_name"] = "height_above_reference_ellipsoid" + self._ds.depth.attrs["long_name"] = "Elevation relative to sea level" + self._ds.depth.attrs["coordinates"] = "lon lat" if "units" not in self._ds[self.lon_name].attrs: self._ds[self.lon_name].attrs["units"] = "degrees_east" if "units" not in self._ds[self.lat_name].attrs: @@ -94,19 +102,12 @@ def _slice_to_domain(self, topo, buf=0.5): ) dlon = float(self._ds[self.lon_name][1] - self._ds[self.lon_name][0]) - total_lon = float( - self._ds[self.lon_name][-1] - self._ds[self.lon_name][0] + dlon + self._ds = longitude_slicer( + self._ds, + np.array(lon_extent) + np.array([-buf, buf]), + self.lon_name, ) - if np.isclose(total_lon, 360): - self._ds = longitude_slicer( - self._ds, - np.array(lon_extent) + np.array([-buf, buf]), - self.lon_name, - ) - else: - self._ds = self._ds.sel( - {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} - ) + return self._ds def _ensure_depth_is_positive_below_msl(self, depth_positive): diff --git a/mom6_forge/grid.py b/mom6_forge/grid.py index 2b43f6d6..f0b57a3a 100644 --- a/mom6_forge/grid.py +++ b/mom6_forge/grid.py @@ -940,3 +940,20 @@ def write_supergrid( ds.attrs["filename"] = os.path.basename(path) ds.to_netcdf(path, format="NETCDF3_64BIT") return ds + + def get_esmf_ready_tracer_ds(self): + """This is to generate a ds that esmf/xesmf can use for regridding. It contains the tlat and tlon coordinates and the area variable (tarea) that can be used as weights for conservative regridding.""" + ds = xr.Dataset( + { + "lat": self.tlat, + "lon": self.tlon, + "tarea": self.tarea, + } + ) + ds.lon.attrs["units"] = "degrees_east" + ds.lon.attrs["_FillValue"] = 1e20 + ds.lat.attrs["units"] = "degrees_north" + ds.lat.attrs["_FillValue"] = 1e20 + ds.depth.attrs["units"] = "meters" + ds.depth.attrs["coordinates"] = "lon lat" + return ds diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 9cc8710a..880ed7e1 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -202,7 +202,7 @@ def masked_depth(self): @property def src(self): """ - Cached SourceBathy object representing the source bathymetry dataset sliced to the topo grid extent. This is set by set_src() when a new source bathymetry is specified, and can be accessed for any cached source dataset. + SourceBathy object representing the source bathymetry dataset sliced to the topo grid extent. This is set by set_src() when a new source bathymetry is specified, and can be accessed for any source dataset. """ return self._src @@ -477,7 +477,7 @@ def set_src( longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, - positive_down=False, + is_input_positive_below_msl=False, buf=0.5, ): """Set a :class:`SourceBathy`, creating and slicing a new one @@ -488,7 +488,7 @@ def set_src( longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, - positive_down=positive_down, + is_input_positive_below_msl=is_input_positive_below_msl, buf=buf, ) return self.src @@ -783,11 +783,10 @@ def set_from_dataset( latitude_coordinate_name, vertical_coordinate_name, fill_channels=False, - positive_down=False, + is_input_positive_below_msl=False, output_dir=Path(""), write_to_file=True, regridding_method="bilinear", - run_config_dataset=True, run_regrid_dataset=True, run_tidy_dataset=True, ): @@ -813,7 +812,7 @@ def set_from_dataset( fill_channels (Optional[bool]): Whether or not to fill in diagonal channels. This removes more narrow inlets, but can also connect extra islands to land. Default: ``False``. - positive_down (Optional[bool]): If ``True``, it assumes that the + is_input_positive_below_msl (Optional[bool]): If ``True``, it assumes that the bathymetry vertical coordinate is positive downwards. Default: ``False``. write_to_file (Optional[bool]): Whether to write the bathymetry to a file. Default: ``True``. regridding_method (Optional[str]): The type of regridding method to use. Defaults to self.regridding_method @@ -825,22 +824,26 @@ def set_from_dataset( Call ``[topo_object_name].mpi_set_from_dataset()`` instead. Follow the given instructions for using mpi and ESMF_Regrid outside of a python environment. This breaks up the process, so be sure to call ``[topo_object_name].tidy_dataset() after regridding with mpi.""") - if run_config_dataset: - self.bathymetry_output, self.empty_bathy = self.config_dataset( - bathymetry_path=bathymetry_path, - longitude_coordinate_name=longitude_coordinate_name, - latitude_coordinate_name=latitude_coordinate_name, - vertical_coordinate_name=vertical_coordinate_name, - fill_channels=fill_channels, - positive_down=positive_down, - output_dir=output_dir, - write_to_file=write_to_file, + output_dir = Path(output_dir) + self.set_src( + bathymetry_path=bathymetry_path, + longitude_coordinate_name=longitude_coordinate_name, + latitude_coordinate_name=latitude_coordinate_name, + vertical_coordinate_name=vertical_coordinate_name, + is_input_positive_below_msl=is_input_positive_below_msl, + ) + self.src_bathymetry_dataset = self.src.ds + self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() + if write_to_file: + self.src_bathymetry_dataset.to_netcdf(output_dir / "bathymetry_original.nc") + self.destination_bathymetry.to_netcdf( + output_dir / "bathymetry_unfinished.nc" ) if run_regrid_dataset: self.regridded_bathy = regrid_dataset_via_xesmf( - input_dataset=self.bathymetry_output, - output_dataset=self.empty_bathy, + input_dataset=self.src_bathymetry_dataset, + output_dataset=self.destination_bathymetry, regridding_method=regridding_method, write_to_file=write_to_file, output_path=output_dir / "bathymetry_unfinished.nc", @@ -850,7 +853,7 @@ def set_from_dataset( # Set directly into self.depth in this function self.tidy_dataset( fill_channels=fill_channels, - positive_down=positive_down, + is_input_positive_below_msl=is_input_positive_below_msl, vertical_coordinate_name="depth", bathymetry=self.regridded_bathy, output_dir=output_dir, @@ -866,10 +869,8 @@ def mpi_set_from_dataset( longitude_coordinate_name, latitude_coordinate_name, vertical_coordinate_name, - fill_channels=False, - positive_down=False, + is_input_positive_below_msl=False, output_dir=Path(""), - write_to_file=True, verbose=True, ): if verbose: @@ -895,171 +896,27 @@ def mpi_set_from_dataset( For additional details see: https://xesmf.readthedocs.io/en/latest/large_problems_on_HPC.html """) - self.bathymetry_output, self.empty_bathy = self.config_dataset( + output_dir = Path(output_dir) + self.set_src( bathymetry_path=bathymetry_path, longitude_coordinate_name=longitude_coordinate_name, latitude_coordinate_name=latitude_coordinate_name, vertical_coordinate_name=vertical_coordinate_name, - fill_channels=fill_channels, - positive_down=positive_down, - output_dir=output_dir, - write_to_file=write_to_file, + is_input_positive_below_msl=is_input_positive_below_msl, ) + self.src_bathymetry_dataset = self.src.ds + self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() + self.src_bathymetry_dataset.to_netcdf(output_dir / "bathymetry_original.nc") + self.destination_bathymetry.to_netcdf(output_dir / "bathymetry_unfinished.nc") print( "Configuration complete. Ready for regridding with MPI. See documentation for more details." ) - def config_dataset( - self, - bathymetry_path, - longitude_coordinate_name, - latitude_coordinate_name, - vertical_coordinate_name, - fill_channels=False, - positive_down=False, - output_dir=Path(""), - write_to_file=True, - ): - """ - Sets up necessary objects/files for regridding bathymetry. Can be flexibly used with - mapping.regrid_bathy_dataset() or user can manually regrid with ESMF_regrid. - - If manual regridding is necessary, write_to_file must be set to True. - - Arguments: - bathymetry_path (str): Path to netCDF file with bathymetry data. - longitude_coordinate_name (Optional[str]): The name of the longitude coordinate in the bathymetry - dataset at ``bathymetry_path``. For example, for GEBCO bathymetry: ``'lon'`` (default). - latitude_coordinate_name (Optional[str]): The name of the latitude coordinate in the bathymetry - dataset at ``bathymetry_path``. For example, for GEBCO bathymetry: ``'lat'`` (default). - vertical_coordinate_name (Optional[str]): The name of the vertical coordinate in the bathymetry - dataset at ``bathymetry_path``. For example, for GEBCO bathymetry: ``'elevation'`` (default). - output_dir: str | Path - The str or Path the write to file should write to. Defaults to the directory the script is running in. - write_to_file (Optional[bool]): Files saved to ``output_dir``. Defaults to ``True``. Must be set to true if using manual regridding methods with ESMF_regrid. - - Returns: - (``bathymetry_output``,``empty_bathy``) (tuple of Datasets): where ``bathymetry_output`` is the original bathymetry data with proper metadata and attributes and ``empty_bathy`` is a template for the regridder. - """ - coordinate_names = { - "xh": longitude_coordinate_name, - "yh": latitude_coordinate_name, - "depth": vertical_coordinate_name, - } - longitude_extent = ( - float(self._grid.qlon.min()), - float(self._grid.qlon.max()), - ) - latitude_extent = ( - float(self._grid.qlat.min()), - float(self._grid.qlat.max()), - ) - - bathymetry = xr.open_dataset(bathymetry_path, chunks="auto")[ - coordinate_names["depth"] - ] - - bathymetry = bathymetry.sel( - { - coordinate_names["yh"]: slice( - latitude_extent[0] - 0.5, latitude_extent[1] + 0.5 - ) - } # 0.5 degree latitude buffer (hardcoded) for regridding - ).astype("float") - - ## Check if the original bathymetry provided has a longitude extent that goes around the globe - ## to take care of the longitude seam when we slice out the regional domain. - - horizontal_resolution = ( - bathymetry[coordinate_names["xh"]][1] - - bathymetry[coordinate_names["xh"]][0] - ) - - horizontal_extent = ( - bathymetry[coordinate_names["xh"]][-1] - - bathymetry[coordinate_names["xh"]][0] - + horizontal_resolution - ) - - longitude_buffer = 0.5 # 0.5 degree longitude buffer (hardcoded) for regridding - - if np.isclose(horizontal_extent, 360): - ## longitude extent that goes around the globe -- use longitude_slicer - bathymetry = longitude_slicer( - bathymetry, - np.array(longitude_extent) - + np.array([-longitude_buffer, longitude_buffer]), - coordinate_names["xh"], - ) - else: - ## otherwise, slice normally - bathymetry = bathymetry.sel( - { - coordinate_names["xh"]: slice( - longitude_extent[0] - longitude_buffer, - longitude_extent[1] + longitude_buffer, - ) - } - ) - - bathymetry.attrs["missing_value"] = -1e20 # missing value expected by FRE tools - bathymetry_output = xr.Dataset({"depth": bathymetry}) - bathymetry.close() - - bathymetry_output = bathymetry_output.rename( - {coordinate_names["xh"]: "lon", coordinate_names["yh"]: "lat"} - ) - - bathymetry_output.depth.attrs["_FillValue"] = -1e20 - bathymetry_output.depth.attrs["units"] = "meters" - bathymetry_output.depth.attrs["standard_name"] = ( - "height_above_reference_ellipsoid" - ) - bathymetry_output.depth.attrs["long_name"] = "Elevation relative to sea level" - bathymetry_output.depth.attrs["coordinates"] = "lon lat" - - # Ensure the source bathymetry as a units attribute - if "units" not in bathymetry_output["lon"].attrs: - bathymetry_output["lon"].attrs["units"] = "degrees_east" - if "units" not in bathymetry_output["lat"].attrs: - bathymetry_output["lat"].attrs["units"] = "degrees_north" - - if write_to_file: - bathymetry_output.to_netcdf( - output_dir / "bathymetry_original.nc", - mode="w", - engine="netcdf4", - ) - - empty_bathy = xr.Dataset( - { - "lon": self._grid.tlon, - "lat": self._grid.tlat, - } - ) - - empty_bathy = empty_bathy.set_coords(("lon", "lat")) - empty_bathy["depth"] = xr.zeros_like(empty_bathy["lon"]) - empty_bathy.lon.attrs["units"] = "degrees_east" - empty_bathy.lon.attrs["_FillValue"] = 1e20 - empty_bathy.lat.attrs["units"] = "degrees_north" - empty_bathy.lat.attrs["_FillValue"] = 1e20 - empty_bathy.depth.attrs["units"] = "meters" - empty_bathy.depth.attrs["coordinates"] = "lon lat" - if write_to_file: - empty_bathy.to_netcdf( - output_dir / "bathymetry_unfinished.nc", - mode="w", - engine="netcdf4", - ) - empty_bathy.close() - return bathymetry_output, empty_bathy - def tidy_dataset( self, fill_channels=False, - positive_down=False, + is_input_positive_below_msl=False, vertical_coordinate_name="depth", bathymetry=None, output_dir=Path(""), @@ -1081,7 +938,7 @@ def tidy_dataset( fill_channels (Optional[bool]): Whether to fill in diagonal channels. This removes more narrow inlets, but can also connect extra islands to land. Default: ``False``. - positive_down (Optional[bool]): If ``False`` (default), assume that + is_input_positive_below_msl (Optional[bool]): If ``False`` (default), assume that bathymetry vertical coordinate is positive down, as is the case in GEBCO for example. bathymetry (Optional[xr.Dataset]): The bathymetry dataset to tidy up. If not provided, it will read the bathymetry from the file ``bathymetry_unfinished.nc`` in the input directory @@ -1111,7 +968,7 @@ def tidy_dataset( bathymetry.expand_dims("tiles", 0) - if not positive_down: + if not is_input_positive_below_msl: ## Ensure that coordinate is positive down! bathymetry["depth"] *= -1 diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py index 0b789c06..348b97e3 100644 --- a/mom6_forge/utils.py +++ b/mom6_forge/utils.py @@ -430,6 +430,7 @@ def longitude_slicer(data, longitude_extent, longitude_coords): longitude_coords = [longitude_coords] for lon in longitude_coords: + central_longitude = np.mean(longitude_extent) ## Midpoint of target domain ## Find a corresponding value for the intended domain midpoint in our data. @@ -437,6 +438,7 @@ def longitude_slicer(data, longitude_extent, longitude_coords): lons = data[lon].data dlons = lons[1] - lons[0] + lon_span = lons[-1] - lons[0] + dlons assert np.allclose( np.diff(lons), dlons * np.ones(np.size(lons) - 1) @@ -456,7 +458,8 @@ def longitude_slicer(data, longitude_extent, longitude_coords): ## Number of indices between the data midpoint and the target midpoint. ## Sign indicates direction needed to shift. shift = int( - -(data[lon].shape[0] * (_central_longitude - central_data)) // 360 + -(data[lon].shape[0] * (_central_longitude - central_data)) + // lon_span ) ## Shift data so that the midpoint of the target domain is the middle of @@ -486,9 +489,9 @@ def longitude_slicer(data, longitude_extent, longitude_coords): ## Choose the number of lon points to take from the middle, including a buffer. ## Use this to index the new global dataset - num_lonpoints = ( + num_lonpoints = int( int(data[lon].shape[0] * (central_longitude - longitude_extent[0])) - // 360 + // lon_span ) data = new_data.isel( diff --git a/setup.py b/setup.py index 4ca2e1ca..763c5853 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="mom6_forge", # Replace with your own username - version="0.0.7", + version="0.1.0", author="Alper Altuntas", author_email="altuntas@ucar.edu", description="MOM6 simple grid and bathymetry generator", diff --git a/tests/test_utils.py b/tests/test_utils.py index 7747f712..2181f782 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -66,6 +66,48 @@ def test_longitude_slicer(): longitude_slicer(data, longitude_extent, "random_lon") +def test_longitude_slicers_regionally(): + nx, ny = 4, 14 + + latitude_extent = (2, 5) + longitude_extent = (-90, -70) + + dims = ["random_lat", "random_lon"] + + dlambda = (longitude_extent[1] - longitude_extent[0]) / 2 + + data = xr.DataArray( + np.random.random((ny, nx)), + dims=dims, + coords={ + "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), + "random_lon": np.linspace( + longitude_extent[0] - 2, longitude_extent[1] + 2, nx + ), + }, + ) + + # Regular regional + data_regular = longitude_slicer(data, longitude_extent, "random_lon") + data_east = longitude_slicer(data, (270, 290), "random_lon") + assert (data_regular == data_east).all() + + # Seam data + longitude_extent = (-5, 5) + data = xr.DataArray( + np.random.random((ny, nx)), + dims=dims, + coords={ + "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), + "random_lon": np.linspace( + longitude_extent[0] - 2, longitude_extent[1] + 2, nx + ), + }, + ) + data_regular = longitude_slicer(data, longitude_extent, "random_lon") + assert len(data_regular.random_lon) > 0 + + @pytest.mark.parametrize( ("v1", "v2", "v3", "v4", "true_area"), [ From b338fa32d831698708c98197643c3e8e54211b98 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 15:54:51 -0400 Subject: [PATCH 41/63] GIt --- .gitignore | 2 +- .vscode/settings.json | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 14ead66f..89d33a2c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ __pycache__/ docs/_build/ # C extensions *.so - +.vscode/ # Distribution / packaging .Python build/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9b388533..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file From d8896c0853576a3234f8bdf20262d37259f03d7d Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 15:56:57 -0400 Subject: [PATCH 42/63] Changes --- mom6_forge/grid.py | 2 -- mom6_forge/topo.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mom6_forge/grid.py b/mom6_forge/grid.py index f0b57a3a..6060b956 100644 --- a/mom6_forge/grid.py +++ b/mom6_forge/grid.py @@ -954,6 +954,4 @@ def get_esmf_ready_tracer_ds(self): ds.lon.attrs["_FillValue"] = 1e20 ds.lat.attrs["units"] = "degrees_north" ds.lat.attrs["_FillValue"] = 1e20 - ds.depth.attrs["units"] = "meters" - ds.depth.attrs["coordinates"] = "lon lat" return ds diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 880ed7e1..7563fb5e 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -834,6 +834,11 @@ def set_from_dataset( ) self.src_bathymetry_dataset = self.src.ds self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() + self.destination_bathymetry["depth"] = np.zeros_like( + self.destination_bathymetry.tarea + ) + self.destination_bathymetry.depth.attrs["units"] = "meters" + self.destination_bathymetry.depth.attrs["coordinates"] = "lon lat" if write_to_file: self.src_bathymetry_dataset.to_netcdf(output_dir / "bathymetry_original.nc") self.destination_bathymetry.to_netcdf( @@ -906,6 +911,11 @@ def mpi_set_from_dataset( ) self.src_bathymetry_dataset = self.src.ds self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() + self.destination_bathymetry["depth"] = np.zeros_like( + self.destination_bathymetry.tarea + ) + self.destination_bathymetry.depth.attrs["units"] = "meters" + self.destination_bathymetry.depth.attrs["coordinates"] = "lon lat" self.src_bathymetry_dataset.to_netcdf(output_dir / "bathymetry_original.nc") self.destination_bathymetry.to_netcdf(output_dir / "bathymetry_unfinished.nc") From 8390f69a2fdc25987c77c36da9c2543bba15d710 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 15:59:04 -0400 Subject: [PATCH 43/63] rm the tidy dataset positive down situation --- mom6_forge/topo.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 7563fb5e..c9aff22f 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -858,7 +858,6 @@ def set_from_dataset( # Set directly into self.depth in this function self.tidy_dataset( fill_channels=fill_channels, - is_input_positive_below_msl=is_input_positive_below_msl, vertical_coordinate_name="depth", bathymetry=self.regridded_bathy, output_dir=output_dir, @@ -926,11 +925,9 @@ def mpi_set_from_dataset( def tidy_dataset( self, fill_channels=False, - is_input_positive_below_msl=False, vertical_coordinate_name="depth", bathymetry=None, output_dir=Path(""), - write_to_file=True, longitude_coordinate_name="lon", latitude_coordinate_name="lat", ): @@ -978,10 +975,6 @@ def tidy_dataset( bathymetry.expand_dims("tiles", 0) - if not is_input_positive_below_msl: - ## Ensure that coordinate is positive down! - bathymetry["depth"] *= -1 - ## Make a land mask based on the bathymetry ocean_mask = xr.where(bathymetry.depth <= 0, 0, 1) land_mask = np.abs(ocean_mask - 1) From c494878b6c02eb565f4a5f390ad3225a26da91d0 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 16:57:47 -0400 Subject: [PATCH 44/63] Last Fixes --- mom6_forge/grid.py | 1 + mom6_forge/topo.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mom6_forge/grid.py b/mom6_forge/grid.py index 6060b956..55483997 100644 --- a/mom6_forge/grid.py +++ b/mom6_forge/grid.py @@ -954,4 +954,5 @@ def get_esmf_ready_tracer_ds(self): ds.lon.attrs["_FillValue"] = 1e20 ds.lat.attrs["units"] = "degrees_north" ds.lat.attrs["_FillValue"] = 1e20 + ds = ds.set_coords(["lat", "lon"]) return ds diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index c9aff22f..f17d4a9d 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -834,7 +834,7 @@ def set_from_dataset( ) self.src_bathymetry_dataset = self.src.ds self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() - self.destination_bathymetry["depth"] = np.zeros_like( + self.destination_bathymetry["depth"] = xr.zeros_like( self.destination_bathymetry.tarea ) self.destination_bathymetry.depth.attrs["units"] = "meters" @@ -861,7 +861,6 @@ def set_from_dataset( vertical_coordinate_name="depth", bathymetry=self.regridded_bathy, output_dir=output_dir, - write_to_file=write_to_file, longitude_coordinate_name="lon", latitude_coordinate_name="lat", ) @@ -910,7 +909,7 @@ def mpi_set_from_dataset( ) self.src_bathymetry_dataset = self.src.ds self.destination_bathymetry = self._grid.get_esmf_ready_tracer_ds() - self.destination_bathymetry["depth"] = np.zeros_like( + self.destination_bathymetry["depth"] = xr.zeros_like( self.destination_bathymetry.tarea ) self.destination_bathymetry.depth.attrs["units"] = "meters" From 0f2a928a86b20a34fbe72e0abd241ff287c35f4c Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 17:39:15 -0400 Subject: [PATCH 45/63] Put the stats on the source --- mom6_forge/topo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index d71e47ec..e3dcf205 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -797,13 +797,13 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): self.src is not None ), "Source bathymetry must be loaded to compute topo stats" src = self.src - if hasattr(self, "_stats") and isinstance(self._stats, xr.Dataset): + if hasattr(self.src, "_stats") and isinstance(self.src._stats, xr.Dataset): if ( - self._stats.attrs.get("nx_sub") == nx_sub - and self._stats.attrs.get("ny_sub") == ny_sub - and self._stats.attrs.get("mask_hmin") == mask_hmin + self.src._stats.attrs.get("nx_sub") == nx_sub + and self.src._stats.attrs.get("ny_sub") == ny_sub + and self.src._stats.attrs.get("mask_hmin") == mask_hmin ): - return self._stats + return self.src._stats # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( @@ -828,7 +828,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1)) dims = ["ny", "nx"] - self._stats = xr.Dataset( + self.src._stats = xr.Dataset( { "OCN_FRAC": xr.DataArray( ocn_frac, @@ -868,7 +868,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): "mask_hmin": mask_hmin, }, ) - return self._stats + return self.src._stats def set_from_dataset( self, From 7cf2011ff527e7df808285bc9751a91561097360 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 17:53:51 -0400 Subject: [PATCH 46/63] Fix --- tests/test_topo_bathymetry_workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 214581c4..7b5bac2f 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -15,7 +15,7 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(topo, synthetic_bathy_file) + src = SourceBathy(topo, synthetic_bathy_file, depth_name="elevation") topo._src = src # Test with different sub-sampling densities From f58f36c302df4a9aa62408df7477b97c2228c897 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 17:58:46 -0400 Subject: [PATCH 47/63] pass --- tests/test_topo_bathymetry_workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 52a52800..e7a10056 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -61,7 +61,7 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(topo, synthetic_bathy_file) + src = SourceBathy(topo, synthetic_bathy_file, depth_name = "elevation") topo.src = src topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) From 622e429390d1568036e81c2e96b4855de459729d Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 19 May 2026 18:00:00 -0400 Subject: [PATCH 48/63] Balc --- tests/test_topo_bathymetry_workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index e7a10056..95592104 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -61,7 +61,7 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(topo, synthetic_bathy_file, depth_name = "elevation") + src = SourceBathy(topo, synthetic_bathy_file, depth_name="elevation") topo.src = src topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) From dab01380cf638325e35fedbfa37fb3fac77251ea Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 11:06:44 -0400 Subject: [PATCH 49/63] Docs & Move Longitude Slicer --- mom6_forge/_source_bathy.py | 258 +++++++++++++++++++++++++++++++----- mom6_forge/topo.py | 2 +- mom6_forge/utils.py | 117 ---------------- tests/test_source_bathy.py | 94 ++++++++++++- tests/test_utils.py | 74 ----------- 5 files changed, 318 insertions(+), 227 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index 83ca3454..b8e0e52e 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -1,34 +1,49 @@ """Source bathymetry loader for mom6_forge. ``SourceBathy`` is a lightweight data container for a regional slice of a -source bathymetry dataset. Users who call pipeline -methods directly +source bathymetry dataset. Users who want to call pipeline methods directly should construct a ``SourceBathy`` explicitly:: from mom6_forge._source_bathy import SourceBathy - src = SourceBathy("gebco_2023.nc").slice_to_domain(topo) + src = SourceBathy( + topo, + "gebco_2023.nc", + lon_name="lon", + lat_name="lat", + depth_name="elevation", + ) """ import numpy as np import xarray as xr from pathlib import Path -from mom6_forge.utils import longitude_slicer class SourceBathy: """Regional slice of a source bathymetry dataset (e.g. GEBCO). - Holds the loaded, domain-clipped, ESMF prepped elevation DataArray together with its - coordinate-name metadata. Per-cell depth statistics are computed and - cached here so repeated calls with the same source file skip the - expensive sub-sampling step. + Holds the loaded, domain-clipped, ESMF-prepped bathymetry ``DataArray`` + together with its coordinate-name metadata. Parameters ---------- + topo : Topo + Target grid object. Only ``topo._grid.qlon`` and + ``topo._grid.qlat`` are used to determine the clipping extent. path : str or Path - lon_name : str — longitude coordinate name. Default ``"lon"``. - lat_name : str — latitude coordinate name. Default ``"lat"``. - depth_name : str — depth variable. Default ``"depth"``. + Path to the source bathymetry NetCDF file. + lon_name : str, optional + Longitude coordinate name in the source file. Default ``"lon"``. + lat_name : str, optional + Latitude coordinate name in the source file. Default ``"lat"``. + depth_name : str, optional + Depth variable name in the source file. Default ``"depth"``. + is_input_positive_below_msl : bool, optional + Whether depth values in the source file are positive below mean sea + level. If ``False`` the sign is flipped on load. Default ``True``. + buf : float, optional + Degree buffer added around the Q-grid bounding box when clipping the + source dataset. Default ``0.5``. """ def __init__( @@ -54,8 +69,27 @@ def __init__( # ------------------------------------------------------------------ def _rename_dims_and_format_ds(self, lon_name, lat_name, depth_name): - """Rename dimensions in the source dataset to match the provided names. This helps prep the dataset for ESMF regridding, which expects specific coordinate names.""" + """Rename source-file dimensions and set required depth attributes. + + Renames ``lon_name``, ``lat_name``, and ``depth_name`` to the + canonical names ``"lon"``, ``"lat"``, and ``"depth"`` so that + subsequent methods can rely on consistent coordinate names. Also + sets the ``missing_value``, ``_FillValue``, ``units``, and + ``coordinates`` attributes expected by FRE tools (some by ESMF as well), and assigns + ``units`` to the longitude/latitude coordinates if absent. + Sets ``self.lon_name``, ``self.lat_name``, and ``self.depth_name`` + to the renamed canonical strings. Mutates ``self._ds`` in place. + + Parameters + ---------- + lon_name : str + Original longitude coordinate name in ``self._ds``. + lat_name : str + Original latitude coordinate name in ``self._ds``. + depth_name : str + Original depth variable name in ``self._ds``. + """ self._ds = self._ds.rename( { lon_name: "lon", @@ -71,24 +105,31 @@ def _rename_dims_and_format_ds(self, lon_name, lat_name, depth_name): ) # missing value expected by FRE tools self._ds.depth.attrs["_FillValue"] = -1e20 self._ds.depth.attrs["units"] = "meters" - self._ds.depth.attrs["standard_name"] = "height_above_reference_ellipsoid" - self._ds.depth.attrs["long_name"] = "Elevation relative to sea level" self._ds.depth.attrs["coordinates"] = "lon lat" if "units" not in self._ds[self.lon_name].attrs: - self._ds[self.lon_name].attrs["units"] = "degrees_east" + self._ds[self.lon_name].attrs["units"] = "degrees" if "units" not in self._ds[self.lat_name].attrs: self._ds[self.lat_name].attrs["units"] = "degrees_north" def _slice_to_domain(self, topo, buf=0.5): - """Load and clip elevation to the topo grid extent plus ``buf`` degrees. + """Clip ``self._ds`` to the topo grid extent plus ``buf`` degrees. - Handles the global-longitude seam automatically. Mutates ``self`` - in place and returns ``self`` for chaining. + Handles the global-longitude seam and regional domains automatically. + Stores ``topo`` as ``self.topo`` and mutates ``self._ds`` in place. Parameters ---------- - topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used. - buf : float — degree buffer around the Q-grid bounding box. Default 0.5. + topo : Topo + Target grid object; only ``topo._grid.qlon`` and + ``topo._grid.qlat`` are used. + buf : float, optional + Degree buffer added around the Q-grid bounding box. + Default ``0.5``. + + Returns + ------- + xarray.Dataset + The clipped dataset (also stored as ``self._ds``). """ self.topo = topo lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) @@ -101,7 +142,6 @@ def _slice_to_domain(self, topo, buf=0.5): {self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)} ) - dlon = float(self._ds[self.lon_name][1] - self._ds[self.lon_name][0]) self._ds = longitude_slicer( self._ds, np.array(lon_extent) + np.array([-buf, buf]), @@ -110,9 +150,20 @@ def _slice_to_domain(self, topo, buf=0.5): return self._ds - def _ensure_depth_is_positive_below_msl(self, depth_positive): - """Ensure depth is positive-down. Mutates self in place.""" - if not depth_positive: + def _ensure_depth_is_positive_below_msl(self, is_input_positive_below_msl): + """Flip depth sign if the source convention is positive *above* MSL. + + Mutates ``self._ds`` in place. After this call the ``"depth"`` + variable is always positive below mean sea level (ocean > 0). + + Parameters + ---------- + is_input_positive_below_msl : bool + If ``True`` the source data are already positive below MSL and no + change is made. If ``False`` the sign of the depth variable is + inverted. + """ + if not is_input_positive_below_msl: self._ds[self.depth_name] = -self._ds[self.depth_name] # ------------------------------------------------------------------ @@ -121,27 +172,174 @@ def _ensure_depth_is_positive_below_msl(self, depth_positive): @property def lon(self): - """1-D longitude array.""" + """1-D longitude array of the clipped source dataset. + + Returns + ------- + numpy.ndarray + Longitude values in degrees. + """ return self.ds[self.lon_name].values @property def lat(self): - """1-D latitude array.""" + """1-D latitude array of the clipped source dataset. + + Returns + ------- + numpy.ndarray + Latitude values in degrees. + """ return self.ds[self.lat_name].values @property def depth(self): - """2-D depth array, depth positive (ocean > 0)""" + """2-D depth array, positive below mean sea level (ocean > 0). + + Returns + ------- + numpy.ndarray + Depth values in metres, shape ``(lat, lon)``. + """ return self.ds[self.depth_name].values @property def ds(self): - """Raw dataset with source coordinate names (positive-down).""" + """Clipped source dataset with canonical coordinate names. + + Depth is positive below mean sea level. + + Returns + ------- + xarray.Dataset + """ return self._ds def __repr__(self): - shape = self._da.shape if self._da is not None else "not loaded" return ( f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, " - f"lat={self.lat_name!r}, elevation={self.elevation_name!r}, shape={shape})" + f"lat={self.lat_name!r}, depth={self.depth_name!r}, shape={self.depth.shape})" + ) + + +def longitude_slicer(data, longitude_extent, longitude_coords): + """Slice a dataset in longitude, handling periodicity and domain seams. + + Correctly clips datasets whose longitude coordinate may use any + convention (e.g. ``[0, 360]`` or ``[-180, 180]``) and where the + requested ``longitude_extent`` may straddle the wrap-around seam. + + The algorithm proceeds in five steps: + + 1. Determine the integer multiple of 360° needed to shift the midpoint + of ``longitude_extent`` into the range covered by ``data``. + 2. Roll the dataset so that its centre aligns with the midpoint of the + target extent. + 3. Rebuild a monotonically increasing longitude coordinate that removes + the seam introduced by the roll. + 4. Index out the required number of longitude points symmetrically + around the new centre. + 5. Re-centre the coordinate values to match the target domain. + + Parameters + ---------- + data : xarray.Dataset + Global (or at least periodic) dataset to slice. The longitude + coordinate must be uniformly spaced. + longitude_extent : array-like of float + Target longitude bounds ``(west, east)`` in degrees, in increasing + order. + longitude_coords : str or list of str + Name(s) of the longitude coordinate(s) in ``data`` along which to + slice. + + Returns + ------- + xarray.Dataset + Dataset sliced to ``longitude_extent``, with the longitude + coordinate re-centred to match the target domain. + + Raises + ------ + AssertionError + If any named longitude coordinate is not uniformly spaced. + """ + + if isinstance(longitude_coords, str): + longitude_coords = [longitude_coords] + + for lon in longitude_coords: + + central_longitude = np.mean(longitude_extent) ## Midpoint of target domain + + ## Find a corresponding value for the intended domain midpoint in our data. + ## It's assumed that data has equally-spaced longitude values. + + lons = data[lon].data + dlons = lons[1] - lons[0] + lon_span = lons[-1] - lons[0] + dlons + + assert np.allclose( + np.diff(lons), dlons * np.ones(np.size(lons) - 1) + ), "provided longitude coordinate must be uniformly spaced" + + for i in range(-1, 2, 1): + if data[lon][0] <= central_longitude + 360 * i <= data[lon][-1]: + + ## Shifted version of target midpoint; e.g., could be -90 vs 270 + ## integer i keeps track of what how many multiples of 360 we need to shift entire + ## grid by to match central_longitude + _central_longitude = central_longitude + 360 * i + + ## Midpoint of the data + central_data = data[lon][data[lon].shape[0] // 2].values + + ## Number of indices between the data midpoint and the target midpoint. + ## Sign indicates direction needed to shift. + shift = int( + -(data[lon].shape[0] * (_central_longitude - central_data)) + // lon_span + ) + + ## Shift data so that the midpoint of the target domain is the middle of + ## the data for easy slicing. + new_data = data.roll({lon: 1 * shift}, roll_coords=True) + + ## Create a new longitude coordinate. + ## We'll modify this to remove any seams (i.e., jumps like -270 -> 90) + new_lon = new_data[lon].values.copy() + + ## Take the 'seam' of the data, and either backfill or forward fill based on + ## whether the data was shifted F or west + if shift > 0: + new_seam_index = shift + + new_lon[0:new_seam_index] -= 360 + + if shift < 0: + new_seam_index = data[lon].shape[0] + shift + + new_lon[new_seam_index:] += 360 + + ## new_lon is used to re-centre the midpoint to match that of target domain + new_lon -= i * 360 + + new_data = new_data.assign_coords({lon: new_lon}) + + ## Choose the number of lon points to take from the middle, including a buffer. + ## Use this to index the new global dataset + num_lonpoints = int( + int(data[lon].shape[0] * (central_longitude - longitude_extent[0])) + // lon_span + ) + + data = new_data.isel( + { + lon: slice( + data[lon].shape[0] // 2 - num_lonpoints, + data[lon].shape[0] // 2 + num_lonpoints, + ) + } ) + + return data diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index f17d4a9d..cfd375ab 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -7,7 +7,7 @@ from scipy import interpolate from scipy.ndimage import label, binary_fill_holes from scipy.spatial import cKDTree -from mom6_forge.utils import cell_area_rad, longitude_slicer +from mom6_forge.utils import cell_area_rad from mom6_forge.grid import Grid from mom6_forge.git_utils import get_domain_dir, get_repo from pathlib import Path diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py index 348b97e3..f0507349 100644 --- a/mom6_forge/utils.py +++ b/mom6_forge/utils.py @@ -387,120 +387,3 @@ def fill_missing_data(idata, mask, maxiter=0, stabilizer=1.0e-14, tripole=False) x, info = scipy.sparse.linalg.bicg(A, b, maxiter=maxiter) new_data[missing_j, missing_i] = x return new_data - - -def longitude_slicer(data, longitude_extent, longitude_coords): - """ - Slices longitudes while handling periodicity and the 'seams', that is the - longitude values where the data wraps around in a global domain (for example, - longitudes are defined, usually, within domain [0, 360] or [-180, 180]). - - The algorithm works in five steps: - - - Determine whether we need to add or subtract 360 to get the middle of the - ``longitude_extent`` to lie within ``data``'s longitude range (hereby ``old_lon``). - - - Shift the dataset so that its midpoint matches the midpoint of - ``longitude_extent`` (up to a multiple of 360). Now, the modified ``old_lon`` - does not increase monotonically from West to East since the 'seam' - has moved. - - - Fix ``old_lon`` to make it monotonically increasing again. This uses - the information we have about the way the dataset was shifted/rolled. - - - Slice the ``data`` index-wise. We know that ``|longitude_extent[1] - longitude_extent[0]| / 360`` - multiplied by the number of discrete longitude points in the global input data gives - the number of longitude points in our slice, and we've already set the midpoint - to be the middle of the target domain. - - - Add back the correct multiple of 360 so the whole domain matches the target. - - Arguments: - data (xarray.Dataset): The global data you want to slice in longitude. - longitude_extent (Tuple[float, float]): The target longitudes (in degrees) - we want to slice to. Must be in increasing order. - longitude_coords (Union[str, list[str]): The name or list of names of the - longitude coordinates(s) in ``data``. - - Returns: - xarray.Dataset: The sliced ``data``. - """ - - if isinstance(longitude_coords, str): - longitude_coords = [longitude_coords] - - for lon in longitude_coords: - - central_longitude = np.mean(longitude_extent) ## Midpoint of target domain - - ## Find a corresponding value for the intended domain midpoint in our data. - ## It's assumed that data has equally-spaced longitude values. - - lons = data[lon].data - dlons = lons[1] - lons[0] - lon_span = lons[-1] - lons[0] + dlons - - assert np.allclose( - np.diff(lons), dlons * np.ones(np.size(lons) - 1) - ), "provided longitude coordinate must be uniformly spaced" - - for i in range(-1, 2, 1): - if data[lon][0] <= central_longitude + 360 * i <= data[lon][-1]: - - ## Shifted version of target midpoint; e.g., could be -90 vs 270 - ## integer i keeps track of what how many multiples of 360 we need to shift entire - ## grid by to match central_longitude - _central_longitude = central_longitude + 360 * i - - ## Midpoint of the data - central_data = data[lon][data[lon].shape[0] // 2].values - - ## Number of indices between the data midpoint and the target midpoint. - ## Sign indicates direction needed to shift. - shift = int( - -(data[lon].shape[0] * (_central_longitude - central_data)) - // lon_span - ) - - ## Shift data so that the midpoint of the target domain is the middle of - ## the data for easy slicing. - new_data = data.roll({lon: 1 * shift}, roll_coords=True) - - ## Create a new longitude coordinate. - ## We'll modify this to remove any seams (i.e., jumps like -270 -> 90) - new_lon = new_data[lon].values.copy() - - ## Take the 'seam' of the data, and either backfill or forward fill based on - ## whether the data was shifted F or west - if shift > 0: - new_seam_index = shift - - new_lon[0:new_seam_index] -= 360 - - if shift < 0: - new_seam_index = data[lon].shape[0] + shift - - new_lon[new_seam_index:] += 360 - - ## new_lon is used to re-centre the midpoint to match that of target domain - new_lon -= i * 360 - - new_data = new_data.assign_coords({lon: new_lon}) - - ## Choose the number of lon points to take from the middle, including a buffer. - ## Use this to index the new global dataset - num_lonpoints = int( - int(data[lon].shape[0] * (central_longitude - longitude_extent[0])) - // lon_span - ) - - data = new_data.isel( - { - lon: slice( - data[lon].shape[0] // 2 - num_lonpoints, - data[lon].shape[0] // 2 + num_lonpoints, - ) - } - ) - - return data diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index 97cdac93..3395c02a 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -5,7 +5,7 @@ import tempfile from pathlib import Path import xarray as xr -from mom6_forge._source_bathy import SourceBathy +from mom6_forge._source_bathy import SourceBathy, longitude_slicer @pytest.fixture @@ -23,7 +23,7 @@ def synthetic_bathy_file(): lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer - # Create synthetic elevation data (positive-up, like GEBCO) + # Create synthetic depth data (positive-up, like GEBCO) # Ocean is negative (water), land is positive elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep @@ -50,6 +50,17 @@ def synthetic_bathy_file(): Path(bathy_file).unlink() +def test_simple_source_bathy_calls(get_rect_topo, synthetic_bathy_file): + src = SourceBathy( + get_rect_topo, + synthetic_bathy_file, + lon_name="lon", + lat_name="lat", + depth_name="elevation", + ) + print(src, src.ds, src.lon, src.lat, src.depth) + + def test_source_bathy_initialization(synthetic_bathy_file, get_rect_topo): """Test SourceBathy initialization and coordinate names.""" src = SourceBathy( @@ -67,7 +78,7 @@ def test_source_bathy_initialization(synthetic_bathy_file, get_rect_topo): def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): - """Smoke test: load and slice elevation to topo domain.""" + """Smoke test: load and slice depth to topo domain.""" topo = get_rect_topo src = SourceBathy( @@ -88,7 +99,7 @@ def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): - """Test that elevation is converted to positive-down depth.""" + """Test that depth is converted to positive-down depth.""" topo = get_rect_topo src = SourceBathy( @@ -104,9 +115,82 @@ def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): # Verify no NaNs in the result assert not bool(np.isnan(depth).all()), "All depth values are NaN" - # Verify positive depth values for ocean (elevation is negative) + # Verify positive depth values for ocean (depth is negative) non_nan_values = depth[~np.isnan(depth)] assert len(non_nan_values) > 0, "No valid depth values" assert np.any(non_nan_values > 0), "Expected positive depth values for ocean" assert depth.shape == src.depth.shape + + +def test_longitude_slicer(): + with pytest.raises(AssertionError): + nx, ny, nt = 4, 14, 5 + + latitude_extent = (10, 20) + longitude_extent = (12, 18) + + dims = ["random_lat", "random_lon", "time"] + + dlambda = (longitude_extent[1] - longitude_extent[0]) / 2 + + data = xr.DataArray( + np.random.random((ny, nx, nt)), + dims=dims, + coords={ + "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), + "random_lon": np.array( + [ + longitude_extent[0], + longitude_extent[0] + 1.5 * dlambda, + longitude_extent[0] + 2.6 * dlambda, + longitude_extent[1], + ] + ), + "time": np.linspace(0, 1000, nt), + }, + ) + + longitude_slicer(data, longitude_extent, "random_lon") + + +def test_longitude_slicers_regionally(): + nx, ny = 4, 14 + + latitude_extent = (2, 5) + longitude_extent = (-90, -70) + + dims = ["random_lat", "random_lon"] + + dlambda = (longitude_extent[1] - longitude_extent[0]) / 2 + + data = xr.DataArray( + np.random.random((ny, nx)), + dims=dims, + coords={ + "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), + "random_lon": np.linspace( + longitude_extent[0] - 2, longitude_extent[1] + 2, nx + ), + }, + ) + + # Regular regional + data_regular = longitude_slicer(data, longitude_extent, "random_lon") + data_east = longitude_slicer(data, (270, 290), "random_lon") + assert (data_regular == data_east).all() + + # Seam data + longitude_extent = (-5, 5) + data = xr.DataArray( + np.random.random((ny, nx)), + dims=dims, + coords={ + "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), + "random_lon": np.linspace( + longitude_extent[0] - 2, longitude_extent[1] + 2, nx + ), + }, + ) + data_regular = longitude_slicer(data, longitude_extent, "random_lon") + assert len(data_regular.random_lon) > 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index 2181f782..de90aee8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,6 @@ from mom6_forge.utils import ( get_avg_resolution, get_avg_resolution_km, - longitude_slicer, ) from mom6_forge._supergrid import ( quadrilateral_area, @@ -35,79 +34,6 @@ def test_avg_resolution(): ), "Average resolution for tx2_3v2 should be around 40 km" -def test_longitude_slicer(): - with pytest.raises(AssertionError): - nx, ny, nt = 4, 14, 5 - - latitude_extent = (10, 20) - longitude_extent = (12, 18) - - dims = ["random_lat", "random_lon", "time"] - - dlambda = (longitude_extent[1] - longitude_extent[0]) / 2 - - data = xr.DataArray( - np.random.random((ny, nx, nt)), - dims=dims, - coords={ - "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), - "random_lon": np.array( - [ - longitude_extent[0], - longitude_extent[0] + 1.5 * dlambda, - longitude_extent[0] + 2.6 * dlambda, - longitude_extent[1], - ] - ), - "time": np.linspace(0, 1000, nt), - }, - ) - - longitude_slicer(data, longitude_extent, "random_lon") - - -def test_longitude_slicers_regionally(): - nx, ny = 4, 14 - - latitude_extent = (2, 5) - longitude_extent = (-90, -70) - - dims = ["random_lat", "random_lon"] - - dlambda = (longitude_extent[1] - longitude_extent[0]) / 2 - - data = xr.DataArray( - np.random.random((ny, nx)), - dims=dims, - coords={ - "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), - "random_lon": np.linspace( - longitude_extent[0] - 2, longitude_extent[1] + 2, nx - ), - }, - ) - - # Regular regional - data_regular = longitude_slicer(data, longitude_extent, "random_lon") - data_east = longitude_slicer(data, (270, 290), "random_lon") - assert (data_regular == data_east).all() - - # Seam data - longitude_extent = (-5, 5) - data = xr.DataArray( - np.random.random((ny, nx)), - dims=dims, - coords={ - "random_lat": np.linspace(latitude_extent[0], latitude_extent[1], ny), - "random_lon": np.linspace( - longitude_extent[0] - 2, longitude_extent[1] + 2, nx - ), - }, - ) - data_regular = longitude_slicer(data, longitude_extent, "random_lon") - assert len(data_regular.random_lon) > 0 - - @pytest.mark.parametrize( ("v1", "v2", "v3", "v4", "true_area"), [ From d2d93ec4f86341d33856e988cb2639e2fa23693f Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 15:42:38 -0400 Subject: [PATCH 50/63] Review Comments --- mom6_forge/_source_bathy.py | 9 +++++++++ mom6_forge/topo.py | 9 ++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index b8e0e52e..f5e80986 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -283,8 +283,12 @@ def longitude_slicer(data, longitude_extent, longitude_coords): np.diff(lons), dlons * np.ones(np.size(lons) - 1) ), "provided longitude coordinate must be uniformly spaced" + is_longitude_extent_in_data = ( + False # This boolean checks if the 360 + i adjustment isn't found + ) for i in range(-1, 2, 1): if data[lon][0] <= central_longitude + 360 * i <= data[lon][-1]: + is_longitude_extent_in_data = True ## Shifted version of target midpoint; e.g., could be -90 vs 270 ## integer i keeps track of what how many multiples of 360 we need to shift entire @@ -333,6 +337,11 @@ def longitude_slicer(data, longitude_extent, longitude_coords): // lon_span ) + if not is_longitude_extent_in_data: + raise ValueError( + "The longitude of the data doesn't seem to include the longitude of the grid." + ) + data = new_data.isel( { lon: slice( diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index cfd375ab..90e83dbb 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -53,7 +53,7 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary", git=True) None # Binary ocean/land mask (None = no mask applied) ) self._min_depth = min_depth - self._src = None # cached SourceBathy; set by _set_src() + self._src = None # SourceBathy object; set by set_src() self.land_fillval = 0.0 # Depth value for land cells initial_command = MinDepthEditCommand( self, attr="min_depth", new_value=min_depth @@ -480,8 +480,7 @@ def set_src( is_input_positive_below_msl=False, buf=0.5, ): - """Set a :class:`SourceBathy`, creating and slicing a new one - only when the path or coordinate names differ from the current cache.""" + """Set a :class:`SourceBathy` into a class object called src""" self.src = SourceBathy( self, Path(bathymetry_path), @@ -944,8 +943,8 @@ def tidy_dataset( fill_channels (Optional[bool]): Whether to fill in diagonal channels. This removes more narrow inlets, but can also connect extra islands to land. Default: ``False``. - is_input_positive_below_msl (Optional[bool]): If ``False`` (default), assume that - bathymetry vertical coordinate is positive down, as is the case in GEBCO for example. + is_input_positive_below_msl (Optional[bool]): If ``True`` (default), assume that + bathymetry vertical coordinate is positive below sea level, GEBCO is negative below sea level, and would need for this to be False. bathymetry (Optional[xr.Dataset]): The bathymetry dataset to tidy up. If not provided, it will read the bathymetry from the file ``bathymetry_unfinished.nc`` in the input directory that was created by :func:`~config/regrid_dataset`. From 69fc8274f5959268e766a92970ac27cd9e82562d Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 16:03:16 -0400 Subject: [PATCH 51/63] black --- mom6_forge/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py index abddc375..bc6f8f2b 100644 --- a/mom6_forge/utils.py +++ b/mom6_forge/utils.py @@ -388,6 +388,7 @@ def fill_missing_data(idata, mask, maxiter=0, stabilizer=1.0e-14, tripole=False) new_data[missing_j, missing_i] = x return new_data + def compute_subsampling_factor(src_nj, src_ni, dst_nj, dst_ni): """ Compute the sub-sampling factors needed so that the super-sampled @@ -412,4 +413,4 @@ def compute_subsampling_factor(src_nj, src_ni, dst_nj, dst_ni): while ny_sub * dst_nj < src_nj: ny_sub += 1 - return ny_sub, nx_sub \ No newline at end of file + return ny_sub, nx_sub From e398bc3c156e4fb75502242588cd78b65fe91fc7 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 16:07:04 -0400 Subject: [PATCH 52/63] remove extra fxtures --- tests/conftest.py | 2 +- tests/test_source_bathy.py | 42 -------------------------------------- tests/test_topo.py | 42 -------------------------------------- 3 files changed, 1 insertion(+), 85 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 06d41a91..f8cd9ae5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -313,7 +313,7 @@ def synthetic_bathy_file(): lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer - # Create synthetic elevation data (positive-up, like GEBCO) + # Create synthetic depth data (positive-up, like GEBCO) # Ocean is negative (water), land is positive elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py index 3395c02a..5440348e 100644 --- a/tests/test_source_bathy.py +++ b/tests/test_source_bathy.py @@ -8,48 +8,6 @@ from mom6_forge._source_bathy import SourceBathy, longitude_slicer -@pytest.fixture -def synthetic_bathy_file(): - """Create a temporary synthetic bathymetry NetCDF file for testing. - - Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). - """ - with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: - bathy_file = tmp.name - - # Create synthetic bathymetry covering the Panama region - # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 - # So we need lon 278-282, lat 7-10 with some buffer - lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer - lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer - - # Create synthetic depth data (positive-up, like GEBCO) - # Ocean is negative (water), land is positive - elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep - - # Add synthetic land masses (islands) - # Create an island around (280, 8.5) - lon_2d, lat_2d = np.meshgrid(lon, lat) - island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 - elevation[island_mask] = 200.0 # Synthetic island - - ds = xr.Dataset( - { - "elevation": (["lat", "lon"], elevation), - }, - coords={ - "lon": lon, - "lat": lat, - }, - ) - ds.to_netcdf(bathy_file) - - yield bathy_file - - # Cleanup - Path(bathy_file).unlink() - - def test_simple_source_bathy_calls(get_rect_topo, synthetic_bathy_file): src = SourceBathy( get_rect_topo, diff --git a/tests/test_topo.py b/tests/test_topo.py index c3809110..b5f046e7 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -7,48 +7,6 @@ from mom6_forge._source_bathy import SourceBathy -@pytest.fixture -def synthetic_bathy_file(): - """Create a temporary synthetic bathymetry NetCDF file for testing. - - Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). - """ - with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: - bathy_file = tmp.name - - # Create synthetic bathymetry covering the Panama region - # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 - # So we need lon 278-282, lat 7-10 with some buffer - lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer - lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer - - # Create synthetic elevation data (positive-up, like GEBCO) - # Ocean is negative (water), land is positive - elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep - - # Add synthetic land masses (islands) - # Create an island around (280, 8.5) - lon_2d, lat_2d = np.meshgrid(lon, lat) - island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 - elevation[island_mask] = 200.0 # Synthetic island - - ds = xr.Dataset( - { - "elevation": (["lat", "lon"], elevation), - }, - coords={ - "lon": lon, - "lat": lat, - }, - ) - ds.to_netcdf(bathy_file) - - yield bathy_file - - # Cleanup - Path(bathy_file).unlink() - - def test_topo_from_version_control(get_rect_topo): topo = get_rect_topo # this topo has a version control directory topo_from_version_control = Topo.from_version_control(topo.domain_dir) From 2c9b0fe0ba910c57b84187f3517aff4508668410 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 16:07:52 -0400 Subject: [PATCH 53/63] Imports --- tests/test_topo.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_topo.py b/tests/test_topo.py index b5f046e7..332c1e74 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -1,10 +1,4 @@ -import numpy as np -import pytest -import tempfile -from pathlib import Path -import xarray as xr from mom6_forge.topo import * -from mom6_forge._source_bathy import SourceBathy def test_topo_from_version_control(get_rect_topo): From 7d0a5c1c76abfb09c973ce06a9f1026d232023e7 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Wed, 20 May 2026 16:16:08 -0400 Subject: [PATCH 54/63] Thing --- mom6_forge/topo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index a467905f..223fec72 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -7,7 +7,7 @@ from scipy import interpolate from scipy.ndimage import label, binary_fill_holes from scipy.spatial import cKDTree -from mom6_forge.utils import cell_area_rad, longitude_slicer +from mom6_forge.utils import cell_area_rad from mom6_forge.grid import Grid from mom6_forge.git_utils import get_domain_dir, get_repo from pathlib import Path From 61778af768d825a58127b85fd734a9c86fc6c9ee Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 21 May 2026 12:49:57 -0400 Subject: [PATCH 55/63] Fix --- mom6_forge/topo.py | 8 ++++---- tests/conftest.py | 2 +- tests/test_topo.py | 1 + tests/test_topo_bathymetry_workflows.py | 11 +++++++++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index fa4b1581..96ff4f55 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -870,20 +870,20 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): return self.src._stats def direct_stats_depth(self, statistic): - """Set the topo depth to a statistic from compute_topo_stats""" + """Set the topo depth to a statistic from compute_topo_stats (Must be called before this function to compute the stats).""" assert ( - self._stats is not None + self.src._stats is not None ), "Source bathymetry must have topo stats computed, please call _compute_stats first if you have not already" approved_list = [] - for key in self._stats: + for key in self.src._stats: if key.startswith("D_"): approved_list.append(key[2:]) assert ( statistic in approved_list ), f"Invalid statistic {statistic}, must be one of {approved_list}" - self.send_entire_depth_change_to_tcm(self._stats[f"D_{statistic}"]) + self.send_entire_depth_change_to_tcm(self.src._stats[f"D_{statistic}"]) def set_from_dataset( self, diff --git a/tests/conftest.py b/tests/conftest.py index 31b5d6dc..f8cd9ae5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -321,7 +321,7 @@ def synthetic_bathy_file(): # Create an island around (280, 8.5) lon_2d, lat_2d = np.meshgrid(lon, lat) island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 - elevation[island_mask] = -200.0 # Synthetic island + elevation[island_mask] = 200.0 # Synthetic island ds = xr.Dataset( { diff --git a/tests/test_topo.py b/tests/test_topo.py index 97bb7c4f..332c1e74 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -1,5 +1,6 @@ from mom6_forge.topo import * + def test_topo_from_version_control(get_rect_topo): topo = get_rect_topo # this topo has a version control directory topo_from_version_control = Topo.from_version_control(topo.domain_dir) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 95592104..375a51cc 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -61,7 +61,12 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo = get_rect_topo # Load source bathymetry and slice to topo domain - src = SourceBathy(topo, synthetic_bathy_file, depth_name="elevation") + src = SourceBathy( + topo, + synthetic_bathy_file, + depth_name="elevation", + is_input_positive_below_msl=False, + ) topo.src = src topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) @@ -69,4 +74,6 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo.direct_stats_depth("mean") mask = ~np.isnan(topo.depth.values) - assert np.isclose(topo.depth.values[mask], topo._stats["D_mean"].values[mask]).all() + assert np.isclose( + topo.depth.values[mask], topo.src._stats["D_mean"].values[mask] + ).all() From b98b12949beb397bdbf9d19be048dbdecaa518d2 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 21 May 2026 15:36:24 -0400 Subject: [PATCH 56/63] Property --- mom6_forge/_source_bathy.py | 19 +++++++++++++++++++ mom6_forge/topo.py | 20 ++++++++++++-------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py index f5e80986..e9e6d896 100644 --- a/mom6_forge/_source_bathy.py +++ b/mom6_forge/_source_bathy.py @@ -57,6 +57,7 @@ def __init__( buf=0.5, ): self.path = Path(path) + self._computed_stats = None self._ds = xr.open_dataset(self.path, chunks="auto") self._rename_dims_and_format_ds( lon_name=lon_name, lat_name=lat_name, depth_name=depth_name @@ -215,6 +216,24 @@ def ds(self): """ return self._ds + @property + def stats(self): + """Per-cell depth statistics computed by ``Topo._compute_stats``. + + Returns ``None`` until ``_compute_stats`` has been called. + + Returns + ------- + xarray.Dataset or None + """ + return self._computed_stats + + @stats.setter + def stats(self, value): + if value is not None and not isinstance(value, xr.Dataset): + raise TypeError("stats must be an xarray.Dataset or None") + self._computed_stats = value + def __repr__(self): return ( f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, " diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 223fec72..06393dca 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -492,6 +492,10 @@ def set_src( ) return self.src + @property + def stats(self): + return self.src.stats if self.src is not None else None + def clear_user_mask(self): cmd = ClearMaskCommand( self, message="Clear manual mask" @@ -778,7 +782,7 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0): def _compute_stats(self, nx_sub, ny_sub, mask_hmin): """Compute per-cell depth statistics by uniform sub-sampling. - Results are cached on ``stats`` so a second call with the + Results are stored on ``stats`` so a second call with the same source file returns immediately without recomputation. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) @@ -796,13 +800,13 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): self.src is not None ), "Source bathymetry must be loaded to compute topo stats" src = self.src - if hasattr(self.src, "_stats") and isinstance(self.src._stats, xr.Dataset): + if self.src.stats is not None: if ( - self.src._stats.attrs.get("nx_sub") == nx_sub - and self.src._stats.attrs.get("ny_sub") == ny_sub - and self.src._stats.attrs.get("mask_hmin") == mask_hmin + self.src.stats.attrs.get("nx_sub") == nx_sub + and self.src.stats.attrs.get("ny_sub") == ny_sub + and self.src.stats.attrs.get("mask_hmin") == mask_hmin ): - return self.src._stats + return self.src.stats # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( @@ -827,7 +831,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1)) dims = ["ny", "nx"] - self.src._stats = xr.Dataset( + self.src.stats = xr.Dataset( { "OCN_FRAC": xr.DataArray( ocn_frac, @@ -867,7 +871,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): "mask_hmin": mask_hmin, }, ) - return self.src._stats + return self.src.stats def set_from_dataset( self, From 3ea0cca4fcf747516e12b2bf750b03f68457fe69 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 21 May 2026 15:42:09 -0400 Subject: [PATCH 57/63] One more --- mom6_forge/topo.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 06393dca..d2ac836a 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -799,18 +799,17 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): assert ( self.src is not None ), "Source bathymetry must be loaded to compute topo stats" - src = self.src - if self.src.stats is not None: - if ( - self.src.stats.attrs.get("nx_sub") == nx_sub - and self.src.stats.attrs.get("ny_sub") == ny_sub - and self.src.stats.attrs.get("mask_hmin") == mask_hmin - ): - return self.src.stats + if ( + self.stats is not None + and self.stats.attrs.get("nx_sub") == nx_sub + and self.stats.attrs.get("ny_sub") == ny_sub + and self.stats.attrs.get("mask_hmin") == mask_hmin + ): + return self.stats # Compute subsampling factor and generate sub-point grid ds = regrid_with_subsampling( - input_dataset=src.ds, + input_dataset=self.src.ds, qlon=self._grid.qlon.values, qlat=self._grid.qlat.values, nx_sub=nx_sub, @@ -818,7 +817,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): regridding_method="nearest_s2d", ) - depth_sub = ds[src.depth_name].values # (ny, nx, ny_sub, nx_sub) + depth_sub = ds[self.src.depth_name].values # (ny, nx, ny_sub, nx_sub) is_ocean = depth_sub > mask_hmin ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) @@ -871,7 +870,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): "mask_hmin": mask_hmin, }, ) - return self.src.stats + return self.stats def set_from_dataset( self, From 6a6f99e4c5f156d0144496eb0051326e78e71aac Mon Sep 17 00:00:00 2001 From: manishvenu Date: Thu, 21 May 2026 16:15:13 -0400 Subject: [PATCH 58/63] Add Chl improvement --- mom6_forge/chl.py | 2 ++ mom6_forge/mapping.py | 12 +++++++++++- test_r.nc | Bin 0 -> 8658 bytes test_weights.nc | Bin 0 -> 8498 bytes tests/test_mapping.py | 19 ++++++++++++++++++- 5 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 test_r.nc create mode 100644 test_weights.nc diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index 13e9bbb5..3d815562 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -108,6 +108,8 @@ def interpolate_and_fill_seawifs( nx_sub=nx_sub, ny_sub=ny_sub, regridding_method="bilinear", + subsampling_weights_path=output_path + / "chl_subsample_weights.nc", # Will reuse the weights file for each time step, so it only needs to be computed once ) q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 8d452da5..910b98d0 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -949,6 +949,7 @@ def regrid_dataset_via_xesmf( regridding_method=None, write_to_file=True, output_path=Path("regridded_dataset.nc"), + save_weights_path=None, ): """ Regrids the dataset given ``input_dataset`` which contains the original dataset and ``output_dataset`` which is a template for the regridded product. @@ -958,7 +959,7 @@ def regrid_dataset_via_xesmf( input_dataset (Xarray Dataset): original dataset with proper metadata and structure for ESMF regridding. output_dataset (Xarray Dataset): Template for the regridded dataset regridding_method: (Optional[str]) The type of regridding method to use. Defaults to bilinear write_to_file (Optional[bool]): Files saved to ``output_dir`` Defaults to ``True``. Must be set to true if using manual regridding methods with ESMF_regrid. - + save_weights_path (Optional[str]): If provided, saves or (if exists) reuses the xESMF regridding weights at the specified path. Returns: regridded_dataset (Xarray.Dataset): """ @@ -974,12 +975,15 @@ def regrid_dataset_via_xesmf( + f"Regridded size: {output_dataset.nbytes/1e6:.2f} Mb\n" ) + weights_exist = save_weights_path is not None and Path(save_weights_path).is_file() regridder = xe.Regridder( input_dataset, output_dataset, method=regridding_method, locstream_out=False, periodic=False, + filename=save_weights_path, + reuse_weights=weights_exist, ) dataset = regridder(input_dataset) @@ -1071,6 +1075,7 @@ def regrid_with_subsampling( nx_sub, ny_sub, regridding_method="nearest_s2d", + subsampling_weights_path=None, ): """ Regrids input_dataset to sub_sampled_grid to @@ -1083,6 +1088,10 @@ def regrid_with_subsampling( Corner coordinates of the destination grid. nx_sub, ny_sub : int Number of sub-points per cell (typically from compute_subsampling_factor). + regridding_method : str + Regridding method to use for mapping from input_dataset to sub-sampled points. Should be compatible with xESMF regridding methods, e.g., 'nearest_s2d' + subsampling_weights_path : str or Path, optional + If provided, saves or (if exists) reuses the xESMF regridding weights for the sub-sampling regridding at the specified path. Returns ------- regridded_dataset : xr.Dataset @@ -1112,6 +1121,7 @@ def regrid_with_subsampling( flat_output, regridding_method=regridding_method, write_to_file=False, + save_weights_path=subsampling_weights_path, ) # Reshape to 4D, keeping sub-points as their own dimension diff --git a/test_r.nc b/test_r.nc new file mode 100644 index 0000000000000000000000000000000000000000..348b70dd60abc9bfc2814876fa9f418d29fbaf37 GIT binary patch literal 8658 zcmeHMU2IfE6rSDNmfh|42PhVaKNqM`5SOLYDv9M%_lH)OUD-uZqS-9$tzFz7-Q7|Z zwM3$%NFoLd@FFo4?E_K6kF~@HeDJ}N5VcVP@du*fg9++G8$yEheD}_@*R`NAH6&(E z_Uz1_Ip@roZ@xR5%| z2q|q{c4mW1#dY|$33sMs=GRm*I+h(hS#cs3NgokLAI#fJ;?#;&ZjH#^N!RmX+yWCid;=&AN zOy%GlLLw4TZ^aesb>9v@g;g*Ia{EH9u^5|4Cam91>Th2a2*=tZ;dso~8VWG?I)aP> z%FTdq>#Bg^zk>U(=}Z+=`Xtn>UiA5cYBH74vxQVHtFG@Vs@Ys|VrWw$pGvGx>mHTM z(q@K&?3BNnxlvF&tQR*VJDU8dOe)E0i4SBXQH|n^E?W?8`Q5}&|9cEojS7dXU`>^* zX-pvutAVbvaDr?XdmJ<1LmNFCEzL(B5P_(2Wkg%yr?5I&0BxvPfG?NJCsWx(Q7`;? z8MNIzH>wyHx8KCMl9C`zFqKX}lt_2!Z2BZi0BJj2s>G5AD0)i1^y>M`mzgYBJMKus;6}^K)6|KtJ6Rd*ntD=TCjX0{e}&8fNpaQ+B;C(}ETEwm z1dj%SfoLG?3&ahAH&Y7GDVZoH3VLyZ>Ej2}U+E^+C)!fFi8Tx8M*FArO(j>}zqVX? z3y&hnzv9Y*FZUD;uFzJ{6HtZ&17eE_6a|a;EKJGgx{p?vUZb18d5W#{MM6dx_KifO z5C^pQlNQ@kY8Q!1?PfN{`fEf_wh{q}fJ8tdAQ6xVNCYGT5&?;TL_i`S5s(O6D+FGC zxAA=B(eJbyAFpv_&JAftx4khua&|}yb)J6jl5;@YsQ7N{c73l64hJtEId@E({p<66 zxz(R(cf1+Z-*&#EjSU<)7^Lx)TkbhMLiSgjXU>eAJ*IuS(fz?C8b2KSardd4e$uvo zeaq$<(c{{j{rTOGtv;^pI{4c89nL*khR?s4&;Jphzl+cR8_%EM`Jd+bFYx>)GN=;Murusf;mve?nDB!w?-?cj?$647GJO)BE5 zs9#Mg;?T#{iSR=oS2U={s;I}Rn9r)1+p4J7s;I}RsK=_P$Ev8)s`OKeL_i`S5s(N- q1papjT=iwx4*Zu+<;yVqb3;TcUgMZn#!6p?iB~ydeA#`0SN{Nr_yr#T literal 0 HcmV?d00001 diff --git a/test_weights.nc b/test_weights.nc new file mode 100644 index 0000000000000000000000000000000000000000..490de3f6203ef798fbd7a76de2ea25ec96805102 GIT binary patch literal 8498 zcmeHMO-vI}5T4zowU*^q6g4I;YGMMUX#uMj1kxWN)?z5p#KhEvE@4AywJl(R1j3EP z0}1i$K@Xk{CK@jq?X)zGMvPjeLDp!pfMFLZ^GNzoqaQJ-Z$UAlJ`~yqwzZb zNx$mz0n4T`C&_qM4*rG$%V#(Cda59~W3?yQp-$KaH!6O`!1J`0s|9-ds=%igV@Xs`IRw)vV z8yS1r$`$Q=&KR318M%CE>(#7Tu+6co6);>0g@F=MF@8Phlu$ZpmBup@JyCnw&Ojsb zK|pIb$>EqN&F$-^z|i_LhDIMIPSOuOwa~MvA%^vYuGYRqwv9z!GxfnjA14j4kA^zqda zGik|XQa{>Bcr8)?H2A$S-=5S=n+$t+r>EvA Date: Thu, 21 May 2026 21:30:14 -0400 Subject: [PATCH 59/63] Revert "Add Chl improvement" This reverts commit 6a6f99e4c5f156d0144496eb0051326e78e71aac. --- mom6_forge/chl.py | 2 -- mom6_forge/mapping.py | 12 +----------- test_r.nc | Bin 8658 -> 0 bytes test_weights.nc | Bin 8498 -> 0 bytes tests/test_mapping.py | 19 +------------------ 5 files changed, 2 insertions(+), 31 deletions(-) delete mode 100644 test_r.nc delete mode 100644 test_weights.nc diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index 3d815562..13e9bbb5 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -108,8 +108,6 @@ def interpolate_and_fill_seawifs( nx_sub=nx_sub, ny_sub=ny_sub, regridding_method="bilinear", - subsampling_weights_path=output_path - / "chl_subsample_weights.nc", # Will reuse the weights file for each time step, so it only needs to be computed once ) q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 910b98d0..8d452da5 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -949,7 +949,6 @@ def regrid_dataset_via_xesmf( regridding_method=None, write_to_file=True, output_path=Path("regridded_dataset.nc"), - save_weights_path=None, ): """ Regrids the dataset given ``input_dataset`` which contains the original dataset and ``output_dataset`` which is a template for the regridded product. @@ -959,7 +958,7 @@ def regrid_dataset_via_xesmf( input_dataset (Xarray Dataset): original dataset with proper metadata and structure for ESMF regridding. output_dataset (Xarray Dataset): Template for the regridded dataset regridding_method: (Optional[str]) The type of regridding method to use. Defaults to bilinear write_to_file (Optional[bool]): Files saved to ``output_dir`` Defaults to ``True``. Must be set to true if using manual regridding methods with ESMF_regrid. - save_weights_path (Optional[str]): If provided, saves or (if exists) reuses the xESMF regridding weights at the specified path. + Returns: regridded_dataset (Xarray.Dataset): """ @@ -975,15 +974,12 @@ def regrid_dataset_via_xesmf( + f"Regridded size: {output_dataset.nbytes/1e6:.2f} Mb\n" ) - weights_exist = save_weights_path is not None and Path(save_weights_path).is_file() regridder = xe.Regridder( input_dataset, output_dataset, method=regridding_method, locstream_out=False, periodic=False, - filename=save_weights_path, - reuse_weights=weights_exist, ) dataset = regridder(input_dataset) @@ -1075,7 +1071,6 @@ def regrid_with_subsampling( nx_sub, ny_sub, regridding_method="nearest_s2d", - subsampling_weights_path=None, ): """ Regrids input_dataset to sub_sampled_grid to @@ -1088,10 +1083,6 @@ def regrid_with_subsampling( Corner coordinates of the destination grid. nx_sub, ny_sub : int Number of sub-points per cell (typically from compute_subsampling_factor). - regridding_method : str - Regridding method to use for mapping from input_dataset to sub-sampled points. Should be compatible with xESMF regridding methods, e.g., 'nearest_s2d' - subsampling_weights_path : str or Path, optional - If provided, saves or (if exists) reuses the xESMF regridding weights for the sub-sampling regridding at the specified path. Returns ------- regridded_dataset : xr.Dataset @@ -1121,7 +1112,6 @@ def regrid_with_subsampling( flat_output, regridding_method=regridding_method, write_to_file=False, - save_weights_path=subsampling_weights_path, ) # Reshape to 4D, keeping sub-points as their own dimension diff --git a/test_r.nc b/test_r.nc deleted file mode 100644 index 348b70dd60abc9bfc2814876fa9f418d29fbaf37..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8658 zcmeHMU2IfE6rSDNmfh|42PhVaKNqM`5SOLYDv9M%_lH)OUD-uZqS-9$tzFz7-Q7|Z zwM3$%NFoLd@FFo4?E_K6kF~@HeDJ}N5VcVP@du*fg9++G8$yEheD}_@*R`NAH6&(E z_Uz1_Ip@roZ@xR5%| z2q|q{c4mW1#dY|$33sMs=GRm*I+h(hS#cs3NgokLAI#fJ;?#;&ZjH#^N!RmX+yWCid;=&AN zOy%GlLLw4TZ^aesb>9v@g;g*Ia{EH9u^5|4Cam91>Th2a2*=tZ;dso~8VWG?I)aP> z%FTdq>#Bg^zk>U(=}Z+=`Xtn>UiA5cYBH74vxQVHtFG@Vs@Ys|VrWw$pGvGx>mHTM z(q@K&?3BNnxlvF&tQR*VJDU8dOe)E0i4SBXQH|n^E?W?8`Q5}&|9cEojS7dXU`>^* zX-pvutAVbvaDr?XdmJ<1LmNFCEzL(B5P_(2Wkg%yr?5I&0BxvPfG?NJCsWx(Q7`;? z8MNIzH>wyHx8KCMl9C`zFqKX}lt_2!Z2BZi0BJj2s>G5AD0)i1^y>M`mzgYBJMKus;6}^K)6|KtJ6Rd*ntD=TCjX0{e}&8fNpaQ+B;C(}ETEwm z1dj%SfoLG?3&ahAH&Y7GDVZoH3VLyZ>Ej2}U+E^+C)!fFi8Tx8M*FArO(j>}zqVX? z3y&hnzv9Y*FZUD;uFzJ{6HtZ&17eE_6a|a;EKJGgx{p?vUZb18d5W#{MM6dx_KifO z5C^pQlNQ@kY8Q!1?PfN{`fEf_wh{q}fJ8tdAQ6xVNCYGT5&?;TL_i`S5s(O6D+FGC zxAA=B(eJbyAFpv_&JAftx4khua&|}yb)J6jl5;@YsQ7N{c73l64hJtEId@E({p<66 zxz(R(cf1+Z-*&#EjSU<)7^Lx)TkbhMLiSgjXU>eAJ*IuS(fz?C8b2KSardd4e$uvo zeaq$<(c{{j{rTOGtv;^pI{4c89nL*khR?s4&;Jphzl+cR8_%EM`Jd+bFYx>)GN=;Murusf;mve?nDB!w?-?cj?$647GJO)BE5 zs9#Mg;?T#{iSR=oS2U={s;I}Rn9r)1+p4J7s;I}RsK=_P$Ev8)s`OKeL_i`S5s(N- q1papjT=iwx4*Zu+<;yVqb3;TcUgMZn#!6p?iB~ydeA#`0SN{Nr_yr#T diff --git a/test_weights.nc b/test_weights.nc deleted file mode 100644 index 490de3f6203ef798fbd7a76de2ea25ec96805102..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8498 zcmeHMO-vI}5T4zowU*^q6g4I;YGMMUX#uMj1kxWN)?z5p#KhEvE@4AywJl(R1j3EP z0}1i$K@Xk{CK@jq?X)zGMvPjeLDp!pfMFLZ^GNzoqaQJ-Z$UAlJ`~yqwzZb zNx$mz0n4T`C&_qM4*rG$%V#(Cda59~W3?yQp-$KaH!6O`!1J`0s|9-ds=%igV@Xs`IRw)vV z8yS1r$`$Q=&KR318M%CE>(#7Tu+6co6);>0g@F=MF@8Phlu$ZpmBup@JyCnw&Ojsb zK|pIb$>EqN&F$-^z|i_LhDIMIPSOuOwa~MvA%^vYuGYRqwv9z!GxfnjA14j4kA^zqda zGik|XQa{>Bcr8)?H2A$S-=5S=n+$t+r>EvA Date: Thu, 21 May 2026 21:37:00 -0400 Subject: [PATCH 60/63] Allow regrid_with_subsampling to take additional dims --- mom6_forge/chl.py | 48 ++++++++++++++++++++----------------------- mom6_forge/mapping.py | 4 ++-- tests/test_mapping.py | 41 ++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index 13e9bbb5..b7f7a138 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -86,33 +86,29 @@ def interpolate_and_fill_seawifs( ) chlor_a = chla["CHL_A"] - # Iterate through time - for t in range(src_data.shape[0]): - - # Build source dataset for this timestep - src_ds = xr.Dataset( - { - "chlor_a": xr.DataArray( - src_data[t, ::-1, :].values, - dims=["lat", "lon"], - coords={"lat": src_lat, "lon": src_lon}, - ) - } - ) - - # Regrid to super-sampled sub-point grid and average back to model grid - q_sub = regrid_with_subsampling( - input_dataset=src_ds, - qlon=grid.qlon.values, - qlat=grid.qlat.values, - nx_sub=nx_sub, - ny_sub=ny_sub, - regridding_method="bilinear", - ) - q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values + # Regrid all timesteps at once + src_ds = xr.Dataset( + { + "chlor_a": xr.DataArray( + src_data[:, ::-1, :].values, + dims=["time", "lat", "lon"], + coords={"lat": src_lat, "lon": src_lon}, + ) + } + ) + q_sub = regrid_with_subsampling( + input_dataset=src_ds, + qlon=grid.qlon.values, + qlat=grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="bilinear", + ) + q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values # (time, ny, nx) - # Fill any missing data - q = q_int * ocn_mask + # fill_missing_data is 2D-only, so iterate over time for that step + for t in range(src_data.shape[0]): + q = q_int[t] * ocn_mask q_nan = np.where((q == 0) | np.isnan(q), np.nan, q) chlor_a[t, :] = fill_missing_data(q_nan, ocn_mask) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 8d452da5..282ddfdb 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -1118,10 +1118,10 @@ def regrid_with_subsampling( data_vars = {} for var in regridded_flat.data_vars: data = regridded_flat[var].values # (..., ny, nx*ny_sub*nx_sub) - reshaped = data.reshape(ny, nx, ny_sub, nx_sub) + reshaped = data.reshape(*data.shape[:-2], ny, nx, ny_sub, nx_sub) original_dims = regridded_flat[var].dims - new_dims = (*original_dims, "ny_sub", "nx_sub") + new_dims = (*original_dims[:-2], "ny", "nx", "ny_sub", "nx_sub") data_vars[var] = xr.DataArray( reshaped, diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 81407562..aee02951 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -84,3 +84,44 @@ def test_regrid_with_subsampling(get_simple_grid): assert np.allclose( ds["data"].values, expected_data ), "Regridded data does not match expected values." + + +def test_regrid_with_subsampling_time_dim(get_simple_grid): + nx_sub = ny_sub = 2 + grid = get_simple_grid + lon = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + lat = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + spatial_data = np.array( + [ + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + ], + dtype=float, + ) + nt = 2 + input_ds = xr.Dataset( + {"data": (["time", "lon", "lat"], np.stack([spatial_data] * nt))}, + coords={ + "lon": (["lon"], [x - 0.1 for x in lon]), + "lat": (["lat"], [x - 0.1 for x in lat]), + }, + ) + ds = regrid_with_subsampling( + input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + assert ds["data"].shape == ( + nt, + 2, + 2, + 2, + 2, + ), "Output shape with time dim is incorrect." + expected_spatial = np.array( + [[[[1, 1], [2, 2]], [[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[3, 3], [4, 4]]]] + ) + for t in range(nt): + assert np.allclose( + ds["data"].values[t], expected_spatial + ), f"Regridded data at t={t} does not match expected values." From b0c2bb4f143fe02af803cae7f65a1e675f70afc7 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 22 May 2026 15:29:03 -0400 Subject: [PATCH 61/63] Add Saving Regridder --- mom6_forge/chl.py | 49 ++++++++++++++++++++++++------------------- mom6_forge/mapping.py | 19 ++++++++++------- mom6_forge/topo.py | 2 +- tests/test_mapping.py | 4 ++-- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index b7f7a138..6b3adac2 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -86,29 +86,34 @@ def interpolate_and_fill_seawifs( ) chlor_a = chla["CHL_A"] - # Regrid all timesteps at once - src_ds = xr.Dataset( - { - "chlor_a": xr.DataArray( - src_data[:, ::-1, :].values, - dims=["time", "lat", "lon"], - coords={"lat": src_lat, "lon": src_lon}, - ) - } - ) - q_sub = regrid_with_subsampling( - input_dataset=src_ds, - qlon=grid.qlon.values, - qlat=grid.qlat.values, - nx_sub=nx_sub, - ny_sub=ny_sub, - regridding_method="bilinear", - ) - q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values # (time, ny, nx) - - # fill_missing_data is 2D-only, so iterate over time for that step + regridder = None for t in range(src_data.shape[0]): - q = q_int[t] * ocn_mask + + # Build source dataset for this timestep + src_ds = xr.Dataset( + { + "chlor_a": xr.DataArray( + src_data[t, ::-1, :].values, + dims=["lat", "lon"], + coords={"lat": src_lat, "lon": src_lon}, + ) + } + ) + + # Regrid to super-sampled sub-point grid and average back to model grid + q_sub, regridder = regrid_with_subsampling( + input_dataset=src_ds, + qlon=grid.qlon.values, + qlat=grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="bilinear", + regridder=regridder, + ) + q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values + + # Fill any missing data + q = q_int * ocn_mask q_nan = np.where((q == 0) | np.isnan(q), np.nan, q) chlor_a[t, :] = fill_missing_data(q_nan, ocn_mask) diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 282ddfdb..e6a67fe9 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -1071,6 +1071,7 @@ def regrid_with_subsampling( nx_sub, ny_sub, regridding_method="nearest_s2d", + regridder=None, ): """ Regrids input_dataset to sub_sampled_grid to @@ -1107,12 +1108,16 @@ def regrid_with_subsampling( } ) - regridded_flat = regrid_dataset_via_xesmf( - input_dataset, - flat_output, - regridding_method=regridding_method, - write_to_file=False, - ) + if regridder is None: + regridder = xe.Regridder( + input_dataset, + flat_output, + method=regridding_method, + locstream_out=False, + periodic=False, + ) + + regridded_flat = regridder(input_dataset) # Reshape to 4D, keeping sub-points as their own dimension data_vars = {} @@ -1137,7 +1142,7 @@ def regrid_with_subsampling( coords["ny_sub"] = np.arange(ny_sub) coords["nx_sub"] = np.arange(nx_sub) - return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs) + return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs), regridder def main(args): diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index d2ac836a..454c1beb 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -808,7 +808,7 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): return self.stats # Compute subsampling factor and generate sub-point grid - ds = regrid_with_subsampling( + ds, _ = regrid_with_subsampling( input_dataset=self.src.ds, qlon=self._grid.qlon.values, qlat=self._grid.qlat.values, diff --git a/tests/test_mapping.py b/tests/test_mapping.py index aee02951..19736268 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -74,7 +74,7 @@ def test_regrid_with_subsampling(get_simple_grid): "lat": (["lat"], [x - 0.1 for x in lat]), }, ) - ds = regrid_with_subsampling( + ds, _ = regrid_with_subsampling( input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub ) assert ds["data"].shape == (2, 2, 2, 2), "Output shape is incorrect." @@ -108,7 +108,7 @@ def test_regrid_with_subsampling_time_dim(get_simple_grid): "lat": (["lat"], [x - 0.1 for x in lat]), }, ) - ds = regrid_with_subsampling( + ds, _ = regrid_with_subsampling( input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub ) assert ds["data"].shape == ( From e94fd96d6a014ce32afda00b2f5f08c145c30a57 Mon Sep 17 00:00:00 2001 From: manishvenu Date: Fri, 22 May 2026 16:46:21 -0400 Subject: [PATCH 62/63] Issue --- tests/test_topo_bathymetry_workflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 375a51cc..04260d0f 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -75,5 +75,5 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): mask = ~np.isnan(topo.depth.values) assert np.isclose( - topo.depth.values[mask], topo.src._stats["D_mean"].values[mask] + topo.depth.values[mask], topo.src.stats["D_mean"].values[mask] ).all() From 1f2b3bded7f519fd548c9dfb0b39e502c7d6a83b Mon Sep 17 00:00:00 2001 From: manishvenu Date: Tue, 26 May 2026 09:42:05 -0600 Subject: [PATCH 63/63] Rename direct_stats_depth to set_depth_from_stats, fix docstring --- mom6_forge/topo.py | 12 ++++++++++-- tests/test_topo_bathymetry_workflows.py | 7 +++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 733e29b5..789ff387 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -873,8 +873,16 @@ def _compute_stats(self, nx_sub, ny_sub, mask_hmin): ) return self.stats - def direct_stats_depth(self, statistic): - """Set the topo depth to a statistic from compute_topo_stats (Must be called before this function to compute the stats).""" + def set_depth_from_stats(self, statistic): + """ + Set the topo depth to a statistic computed by _compute_stats. + + Parameters + ---------- + statistic : str + Which depth statistic to use. Must be one of the "D_*" keys + in self.src.stats (e.g. "mean", "min", "max"). + """ assert ( self.src.stats is not None diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py index 04260d0f..9b2a9bc7 100644 --- a/tests/test_topo_bathymetry_workflows.py +++ b/tests/test_topo_bathymetry_workflows.py @@ -56,8 +56,8 @@ def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): assert stats2 is stats -def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): - """Test direct_stats_depth method which computes depth stats directly from source bathymetry without sub-sampling.""" +def test_set_depth_from_stats(get_rect_topo, synthetic_bathy_file): + """Test set_depth_from_stats sets topo depth to the chosen statistic from _compute_stats.""" topo = get_rect_topo # Load source bathymetry and slice to topo domain @@ -70,8 +70,7 @@ def test_direct_stats_depth(get_rect_topo, synthetic_bathy_file): topo.src = src topo._compute_stats(nx_sub=2, ny_sub=2, mask_hmin=0.0) - # Call direct_stats_depth - topo.direct_stats_depth("mean") + topo.set_depth_from_stats("mean") mask = ~np.isnan(topo.depth.values) assert np.isclose(