diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py
new file mode 100644
index 00000000..f94abd8e
--- /dev/null
+++ b/mom6_forge/_source_bathy.py
@@ -0,0 +1,123 @@
+"""Source bathymetry loader for mom6_forge.
+
+``SourceBathy`` is a lightweight data container for a regional slice of a
+source bathymetry dataset. ``Topo._get_src()`` creates and caches one
+automatically when ``set_from_dataset`` is called. Users who call pipeline
+methods directly (e.g. ``high_res_regrid``, ``generate_mask_ocean_frac``)
+should construct a ``SourceBathy`` explicitly::
+
+ from mom6_forge._source_bathy import SourceBathy
+ src = SourceBathy("gebco_2023.nc").slice_to_domain(topo)
+"""
+
+import numpy as np
+import xarray as xr
+from pathlib import Path
+from mom6_forge.utils import longitude_slicer
+
+
+class SourceBathy:
+ """Regional slice of a source bathymetry dataset (e.g. GEBCO).
+
+ Holds the loaded, domain-clipped elevation DataArray together with its
+ coordinate-name metadata. Per-cell depth statistics are computed and
+ cached here so repeated calls with the same source file skip the
+ expensive sub-sampling step.
+
+ Parameters
+ ----------
+ path : str or Path
+ lon_name : str — longitude coordinate name. Default ``"lon"``.
+ lat_name : str — latitude coordinate name. Default ``"lat"``.
+ elevation_name : str — elevation variable (positive-up). Default ``"elevation"``.
+ """
+
+ def __init__(
+ self,
+ path,
+ lon_name="lon",
+ lat_name="lat",
+ elevation_name="elevation",
+ ):
+ self.path = Path(path)
+ self.lon_name = lon_name
+ self.lat_name = lat_name
+ self.elevation_name = elevation_name
+ self._da = None # set by slice_to_domain
+ self._topo_stats = None # set by compute_topo_stats
+
+ # ------------------------------------------------------------------
+ # Loading
+ # ------------------------------------------------------------------
+
+ def slice_to_domain(self, topo, buf=0.5):
+ """Load and clip elevation to the topo grid extent plus ``buf`` degrees.
+
+ Handles the global-longitude seam automatically. Mutates ``self``
+ in place and returns ``self`` for chaining.
+
+ Parameters
+ ----------
+ topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used.
+ buf : float — degree buffer around the Q-grid bounding box. Default 0.5.
+ """
+ lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max()))
+ lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max()))
+
+ ds_src = xr.open_dataset(self.path, chunks="auto")
+
+ ds = ds_src.sel(
+ {self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)}
+ )
+
+ dlon = float(ds[self.lon_name][1] - ds[self.lon_name][0])
+ total_lon = float(ds[self.lon_name][-1] - ds[self.lon_name][0] + dlon)
+ if np.isclose(total_lon, 360):
+ ds = longitude_slicer(
+ ds,
+ np.array(lon_extent) + np.array([-buf, buf]),
+ self.lon_name,
+ )
+ else:
+ ds = ds.sel(
+ {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)}
+ )
+ self._ds = ds
+ self._da = ds[self.elevation_name].load()
+ return self
+
+ # ------------------------------------------------------------------
+ # Accessors
+ # ------------------------------------------------------------------
+
+ @property
+ def lon(self):
+ """1-D longitude array."""
+ return self._da[self.lon_name].values
+
+ @property
+ def lat(self):
+ """1-D latitude array."""
+ return self._da[self.lat_name].values
+
+ @property
+ def depth(self):
+ """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src)."""
+ return -self._da.values.astype(float)
+
+ @property
+ def da(self):
+ """Raw elevation DataArray with source coordinate names (positive-up)."""
+ return self._da
+
+ @property
+ def ds(self):
+ """Raw dataset with source coordinate names (positive-up)."""
+ return self._ds
+
+ def __repr__(self):
+ shape = self._da.shape if self._da is not None else "not loaded"
+ return (
+ f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, "
+ f"lat={self.lat_name!r}, elevation={self.elevation_name!r}, shape={shape})"
+ )
diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py
index 98e8d40c..13e9bbb5 100644
--- a/mom6_forge/chl.py
+++ b/mom6_forge/chl.py
@@ -6,7 +6,8 @@
from datetime import datetime
from pathlib import Path
from os.path import isfile
-from mom6_forge.utils import fill_missing_data
+from mom6_forge.utils import fill_missing_data, compute_subsampling_factor
+from mom6_forge.mapping import regrid_with_subsampling
def interpolate_and_fill_seawifs(
@@ -64,7 +65,7 @@ def interpolate_and_fill_seawifs(
(np.arange(src_ni) + 0.5) / src_ni
) * 360.0 + src_x0 # Recompute as doubles
- spr_lat, spr_lon = grid.tlat.values, grid.tlon.values
+ ny_sub, nx_sub = compute_subsampling_factor(src_nj, src_ni, ocn_nj, ocn_ni)
# Set output path
if output_path is None:
@@ -88,21 +89,27 @@ def interpolate_and_fill_seawifs(
# Iterate through time
for t in range(src_data.shape[0]):
- # Bilinearly interpolate the source data onto the super-sampled grid
- # adj lon to -180 to 180
- adj_lon = spr_lon - 360
- data = src_data[t, ::-1, :].values
-
- src = {"lon": src_lon, "lat": src_lat}
- dst = {"lon": adj_lon, "lat": spr_lat}
- regridder = xe.Regridder(
- src,
- dst,
- "bilinear",
- filename="bilin_weights.nc",
- reuse_weights=isfile("bilin_weights.nc"),
+ # Build source dataset for this timestep
+ src_ds = xr.Dataset(
+ {
+ "chlor_a": xr.DataArray(
+ src_data[t, ::-1, :].values,
+ dims=["lat", "lon"],
+ coords={"lat": src_lat, "lon": src_lon},
+ )
+ }
+ )
+
+ # Regrid to super-sampled sub-point grid and average back to model grid
+ q_sub = regrid_with_subsampling(
+ input_dataset=src_ds,
+ qlon=grid.qlon.values,
+ qlat=grid.qlat.values,
+ nx_sub=nx_sub,
+ ny_sub=ny_sub,
+ regridding_method="bilinear",
)
- q_int = regridder(data)
+ q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values
# Fill any missing data
q = q_int * ocn_mask
diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py
index 065c7e97..6cc6da86 100755
--- a/mom6_forge/mapping.py
+++ b/mom6_forge/mapping.py
@@ -994,6 +994,148 @@ def regrid_dataset_via_xesmf(
return dataset
+def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub):
+ """
+ Given corner coordinates of a low-res grid (qlon/qlat on the q-point staggering),
+ return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python)
+
+ Parameters
+ ----------
+ qlon, qlat : np.ndarray shape (ny+1, nx+1)
+ Corner (q-point) coordinates of the low-res grid.
+ nx_sub, ny_sub : int
+ Number of sub-points per cell in each direction.
+
+ Returns
+ -------
+ sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub)
+ """
+ assert type(qlon) == type(qlat) == np.ndarray, "qlon and qlat must be numpy arrays"
+
+ SW_lon = qlon[:-1, :-1]
+ SW_lat = qlat[:-1, :-1]
+ SE_lon = qlon[:-1, 1:]
+ SE_lat = qlat[:-1, 1:]
+ NE_lon = qlon[1:, 1:]
+ NE_lat = qlat[1:, 1:]
+ NW_lon = qlon[1:, :-1]
+ NW_lat = qlat[1:, :-1]
+
+ # Fix antimeridian-straddling cells
+ def _fix_period(lon, ref):
+ lon = np.where(lon - ref > 270, lon - 360, lon)
+ lon = np.where(lon - ref < -270, lon + 360, lon)
+ return lon
+
+ SW_lon = _fix_period(SW_lon, NE_lon)
+ SE_lon = _fix_period(SE_lon, NE_lon)
+ NW_lon = _fix_period(NW_lon, NE_lon)
+
+ ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float)
+ jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float)
+
+ i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] # (1,1,1,nx_sub)
+ j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] # (1,1,ny_sub,1)
+
+ # Broadcast all corners to (ny, nx, 1, 1),
+ SW_lon = SW_lon[:, :, np.newaxis, np.newaxis]
+ SE_lon = SE_lon[:, :, np.newaxis, np.newaxis]
+ NE_lon = NE_lon[:, :, np.newaxis, np.newaxis]
+ NW_lon = NW_lon[:, :, np.newaxis, np.newaxis]
+ SW_lat = SW_lat[:, :, np.newaxis, np.newaxis]
+ SE_lat = SE_lat[:, :, np.newaxis, np.newaxis]
+ NE_lat = NE_lat[:, :, np.newaxis, np.newaxis]
+ NW_lat = NW_lat[:, :, np.newaxis, np.newaxis]
+
+ sub_lon = (
+ (1 - i_) * (1 - j_) * SW_lon
+ + i_ * (1 - j_) * SE_lon
+ + i_ * j_ * NE_lon
+ + (1 - i_) * j_ * NW_lon
+ )
+ sub_lat = (
+ (1 - i_) * (1 - j_) * SW_lat
+ + i_ * (1 - j_) * SE_lat
+ + i_ * j_ * NE_lat
+ + (1 - i_) * j_ * NW_lat
+ )
+
+ return sub_lon, sub_lat
+
+
+def regrid_with_subsampling(
+ input_dataset,
+ qlon,
+ qlat,
+ nx_sub,
+ ny_sub,
+ regridding_method="nearest_s2d",
+):
+ """
+ Regrids input_dataset to sub_sampled_grid to
+ properly analyze high-res source data into each coarse cell. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python)
+
+ Parameters
+ ----------
+ input_dataset : xr.Dataset
+ qlon, qlat : np.ndarray shape (ny+1, nx+1)
+ Corner coordinates of the destination grid.
+ nx_sub, ny_sub : int
+ Number of sub-points per cell (typically from compute_subsampling_factor).
+ Returns
+ -------
+ regridded_dataset : xr.Dataset
+ Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc)
+ """
+ ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1
+
+ # Build the (ny, nx, ny_sub, nx_sub) sub-point grid
+ sub_lon, sub_lat = _make_subgrid_points(qlon, qlat, nx_sub, ny_sub)
+
+ # Flatten to (ny, nx*ny_sub*nx_sub) so xesmf sees a 2D destination
+ flat_lon = sub_lon.reshape(ny, nx * ny_sub * nx_sub)
+ flat_lat = sub_lat.reshape(ny, nx * ny_sub * nx_sub)
+
+ flat_output = xr.Dataset(
+ {
+ "lon": xr.DataArray(flat_lon, dims=["ny", "nx"]),
+ "lat": xr.DataArray(flat_lat, dims=["ny", "nx"]),
+ }
+ )
+
+ regridded_flat = regrid_dataset_via_xesmf(
+ input_dataset,
+ flat_output,
+ regridding_method=regridding_method,
+ write_to_file=False,
+ )
+
+ # Reshape to 4D, keeping sub-points as their own dimension
+ data_vars = {}
+ for var in regridded_flat.data_vars:
+ data = regridded_flat[var].values # (..., ny, nx*ny_sub*nx_sub)
+ reshaped = data.reshape(ny, nx, ny_sub, nx_sub)
+
+ original_dims = regridded_flat[var].dims
+ new_dims = (*original_dims, "ny_sub", "nx_sub")
+
+ data_vars[var] = xr.DataArray(
+ reshaped,
+ dims=new_dims,
+ attrs=regridded_flat[var].attrs,
+ )
+
+ coords = {
+ k: v
+ for k, v in regridded_flat.coords.items()
+ if "ny" not in v.dims and "nx" not in v.dims
+ }
+ coords["ny_sub"] = np.arange(ny_sub)
+ coords["nx_sub"] = np.arange(nx_sub)
+
+ return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs)
+
+
def main(args):
if args.parallel:
diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py
index 78d19951..5823cee3 100644
--- a/mom6_forge/topo.py
+++ b/mom6_forge/topo.py
@@ -11,7 +11,8 @@
from pathlib import Path
from mom6_forge.edit_command import *
from mom6_forge.command_manager import TopoCommandManager, CommandType
-from mom6_forge.mapping import regrid_dataset_via_xesmf
+from mom6_forge.mapping import regrid_dataset_via_xesmf, regrid_with_subsampling
+from mom6_forge._source_bathy import SourceBathy
class Topo:
@@ -38,6 +39,7 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary"):
attrs={"units": "m"},
) # Initialize depth with NaNs
self._min_depth = min_depth
+ self._src = None # cached SourceBathy; set by _get_src()
if version_control_dir is None:
raise ValueError(
@@ -217,26 +219,6 @@ def umask(self):
return umask
- @property
- def umask(self):
- """
- Ocean domain mask on U grid. 1 if ocean, 0 if land.
- """
- tmask = self.tmask
-
- # Create empty mask DataArray for umask
- umask = xr.DataArray(
- np.ones(self._grid.ulat.shape, dtype=int),
- dims=["yh", "xq"],
- attrs={"name": "U mask"},
- )
-
- # Fill umask with mask values
- umask[:, :-1] &= tmask.values # h-point translates to the left u-point
- umask[:, 1:] &= tmask.values # h-point translates to the right u-point
-
- return umask
-
@property
def vmask(self):
"""
@@ -311,6 +293,31 @@ def supergridmask(self):
supergridmask[1::2, 1::2] = self.tmask.values
return supergridmask
+ def _get_src(
+ self,
+ bathymetry_path,
+ longitude_coordinate_name,
+ latitude_coordinate_name,
+ vertical_coordinate_name,
+ ):
+ """Return a cached :class:`SourceBathy`, creating and slicing a new one
+ only when the path or coordinate names differ from the current cache."""
+ path = Path(bathymetry_path)
+ if (
+ self._src is None
+ or self._src.path != path
+ or self._src.lon_name != longitude_coordinate_name
+ or self._src.lat_name != latitude_coordinate_name
+ or self._src.elevation_name != vertical_coordinate_name
+ ):
+ self._src = SourceBathy(
+ path,
+ longitude_coordinate_name,
+ latitude_coordinate_name,
+ vertical_coordinate_name,
+ ).slice_to_domain(self)
+ return self._src
+
def point_is_ocean(self, lons, lats):
"""
Given a list of coordinates, return a list of booleans indicating if the coordinates are in the ocean (True) or land (False)
@@ -564,6 +571,90 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0):
# Save to object (Build TCM Object)
self.send_entire_depth_change_to_tcm(new_values)
+ def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin):
+ """Compute per-cell depth statistics by uniform sub-sampling.
+
+ Results are cached on ``src._topo_stats`` so a second call with the
+ same source file returns immediately without recomputation.
+ (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python)
+
+ Parameters
+ ----------
+ src : SourceBathy (part of class)
+ nx_sub, ny_sub : int
+ mask_hmin : float
+
+ Returns
+ -------
+ xr.Dataset — ``OCN_FRAC``, ``D_mean``, ``D_min``, ``D_max``, ``D2_mean``.
+ """
+ assert (
+ self._src is not None
+ ), "Source bathymetry must be loaded to compute topo stats"
+ src = self._src
+ if src._topo_stats is not None:
+ return src._topo_stats
+
+ # Compute subsampling factor and generate sub-point grid
+ ds = regrid_with_subsampling(
+ input_dataset=src.ds,
+ qlon=self._grid.qlon.values,
+ qlat=self._grid.qlat.values,
+ nx_sub=nx_sub,
+ ny_sub=ny_sub,
+ regridding_method="nearest_s2d",
+ )
+
+ depth_sub = ds[src.elevation_name].values # (ny, nx, ny_sub, nx_sub)
+
+ is_ocean = depth_sub > mask_hmin
+ ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub)
+
+ depth_ocean = np.where(is_ocean, depth_sub, np.nan)
+ with np.errstate(all="ignore"):
+ D_mean = np.nanmean(depth_ocean, axis=(-2, -1))
+ D_min = np.nanmin(depth_ocean, axis=(-2, -1))
+ D_max = np.nanmax(depth_ocean, axis=(-2, -1))
+ D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1))
+
+ dims = ["ny", "nx"]
+ src._topo_stats = xr.Dataset(
+ {
+ "OCN_FRAC": xr.DataArray(
+ ocn_frac,
+ dims=dims,
+ attrs={
+ "long_name": "ocean fraction from sub-sampling",
+ "units": "1",
+ },
+ ),
+ "D_mean": xr.DataArray(
+ D_mean,
+ dims=dims,
+ attrs={"long_name": "mean ocean depth in cell", "units": "m"},
+ ),
+ "D_min": xr.DataArray(
+ D_min,
+ dims=dims,
+ attrs={"long_name": "minimum ocean depth in cell", "units": "m"},
+ ),
+ "D_max": xr.DataArray(
+ D_max,
+ dims=dims,
+ attrs={"long_name": "maximum ocean depth in cell", "units": "m"},
+ ),
+ "D2_mean": xr.DataArray(
+ D2_mean,
+ dims=dims,
+ attrs={
+ "long_name": "mean squared ocean depth in cell",
+ "units": "m2",
+ },
+ ),
+ }
+ )
+ return src._topo_stats
+
def set_from_dataset(
self,
bathymetry_path,
diff --git a/mom6_forge/topo_editor.py b/mom6_forge/topo_editor.py
index 8a87288e..1a45d27c 100644
--- a/mom6_forge/topo_editor.py
+++ b/mom6_forge/topo_editor.py
@@ -6,6 +6,7 @@
from matplotlib.ticker import MaxNLocator
from mom6_forge.edit_command import *
from mom6_forge.git_utils import *
+from mom6_forge.utils import normalize_deg
class TopoEditor(widgets.HBox):
@@ -121,6 +122,24 @@ def format_coord(x, y):
self._interactive_plot = widgets.HBox(
children=(self.fig.canvas,), layout={"border_left": "1px solid grey"}
)
+ # Stats overlay text box (hidden by default)
+ self._stats_text = self.ax.text(
+ 0.02,
+ 0.98,
+ "",
+ transform=self.ax.transAxes,
+ fontsize=9,
+ verticalalignment="top",
+ bbox=dict(
+ boxstyle="round,pad=0.4",
+ facecolor="white",
+ alpha=0.82,
+ edgecolor="#888888",
+ linewidth=0.8,
+ ),
+ visible=False,
+ zorder=10,
+ )
def construct_control_panel(self):
"""
@@ -163,6 +182,15 @@ def construct_control_panel(self):
layout={"width": "80%"},
style={"description_width": "auto"},
)
+ self._set_to_mean_button = widgets.Button(
+ description="Mean", disabled=True, layout={"width": "30%"}
+ )
+ self._set_to_max_button = widgets.Button(
+ description="Max", disabled=True, layout={"width": "30%"}
+ )
+ self._set_to_min_button = widgets.Button(
+ description="Min", disabled=True, layout={"width": "30%"}
+ )
# --- Basin editing widgets ---
self._basin_specifier_toggle = widgets.Button(
@@ -235,13 +263,35 @@ def construct_control_panel(self):
self._min_depth_specifier,
]
)
- cell_editing_section = widgets.VBox(
- [
- widgets.HTML("
Cell Editing
"),
- self._selected_cell_label,
- self._depth_specifier,
- ]
- )
+ cell_editing_section_children = [
+ widgets.HTML("Cell Editing
"),
+ self._selected_cell_label,
+ self._depth_specifier,
+ ]
+
+ # Only add stats section if statistics are available
+ has_stats = (
+ hasattr(self.topo._src, "_topo_stats")
+ and self.topo._src._topo_stats is not None
+ )
+ if has_stats:
+ cell_editing_section_children.extend(
+ [
+ widgets.HTML(
+ "Set to statistic:
"
+ ),
+ widgets.HBox(
+ [
+ self._set_to_mean_button,
+ self._set_to_max_button,
+ self._set_to_min_button,
+ ],
+ layout={"justify_content": "space-between"},
+ ),
+ ]
+ )
+
+ cell_editing_section = widgets.VBox(cell_editing_section_children)
basin_section = widgets.VBox(
[
widgets.HTML("Basin Selector
"),
@@ -333,6 +383,108 @@ def trigger_refresh(self):
self._min_depth_specifier.value = self.topo.min_depth
self.update_undo_redo_buttons()
+ def _draw_cell_stats(self, visible):
+ """Draw per-cell stat annotations directly on each visible cell."""
+ # Clear previous annotations
+ if hasattr(self, "_cell_stat_texts"):
+ for t in self._cell_stat_texts:
+ t.remove()
+ self._cell_stat_texts = []
+
+ if (
+ not hasattr(self.topo._src, "_topo_stats")
+ or self.topo._src._topo_stats is None
+ ):
+ return
+
+ ds = self.topo._src._topo_stats
+ js, is_ = np.where(visible)
+
+ qlon = self.topo._grid.qlon.data
+ qlat = self.topo._grid.qlat.data
+
+ for idx in range(len(js)):
+ j, i = js[idx], is_[idx]
+
+ # Cell centre
+ cx = 0.25 * (
+ qlon[j, i] + qlon[j, i + 1] + qlon[j + 1, i] + qlon[j + 1, i + 1]
+ )
+ cy = 0.25 * (
+ qlat[j, i] + qlat[j, i + 1] + qlat[j + 1, i] + qlat[j + 1, i + 1]
+ )
+
+ lines = []
+ for var in ds.data_vars:
+ if var == "D2_mean":
+ continue # Skip this variable for now since it's not super informative
+ label = var
+ if var.startswith("D_"):
+ label = var[2:] # Get rid of the "D_" prefix for display purposes
+ val = ds[var].data[j, i]
+ units = ds[var].attrs.get("units", "")
+ if units == "1": # Unitless
+ units = ""
+ lines.append(f"{label}: {val:.2f} {units}")
+
+ t = self.ax.text(
+ cx,
+ cy,
+ "\n".join(lines),
+ fontsize=7,
+ ha="center",
+ va="center",
+ transform=ccrs.PlateCarree(),
+ bbox=dict(
+ boxstyle="round,pad=0.2",
+ facecolor="none",
+ alpha=0.7,
+ edgecolor="none",
+ ),
+ zorder=11,
+ )
+ self._cell_stat_texts.append(t)
+
+ self.fig.canvas.draw_idle()
+
+ def _on_zoom_change(self, ax):
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+
+ qlon = self.topo._grid.qlon.data
+ qlat = self.topo._grid.qlat.data
+
+ # Check if all four corners of each cell are within the view limits
+ visible = (
+ (qlon[:-1, :-1] >= normalize_deg(xlim[0]))
+ & (qlon[:-1, :-1] <= normalize_deg(xlim[1]))
+ & (qlon[:-1, 1:] >= normalize_deg(xlim[0]))
+ & (qlon[:-1, 1:] <= normalize_deg(xlim[1]))
+ & (qlon[1:, :-1] >= normalize_deg(xlim[0]))
+ & (qlon[1:, :-1] <= normalize_deg(xlim[1]))
+ & (qlon[1:, 1:] >= normalize_deg(xlim[0]))
+ & (qlon[1:, 1:] <= normalize_deg(xlim[1]))
+ & (qlat[:-1, :-1] >= ylim[0])
+ & (qlat[:-1, :-1] <= ylim[1])
+ & (qlat[:-1, 1:] >= ylim[0])
+ & (qlat[:-1, 1:] <= ylim[1])
+ & (qlat[1:, :-1] >= ylim[0])
+ & (qlat[1:, :-1] <= ylim[1])
+ & (qlat[1:, 1:] >= ylim[0])
+ & (qlat[1:, 1:] <= ylim[1])
+ )
+
+ # Always clear old annotations first
+ if hasattr(self, "_cell_stat_texts"):
+ for t in self._cell_stat_texts:
+ t.remove()
+ self._cell_stat_texts = []
+
+ if visible.sum() <= 40:
+ self._draw_cell_stats(visible)
+ else:
+ self.fig.canvas.draw_idle()
+
def _select_cell(self, i, j):
"""Select a cell in the topography grid and update the UI accordingly."""
# Remove old patch if it exists
@@ -383,6 +535,19 @@ def _select_cell(self, i, j):
if hasattr(self, "_depth_specifier"):
self._depth_specifier.disabled = False
self._depth_specifier.value = self.topo.depth.data[j, i]
+
+ # Enable statistic buttons if statistics are available
+ has_stats = (
+ hasattr(self.topo._src, "_topo_stats")
+ and self.topo._src._topo_stats is not None
+ )
+ for btn in [
+ self._set_to_mean_button,
+ self._set_to_max_button,
+ self._set_to_min_button,
+ ]:
+ btn.disabled = not has_stats
+
if hasattr(self, "_basin_specifier"):
label = self.topo.basintmask.data[j, i]
self._basin_specifier.value = f"Basin Label Number: {str(label)}"
@@ -405,7 +570,9 @@ def construct_observances(self):
# Double click event for cell selection on the plot
self.fig.canvas.mpl_connect("button_press_event", self.on_double_click)
-
+ # Zoom-dependent stats overlay
+ self.ax.callbacks.connect("xlim_changed", self._on_zoom_change)
+ self.ax.callbacks.connect("ylim_changed", self._on_zoom_change)
# Min depth change observer
self._min_depth_specifier.observe(
self.on_min_depth_change, names="value", type="change"
@@ -420,6 +587,11 @@ def construct_observances(self):
self.on_depth_change, names="value", type="change"
)
+ # Statistic buttons
+ self._set_to_mean_button.on_click(self.set_depth_to_mean)
+ self._set_to_max_button.on_click(self.set_depth_to_max)
+ self._set_to_min_button.on_click(self.set_depth_to_min)
+
# Undo/Redo/Reset buttons
self._undo_button.on_click(self.undo_last_edit)
self._redo_button.on_click(self.redo_last_edit)
@@ -496,6 +668,37 @@ def on_depth_change(self, change):
self.apply_edit(cmd)
self.update_undo_redo_buttons()
+ def _get_statistic_value(self, stat_name):
+ """Get a statistic value for the selected cell."""
+ if self._selected_cell is None or not hasattr(self.topo._src, "_topo_stats"):
+ return None
+
+ i, j, _ = self._selected_cell
+ ds = self.topo._src._topo_stats
+
+ if ds is None or stat_name not in ds.data_vars:
+ return None
+
+ return float(ds[stat_name].data[j, i])
+
+ def set_depth_to_mean(self, b):
+ """Set the selected cell's depth to the mean value."""
+ val = self._get_statistic_value("D_mean")
+ if val is not None:
+ self._depth_specifier.value = val
+
+ def set_depth_to_max(self, b):
+ """Set the selected cell's depth to the max value."""
+ val = self._get_statistic_value("D_max")
+ if val is not None:
+ self._depth_specifier.value = val
+
+ def set_depth_to_min(self, b):
+ """Set the selected cell's depth to the min value."""
+ val = self._get_statistic_value("D_min")
+ if val is not None:
+ self._depth_specifier.value = val
+
def on_git_create_branch(self, b):
"""Create a new git branch"""
name = self._git_branch_name.value.strip()
diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py
index adc09bca..1fcbce39 100644
--- a/mom6_forge/utils.py
+++ b/mom6_forge/utils.py
@@ -697,3 +697,30 @@ def quadrilateral_area(v1, v2, v3, v4):
a4 = angle_between(v4, v1, v3)
return (a1 + a2 + a3 + a4 - 2 * np.pi) * R**2
+
+
+def compute_subsampling_factor(src_nj, src_ni, dst_nj, dst_ni):
+ """
+ Compute the sub-sampling factors needed so that the super-sampled
+ destination grid has at least as many points as the source grid.
+
+ Parameters
+ ----------
+ src_nj, src_ni : int
+ Source grid dimensions.
+ dst_nj, dst_ni : int
+ Destination grid dimensions.
+
+ Returns
+ -------
+ ny_sub, nx_sub : int
+ """
+ nx_sub = 1
+ while nx_sub * dst_ni < src_ni:
+ nx_sub += 1
+
+ ny_sub = 1
+ while ny_sub * dst_nj < src_nj:
+ ny_sub += 1
+
+ return ny_sub, nx_sub
diff --git a/setup.py b/setup.py
index 4b4ca999..facab1e4 100644
--- a/setup.py
+++ b/setup.py
@@ -37,5 +37,6 @@
"gitpython>=3.1,<3.2.0",
"cartopy>=0.23,<0.30",
"xesmf>=0.8.10,<1.0.0",
+ "dask>=2026.3.0,<2026.4.0",
],
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 0b83f72e..36161409 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,11 @@
import numpy as np
import pytest
+import tempfile
+from pathlib import Path
+import xarray as xr
from mom6_forge.grid import Grid
from mom6_forge.topo import Topo
+from mom6_forge._source_bathy import SourceBathy
@pytest.fixture
@@ -252,3 +256,98 @@ def get_rect_topo(get_rect_grid, tmp_path):
topo = Topo(get_rect_grid, min_depth=0, version_control_dir=tmp_path)
topo.set_flat(1000)
return topo
+
+
+@pytest.fixture
+def synthetic_bathy_file():
+ """Create a temporary synthetic bathymetry NetCDF file for testing.
+
+ Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid().
+ """
+ with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp:
+ bathy_file = tmp.name
+
+ # Create synthetic bathymetry covering the Panama region
+ # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3
+ # So we need lon 278-282, lat 7-10 with some buffer
+ lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer
+ lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer
+
+ # Create synthetic elevation data (positive-up, like GEBCO)
+ # Ocean is negative (water), land is positive
+ elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep
+
+ # Add synthetic land masses (islands)
+ # Create an island around (280, 8.5)
+ lon_2d, lat_2d = np.meshgrid(lon, lat)
+ island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5
+ elevation[island_mask] = 200.0 # Synthetic island
+
+ ds = xr.Dataset(
+ {
+ "elevation": (["lat", "lon"], elevation),
+ },
+ coords={
+ "lon": lon,
+ "lat": lat,
+ },
+ )
+ ds.to_netcdf(bathy_file)
+
+ yield bathy_file
+
+ # Cleanup
+ Path(bathy_file).unlink()
+
+
+@pytest.fixture
+def get_simple_grid():
+ grid = Grid(
+ resolution=1, # in degrees
+ xstart=1, # min longitude in [0, 360]
+ lenx=2, # longitude extent in degrees
+ ystart=1, # min latitude in [-90, 90]
+ leny=2, # latitude extent in degrees
+ name="panama1",
+ )
+ return grid
+
+
+@pytest.fixture
+def get_simple_global_grid():
+ grid = Grid(
+ resolution=1, # in degrees
+ xstart=0, # min longitude in [0, 360]
+ lenx=360, # longitude extent in degrees
+ ystart=-90, # min latitude in [-90, 90]
+ leny=180, # latitude extent in degrees
+ name="panama1",
+ cyclic_x=True,
+ )
+ return grid
+
+
+@pytest.fixture
+def get_PM_seam_grid():
+ grid = Grid(
+ resolution=1, # in degrees
+ xstart=359, # min longitude in [0, 360]
+ lenx=2, # longitude extent in degrees
+ ystart=-1, # min latitude in [-90, 90]
+ leny=2, # latitude extent in degrees
+ name="panama1",
+ )
+ return grid
+
+
+@pytest.fixture
+def get_dateline_seam_grid():
+ grid = Grid(
+ resolution=1, # in degrees
+ xstart=-1, # min longitude in [0, 360]
+ lenx=2, # longitude extent in degrees
+ ystart=-1, # min latitude in [-90, 90]
+ leny=2, # latitude extent in degrees
+ name="panama1",
+ )
+ return grid
diff --git a/tests/test_mapping.py b/tests/test_mapping.py
new file mode 100644
index 00000000..81407562
--- /dev/null
+++ b/tests/test_mapping.py
@@ -0,0 +1,86 @@
+from mom6_forge.mapping import _make_subgrid_points, regrid_with_subsampling
+import numpy as np
+import xarray as xr
+
+
+def test_make_subgrid_points(get_simple_grid):
+ # Test with a simple 2x2 grid and 2 sub-points per cell
+ nx_sub = ny_sub = 2
+ grid = get_simple_grid
+ sub_lon, sub_lat = _make_subgrid_points(
+ grid.qlon.values, grid.qlat.values, nx_sub, ny_sub
+ )
+
+ expected_sub_lon = np.array(
+ [
+ [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]],
+ [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]],
+ ]
+ )
+ expected_sub_lat = np.array(
+ [
+ [[[4 / 3, 4 / 3], [5 / 3, 5 / 3]], [[4 / 3, 4 / 3], [5 / 3, 5 / 3]]],
+ [[[7 / 3, 7 / 3], [8 / 3, 8 / 3]], [[7 / 3, 7 / 3], [8 / 3, 8 / 3]]],
+ ]
+ )
+
+ assert np.allclose(
+ sub_lon, expected_sub_lon
+ ), "Sub-grid longitudes do not match expected values."
+ assert np.allclose(
+ sub_lat, expected_sub_lat
+ ), "Sub-grid latitudes do not match expected values."
+
+
+def test_smoke_seams_and_global_make_subgrid_points(
+ get_dateline_seam_grid, get_PM_seam_grid, get_simple_global_grid
+):
+ # Test with a simple 2x2 grid and 2 sub-points per cell
+ nx_sub = ny_sub = 2
+ grid = get_dateline_seam_grid
+ sub_lon, sub_lat = _make_subgrid_points(
+ grid.qlon.values, grid.qlat.values, nx_sub, ny_sub
+ )
+ grid = get_PM_seam_grid
+ sub_lon, sub_lat = _make_subgrid_points(
+ grid.qlon.values, grid.qlat.values, nx_sub, ny_sub
+ )
+ grid = get_simple_global_grid
+ sub_lon, sub_lat = _make_subgrid_points(
+ grid.qlon.values, grid.qlat.values, nx_sub, ny_sub
+ )
+
+
+def test_regrid_with_subsampling(get_simple_grid):
+ # Test with a simple 2x2 grid and 2 sub-points per cell with data that lands exactly on the sub points (subtracted by 0.1 to show snapping to sub points)
+ nx_sub = ny_sub = 2
+ grid = get_simple_grid
+ lon = [4 / 3, 5 / 3, 7 / 3, 8 / 3]
+ lat = [4 / 3, 5 / 3, 7 / 3, 8 / 3]
+ input_ds = xr.Dataset(
+ {
+ "data": (
+ ["lon", "lat"],
+ [
+ np.arange(1, 5, 1),
+ np.arange(1, 5, 1),
+ np.arange(1, 5, 1),
+ np.arange(1, 5, 1),
+ ],
+ )
+ },
+ coords={
+ "lon": (["lon"], [x - 0.1 for x in lon]),
+ "lat": (["lat"], [x - 0.1 for x in lat]),
+ },
+ )
+ ds = regrid_with_subsampling(
+ input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub
+ )
+ assert ds["data"].shape == (2, 2, 2, 2), "Output shape is incorrect."
+ expected_data = np.array(
+ [[[[1, 1], [2, 2]], [[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[3, 3], [4, 4]]]]
+ )
+ assert np.allclose(
+ ds["data"].values, expected_data
+ ), "Regridded data does not match expected values."
diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py
new file mode 100644
index 00000000..a96b6b3f
--- /dev/null
+++ b/tests/test_source_bathy.py
@@ -0,0 +1,105 @@
+"""Smoke test for SourceBathy loader."""
+
+import numpy as np
+import pytest
+import tempfile
+from pathlib import Path
+import xarray as xr
+from mom6_forge._source_bathy import SourceBathy
+
+
+@pytest.fixture
+def synthetic_bathy_file():
+ """Create a temporary synthetic bathymetry NetCDF file for testing.
+
+ Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid().
+ """
+ with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp:
+ bathy_file = tmp.name
+
+ # Create synthetic bathymetry covering the Panama region
+ # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3
+ # So we need lon 278-282, lat 7-10 with some buffer
+ lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer
+ lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer
+
+ # Create synthetic elevation data (positive-up, like GEBCO)
+ # Ocean is negative (water), land is positive
+ elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep
+
+ # Add synthetic land masses (islands)
+ # Create an island around (280, 8.5)
+ lon_2d, lat_2d = np.meshgrid(lon, lat)
+ island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5
+ elevation[island_mask] = 200.0 # Synthetic island
+
+ ds = xr.Dataset(
+ {
+ "elevation": (["lat", "lon"], elevation),
+ },
+ coords={
+ "lon": lon,
+ "lat": lat,
+ },
+ )
+ ds.to_netcdf(bathy_file)
+
+ yield bathy_file
+
+ # Cleanup
+ Path(bathy_file).unlink()
+
+
+def test_source_bathy_initialization(synthetic_bathy_file):
+ """Test SourceBathy initialization and coordinate names."""
+ src = SourceBathy(
+ synthetic_bathy_file,
+ lon_name="lon",
+ lat_name="lat",
+ elevation_name="elevation",
+ )
+
+ assert src.path == Path(synthetic_bathy_file)
+ assert src.lon_name == "lon"
+ assert src.lat_name == "lat"
+ assert src.elevation_name == "elevation"
+ assert src._da is None # Not loaded yet
+
+
+def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file):
+ """Smoke test: load and slice elevation to topo domain."""
+ topo = get_rect_topo
+
+ src = SourceBathy(synthetic_bathy_file)
+ src.slice_to_domain(topo, buf=0.4)
+
+ # Verify data was loaded
+ assert src._da is not None
+ assert src.lon is not None
+ assert src.lat is not None
+
+ # Verify shape makes sense
+ assert len(src.lon) > 0, f"Expected lon data, got empty array"
+ assert len(src.lat) > 0, f"Expected lat data, got empty array"
+ assert src._da.shape == (len(src.lat), len(src.lon))
+
+
+def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file):
+ """Test that elevation is converted to positive-down depth."""
+ topo = get_rect_topo
+
+ src = SourceBathy(synthetic_bathy_file)
+ src.slice_to_domain(topo, buf=0.5)
+
+ # Get depth and verify sign conversion
+ depth = src.depth
+
+ # Verify no NaNs in the result
+ assert not bool(np.isnan(depth).all()), "All depth values are NaN"
+
+ # Verify positive depth values for ocean (elevation is negative)
+ non_nan_values = depth[~np.isnan(depth)]
+ assert len(non_nan_values) > 0, "No valid depth values"
+ assert np.any(non_nan_values > 0), "Expected positive depth values for ocean"
+
+ assert depth.shape == src.da.shape
diff --git a/tests/test_topo.py b/tests/test_topo.py
index 5da340d8..7e89b0e2 100644
--- a/tests/test_topo.py
+++ b/tests/test_topo.py
@@ -1,4 +1,52 @@
+import numpy as np
+import pytest
+import tempfile
+from pathlib import Path
+import xarray as xr
from mom6_forge.topo import *
+from mom6_forge._source_bathy import SourceBathy
+
+
+@pytest.fixture
+def synthetic_bathy_file():
+ """Create a temporary synthetic bathymetry NetCDF file for testing.
+
+ Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid().
+ """
+ with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp:
+ bathy_file = tmp.name
+
+ # Create synthetic bathymetry covering the Panama region
+ # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3
+ # So we need lon 278-282, lat 7-10 with some buffer
+ lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer
+ lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer
+
+ # Create synthetic elevation data (positive-up, like GEBCO)
+ # Ocean is negative (water), land is positive
+ elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep
+
+ # Add synthetic land masses (islands)
+ # Create an island around (280, 8.5)
+ lon_2d, lat_2d = np.meshgrid(lon, lat)
+ island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5
+ elevation[island_mask] = 200.0 # Synthetic island
+
+ ds = xr.Dataset(
+ {
+ "elevation": (["lat", "lon"], elevation),
+ },
+ coords={
+ "lon": lon,
+ "lat": lat,
+ },
+ )
+ ds.to_netcdf(bathy_file)
+
+ yield bathy_file
+
+ # Cleanup
+ Path(bathy_file).unlink()
def test_topo_from_version_control(get_rect_topo):
diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py
new file mode 100644
index 00000000..5a5dbf30
--- /dev/null
+++ b/tests/test_topo_bathymetry_workflows.py
@@ -0,0 +1,57 @@
+import numpy as np
+import pytest
+from mom6_forge.topo import *
+from mom6_forge._source_bathy import SourceBathy
+
+
+def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file):
+ """Test _compute_topo_stats: per-cell depth statistics via xesmf nearest neighbor regridding.
+
+ This test validates the refactored _compute_topo_stats method which:
+ - Generates sub-points within each grid cell
+ - Uses xesmf nearest_s2d regridding to snap sub-points to nearest source data
+ - Computes per-cell statistics (OCN_FRAC, D_mean, D_min, D_max, D2_mean)
+ """
+ topo = get_rect_topo
+
+ # Load source bathymetry and slice to topo domain
+ src = SourceBathy(synthetic_bathy_file)
+ src.slice_to_domain(topo, buf=0.5)
+ topo._src = src
+
+ # Test with different sub-sampling densities
+ for nx_sub, ny_sub in [(2, 2), (3, 3)]:
+ # Call _compute_topo_stats
+ stats = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0)
+
+ # Verify output is a Dataset with expected variables
+ assert isinstance(stats, xr.Dataset)
+ required_vars = ["OCN_FRAC", "D_mean", "D_min", "D_max", "D2_mean"]
+ for var in required_vars:
+ assert var in stats.data_vars, f"Missing {var} in output"
+
+ # Verify shapes match topo grid
+ expected_shape = (topo.depth.shape[0], topo.depth.shape[1])
+ assert stats["OCN_FRAC"].shape == expected_shape
+ assert stats["D_mean"].shape == expected_shape
+ assert stats["D_min"].shape == expected_shape
+ assert stats["D_max"].shape == expected_shape
+ assert stats["D2_mean"].shape == expected_shape
+
+ # Verify OCN_FRAC is between 0 and 1
+ assert (stats["OCN_FRAC"] >= 0).all()
+ assert (stats["OCN_FRAC"] <= 1).all()
+
+ # Verify D_min <= D_mean <= D_max
+ ocean_cells = stats["OCN_FRAC"].values > 0
+ assert (
+ stats["D_min"].values[ocean_cells] <= stats["D_mean"].values[ocean_cells]
+ ).all()
+ assert (
+ stats["D_mean"].values[ocean_cells] <= stats["D_max"].values[ocean_cells]
+ ).all()
+
+ # Verify caching: second call should return cached result
+ stats2 = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0)
+ # Should be the exact same object (cached)
+ assert stats2 is stats