diff --git a/mom6_forge/_source_bathy.py b/mom6_forge/_source_bathy.py new file mode 100644 index 00000000..f94abd8e --- /dev/null +++ b/mom6_forge/_source_bathy.py @@ -0,0 +1,123 @@ +"""Source bathymetry loader for mom6_forge. + +``SourceBathy`` is a lightweight data container for a regional slice of a +source bathymetry dataset. ``Topo._get_src()`` creates and caches one +automatically when ``set_from_dataset`` is called. Users who call pipeline +methods directly (e.g. ``high_res_regrid``, ``generate_mask_ocean_frac``) +should construct a ``SourceBathy`` explicitly:: + + from mom6_forge._source_bathy import SourceBathy + src = SourceBathy("gebco_2023.nc").slice_to_domain(topo) +""" + +import numpy as np +import xarray as xr +from pathlib import Path +from mom6_forge.utils import longitude_slicer + + +class SourceBathy: + """Regional slice of a source bathymetry dataset (e.g. GEBCO). + + Holds the loaded, domain-clipped elevation DataArray together with its + coordinate-name metadata. Per-cell depth statistics are computed and + cached here so repeated calls with the same source file skip the + expensive sub-sampling step. + + Parameters + ---------- + path : str or Path + lon_name : str — longitude coordinate name. Default ``"lon"``. + lat_name : str — latitude coordinate name. Default ``"lat"``. + elevation_name : str — elevation variable (positive-up). Default ``"elevation"``. + """ + + def __init__( + self, + path, + lon_name="lon", + lat_name="lat", + elevation_name="elevation", + ): + self.path = Path(path) + self.lon_name = lon_name + self.lat_name = lat_name + self.elevation_name = elevation_name + self._da = None # set by slice_to_domain + self._topo_stats = None # set by compute_topo_stats + + # ------------------------------------------------------------------ + # Loading + # ------------------------------------------------------------------ + + def slice_to_domain(self, topo, buf=0.5): + """Load and clip elevation to the topo grid extent plus ``buf`` degrees. + + Handles the global-longitude seam automatically. Mutates ``self`` + in place and returns ``self`` for chaining. + + Parameters + ---------- + topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used. + buf : float — degree buffer around the Q-grid bounding box. Default 0.5. + """ + lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max())) + lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max())) + + ds_src = xr.open_dataset(self.path, chunks="auto") + + ds = ds_src.sel( + {self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)} + ) + + dlon = float(ds[self.lon_name][1] - ds[self.lon_name][0]) + total_lon = float(ds[self.lon_name][-1] - ds[self.lon_name][0] + dlon) + if np.isclose(total_lon, 360): + ds = longitude_slicer( + ds, + np.array(lon_extent) + np.array([-buf, buf]), + self.lon_name, + ) + else: + ds = ds.sel( + {self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)} + ) + self._ds = ds + self._da = ds[self.elevation_name].load() + return self + + # ------------------------------------------------------------------ + # Accessors + # ------------------------------------------------------------------ + + @property + def lon(self): + """1-D longitude array.""" + return self._da[self.lon_name].values + + @property + def lat(self): + """1-D latitude array.""" + return self._da[self.lat_name].values + + @property + def depth(self): + """2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src).""" + return -self._da.values.astype(float) + + @property + def da(self): + """Raw elevation DataArray with source coordinate names (positive-up).""" + return self._da + + @property + def ds(self): + """Raw dataset with source coordinate names (positive-up).""" + return self._ds + + def __repr__(self): + shape = self._da.shape if self._da is not None else "not loaded" + return ( + f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, " + f"lat={self.lat_name!r}, elevation={self.elevation_name!r}, shape={shape})" + ) diff --git a/mom6_forge/chl.py b/mom6_forge/chl.py index 98e8d40c..13e9bbb5 100644 --- a/mom6_forge/chl.py +++ b/mom6_forge/chl.py @@ -6,7 +6,8 @@ from datetime import datetime from pathlib import Path from os.path import isfile -from mom6_forge.utils import fill_missing_data +from mom6_forge.utils import fill_missing_data, compute_subsampling_factor +from mom6_forge.mapping import regrid_with_subsampling def interpolate_and_fill_seawifs( @@ -64,7 +65,7 @@ def interpolate_and_fill_seawifs( (np.arange(src_ni) + 0.5) / src_ni ) * 360.0 + src_x0 # Recompute as doubles - spr_lat, spr_lon = grid.tlat.values, grid.tlon.values + ny_sub, nx_sub = compute_subsampling_factor(src_nj, src_ni, ocn_nj, ocn_ni) # Set output path if output_path is None: @@ -88,21 +89,27 @@ def interpolate_and_fill_seawifs( # Iterate through time for t in range(src_data.shape[0]): - # Bilinearly interpolate the source data onto the super-sampled grid - # adj lon to -180 to 180 - adj_lon = spr_lon - 360 - data = src_data[t, ::-1, :].values - - src = {"lon": src_lon, "lat": src_lat} - dst = {"lon": adj_lon, "lat": spr_lat} - regridder = xe.Regridder( - src, - dst, - "bilinear", - filename="bilin_weights.nc", - reuse_weights=isfile("bilin_weights.nc"), + # Build source dataset for this timestep + src_ds = xr.Dataset( + { + "chlor_a": xr.DataArray( + src_data[t, ::-1, :].values, + dims=["lat", "lon"], + coords={"lat": src_lat, "lon": src_lon}, + ) + } + ) + + # Regrid to super-sampled sub-point grid and average back to model grid + q_sub = regrid_with_subsampling( + input_dataset=src_ds, + qlon=grid.qlon.values, + qlat=grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="bilinear", ) - q_int = regridder(data) + q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values # Fill any missing data q = q_int * ocn_mask diff --git a/mom6_forge/mapping.py b/mom6_forge/mapping.py index 065c7e97..6cc6da86 100755 --- a/mom6_forge/mapping.py +++ b/mom6_forge/mapping.py @@ -994,6 +994,148 @@ def regrid_dataset_via_xesmf( return dataset +def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub): + """ + Given corner coordinates of a low-res grid (qlon/qlat on the q-point staggering), + return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) + + Parameters + ---------- + qlon, qlat : np.ndarray shape (ny+1, nx+1) + Corner (q-point) coordinates of the low-res grid. + nx_sub, ny_sub : int + Number of sub-points per cell in each direction. + + Returns + ------- + sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub) + """ + assert type(qlon) == type(qlat) == np.ndarray, "qlon and qlat must be numpy arrays" + + SW_lon = qlon[:-1, :-1] + SW_lat = qlat[:-1, :-1] + SE_lon = qlon[:-1, 1:] + SE_lat = qlat[:-1, 1:] + NE_lon = qlon[1:, 1:] + NE_lat = qlat[1:, 1:] + NW_lon = qlon[1:, :-1] + NW_lat = qlat[1:, :-1] + + # Fix antimeridian-straddling cells + def _fix_period(lon, ref): + lon = np.where(lon - ref > 270, lon - 360, lon) + lon = np.where(lon - ref < -270, lon + 360, lon) + return lon + + SW_lon = _fix_period(SW_lon, NE_lon) + SE_lon = _fix_period(SE_lon, NE_lon) + NW_lon = _fix_period(NW_lon, NE_lon) + + ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float) + jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float) + + i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] # (1,1,1,nx_sub) + j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] # (1,1,ny_sub,1) + + # Broadcast all corners to (ny, nx, 1, 1), + SW_lon = SW_lon[:, :, np.newaxis, np.newaxis] + SE_lon = SE_lon[:, :, np.newaxis, np.newaxis] + NE_lon = NE_lon[:, :, np.newaxis, np.newaxis] + NW_lon = NW_lon[:, :, np.newaxis, np.newaxis] + SW_lat = SW_lat[:, :, np.newaxis, np.newaxis] + SE_lat = SE_lat[:, :, np.newaxis, np.newaxis] + NE_lat = NE_lat[:, :, np.newaxis, np.newaxis] + NW_lat = NW_lat[:, :, np.newaxis, np.newaxis] + + sub_lon = ( + (1 - i_) * (1 - j_) * SW_lon + + i_ * (1 - j_) * SE_lon + + i_ * j_ * NE_lon + + (1 - i_) * j_ * NW_lon + ) + sub_lat = ( + (1 - i_) * (1 - j_) * SW_lat + + i_ * (1 - j_) * SE_lat + + i_ * j_ * NE_lat + + (1 - i_) * j_ * NW_lat + ) + + return sub_lon, sub_lat + + +def regrid_with_subsampling( + input_dataset, + qlon, + qlat, + nx_sub, + ny_sub, + regridding_method="nearest_s2d", +): + """ + Regrids input_dataset to sub_sampled_grid to + properly analyze high-res source data into each coarse cell. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) + + Parameters + ---------- + input_dataset : xr.Dataset + qlon, qlat : np.ndarray shape (ny+1, nx+1) + Corner coordinates of the destination grid. + nx_sub, ny_sub : int + Number of sub-points per cell (typically from compute_subsampling_factor). + Returns + ------- + regridded_dataset : xr.Dataset + Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc) + """ + ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1 + + # Build the (ny, nx, ny_sub, nx_sub) sub-point grid + sub_lon, sub_lat = _make_subgrid_points(qlon, qlat, nx_sub, ny_sub) + + # Flatten to (ny, nx*ny_sub*nx_sub) so xesmf sees a 2D destination + flat_lon = sub_lon.reshape(ny, nx * ny_sub * nx_sub) + flat_lat = sub_lat.reshape(ny, nx * ny_sub * nx_sub) + + flat_output = xr.Dataset( + { + "lon": xr.DataArray(flat_lon, dims=["ny", "nx"]), + "lat": xr.DataArray(flat_lat, dims=["ny", "nx"]), + } + ) + + regridded_flat = regrid_dataset_via_xesmf( + input_dataset, + flat_output, + regridding_method=regridding_method, + write_to_file=False, + ) + + # Reshape to 4D, keeping sub-points as their own dimension + data_vars = {} + for var in regridded_flat.data_vars: + data = regridded_flat[var].values # (..., ny, nx*ny_sub*nx_sub) + reshaped = data.reshape(ny, nx, ny_sub, nx_sub) + + original_dims = regridded_flat[var].dims + new_dims = (*original_dims, "ny_sub", "nx_sub") + + data_vars[var] = xr.DataArray( + reshaped, + dims=new_dims, + attrs=regridded_flat[var].attrs, + ) + + coords = { + k: v + for k, v in regridded_flat.coords.items() + if "ny" not in v.dims and "nx" not in v.dims + } + coords["ny_sub"] = np.arange(ny_sub) + coords["nx_sub"] = np.arange(nx_sub) + + return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs) + + def main(args): if args.parallel: diff --git a/mom6_forge/topo.py b/mom6_forge/topo.py index 78d19951..5823cee3 100644 --- a/mom6_forge/topo.py +++ b/mom6_forge/topo.py @@ -11,7 +11,8 @@ from pathlib import Path from mom6_forge.edit_command import * from mom6_forge.command_manager import TopoCommandManager, CommandType -from mom6_forge.mapping import regrid_dataset_via_xesmf +from mom6_forge.mapping import regrid_dataset_via_xesmf, regrid_with_subsampling +from mom6_forge._source_bathy import SourceBathy class Topo: @@ -38,6 +39,7 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary"): attrs={"units": "m"}, ) # Initialize depth with NaNs self._min_depth = min_depth + self._src = None # cached SourceBathy; set by _get_src() if version_control_dir is None: raise ValueError( @@ -217,26 +219,6 @@ def umask(self): return umask - @property - def umask(self): - """ - Ocean domain mask on U grid. 1 if ocean, 0 if land. - """ - tmask = self.tmask - - # Create empty mask DataArray for umask - umask = xr.DataArray( - np.ones(self._grid.ulat.shape, dtype=int), - dims=["yh", "xq"], - attrs={"name": "U mask"}, - ) - - # Fill umask with mask values - umask[:, :-1] &= tmask.values # h-point translates to the left u-point - umask[:, 1:] &= tmask.values # h-point translates to the right u-point - - return umask - @property def vmask(self): """ @@ -311,6 +293,31 @@ def supergridmask(self): supergridmask[1::2, 1::2] = self.tmask.values return supergridmask + def _get_src( + self, + bathymetry_path, + longitude_coordinate_name, + latitude_coordinate_name, + vertical_coordinate_name, + ): + """Return a cached :class:`SourceBathy`, creating and slicing a new one + only when the path or coordinate names differ from the current cache.""" + path = Path(bathymetry_path) + if ( + self._src is None + or self._src.path != path + or self._src.lon_name != longitude_coordinate_name + or self._src.lat_name != latitude_coordinate_name + or self._src.elevation_name != vertical_coordinate_name + ): + self._src = SourceBathy( + path, + longitude_coordinate_name, + latitude_coordinate_name, + vertical_coordinate_name, + ).slice_to_domain(self) + return self._src + def point_is_ocean(self, lons, lats): """ Given a list of coordinates, return a list of booleans indicating if the coordinates are in the ocean (True) or land (False) @@ -564,6 +571,90 @@ def set_bowl(self, max_depth, dedge, rad_earth=6.378e6, expdecay=400000.0): # Save to object (Build TCM Object) self.send_entire_depth_change_to_tcm(new_values) + def _compute_topo_stats(self, nx_sub, ny_sub, mask_hmin): + """Compute per-cell depth statistics by uniform sub-sampling. + + Results are cached on ``src._topo_stats`` so a second call with the + same source file returns immediately without recomputation. + (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python) + + Parameters + ---------- + src : SourceBathy (part of class) + nx_sub, ny_sub : int + mask_hmin : float + + Returns + ------- + xr.Dataset — ``OCN_FRAC``, ``D_mean``, ``D_min``, ``D_max``, ``D2_mean``. + """ + assert ( + self._src is not None + ), "Source bathymetry must be loaded to compute topo stats" + src = self._src + if src._topo_stats is not None: + return src._topo_stats + + # Compute subsampling factor and generate sub-point grid + ds = regrid_with_subsampling( + input_dataset=src.ds, + qlon=self._grid.qlon.values, + qlat=self._grid.qlat.values, + nx_sub=nx_sub, + ny_sub=ny_sub, + regridding_method="nearest_s2d", + ) + + depth_sub = ds[src.elevation_name].values # (ny, nx, ny_sub, nx_sub) + + is_ocean = depth_sub > mask_hmin + ocn_frac = is_ocean.sum(axis=(-2, -1)) / (nx_sub * ny_sub) + + depth_ocean = np.where(is_ocean, depth_sub, np.nan) + with np.errstate(all="ignore"): + D_mean = np.nanmean(depth_ocean, axis=(-2, -1)) + D_min = np.nanmin(depth_ocean, axis=(-2, -1)) + D_max = np.nanmax(depth_ocean, axis=(-2, -1)) + D2_mean = np.nanmean(depth_ocean**2, axis=(-2, -1)) + + dims = ["ny", "nx"] + src._topo_stats = xr.Dataset( + { + "OCN_FRAC": xr.DataArray( + ocn_frac, + dims=dims, + attrs={ + "long_name": "ocean fraction from sub-sampling", + "units": "1", + }, + ), + "D_mean": xr.DataArray( + D_mean, + dims=dims, + attrs={"long_name": "mean ocean depth in cell", "units": "m"}, + ), + "D_min": xr.DataArray( + D_min, + dims=dims, + attrs={"long_name": "minimum ocean depth in cell", "units": "m"}, + ), + "D_max": xr.DataArray( + D_max, + dims=dims, + attrs={"long_name": "maximum ocean depth in cell", "units": "m"}, + ), + "D2_mean": xr.DataArray( + D2_mean, + dims=dims, + attrs={ + "long_name": "mean squared ocean depth in cell", + "units": "m2", + }, + ), + } + ) + return src._topo_stats + def set_from_dataset( self, bathymetry_path, diff --git a/mom6_forge/topo_editor.py b/mom6_forge/topo_editor.py index 8a87288e..1a45d27c 100644 --- a/mom6_forge/topo_editor.py +++ b/mom6_forge/topo_editor.py @@ -6,6 +6,7 @@ from matplotlib.ticker import MaxNLocator from mom6_forge.edit_command import * from mom6_forge.git_utils import * +from mom6_forge.utils import normalize_deg class TopoEditor(widgets.HBox): @@ -121,6 +122,24 @@ def format_coord(x, y): self._interactive_plot = widgets.HBox( children=(self.fig.canvas,), layout={"border_left": "1px solid grey"} ) + # Stats overlay text box (hidden by default) + self._stats_text = self.ax.text( + 0.02, + 0.98, + "", + transform=self.ax.transAxes, + fontsize=9, + verticalalignment="top", + bbox=dict( + boxstyle="round,pad=0.4", + facecolor="white", + alpha=0.82, + edgecolor="#888888", + linewidth=0.8, + ), + visible=False, + zorder=10, + ) def construct_control_panel(self): """ @@ -163,6 +182,15 @@ def construct_control_panel(self): layout={"width": "80%"}, style={"description_width": "auto"}, ) + self._set_to_mean_button = widgets.Button( + description="Mean", disabled=True, layout={"width": "30%"} + ) + self._set_to_max_button = widgets.Button( + description="Max", disabled=True, layout={"width": "30%"} + ) + self._set_to_min_button = widgets.Button( + description="Min", disabled=True, layout={"width": "30%"} + ) # --- Basin editing widgets --- self._basin_specifier_toggle = widgets.Button( @@ -235,13 +263,35 @@ def construct_control_panel(self): self._min_depth_specifier, ] ) - cell_editing_section = widgets.VBox( - [ - widgets.HTML("

Cell Editing

"), - self._selected_cell_label, - self._depth_specifier, - ] - ) + cell_editing_section_children = [ + widgets.HTML("

Cell Editing

"), + self._selected_cell_label, + self._depth_specifier, + ] + + # Only add stats section if statistics are available + has_stats = ( + hasattr(self.topo._src, "_topo_stats") + and self.topo._src._topo_stats is not None + ) + if has_stats: + cell_editing_section_children.extend( + [ + widgets.HTML( + "

Set to statistic:

" + ), + widgets.HBox( + [ + self._set_to_mean_button, + self._set_to_max_button, + self._set_to_min_button, + ], + layout={"justify_content": "space-between"}, + ), + ] + ) + + cell_editing_section = widgets.VBox(cell_editing_section_children) basin_section = widgets.VBox( [ widgets.HTML("

Basin Selector

"), @@ -333,6 +383,108 @@ def trigger_refresh(self): self._min_depth_specifier.value = self.topo.min_depth self.update_undo_redo_buttons() + def _draw_cell_stats(self, visible): + """Draw per-cell stat annotations directly on each visible cell.""" + # Clear previous annotations + if hasattr(self, "_cell_stat_texts"): + for t in self._cell_stat_texts: + t.remove() + self._cell_stat_texts = [] + + if ( + not hasattr(self.topo._src, "_topo_stats") + or self.topo._src._topo_stats is None + ): + return + + ds = self.topo._src._topo_stats + js, is_ = np.where(visible) + + qlon = self.topo._grid.qlon.data + qlat = self.topo._grid.qlat.data + + for idx in range(len(js)): + j, i = js[idx], is_[idx] + + # Cell centre + cx = 0.25 * ( + qlon[j, i] + qlon[j, i + 1] + qlon[j + 1, i] + qlon[j + 1, i + 1] + ) + cy = 0.25 * ( + qlat[j, i] + qlat[j, i + 1] + qlat[j + 1, i] + qlat[j + 1, i + 1] + ) + + lines = [] + for var in ds.data_vars: + if var == "D2_mean": + continue # Skip this variable for now since it's not super informative + label = var + if var.startswith("D_"): + label = var[2:] # Get rid of the "D_" prefix for display purposes + val = ds[var].data[j, i] + units = ds[var].attrs.get("units", "") + if units == "1": # Unitless + units = "" + lines.append(f"{label}: {val:.2f} {units}") + + t = self.ax.text( + cx, + cy, + "\n".join(lines), + fontsize=7, + ha="center", + va="center", + transform=ccrs.PlateCarree(), + bbox=dict( + boxstyle="round,pad=0.2", + facecolor="none", + alpha=0.7, + edgecolor="none", + ), + zorder=11, + ) + self._cell_stat_texts.append(t) + + self.fig.canvas.draw_idle() + + def _on_zoom_change(self, ax): + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + qlon = self.topo._grid.qlon.data + qlat = self.topo._grid.qlat.data + + # Check if all four corners of each cell are within the view limits + visible = ( + (qlon[:-1, :-1] >= normalize_deg(xlim[0])) + & (qlon[:-1, :-1] <= normalize_deg(xlim[1])) + & (qlon[:-1, 1:] >= normalize_deg(xlim[0])) + & (qlon[:-1, 1:] <= normalize_deg(xlim[1])) + & (qlon[1:, :-1] >= normalize_deg(xlim[0])) + & (qlon[1:, :-1] <= normalize_deg(xlim[1])) + & (qlon[1:, 1:] >= normalize_deg(xlim[0])) + & (qlon[1:, 1:] <= normalize_deg(xlim[1])) + & (qlat[:-1, :-1] >= ylim[0]) + & (qlat[:-1, :-1] <= ylim[1]) + & (qlat[:-1, 1:] >= ylim[0]) + & (qlat[:-1, 1:] <= ylim[1]) + & (qlat[1:, :-1] >= ylim[0]) + & (qlat[1:, :-1] <= ylim[1]) + & (qlat[1:, 1:] >= ylim[0]) + & (qlat[1:, 1:] <= ylim[1]) + ) + + # Always clear old annotations first + if hasattr(self, "_cell_stat_texts"): + for t in self._cell_stat_texts: + t.remove() + self._cell_stat_texts = [] + + if visible.sum() <= 40: + self._draw_cell_stats(visible) + else: + self.fig.canvas.draw_idle() + def _select_cell(self, i, j): """Select a cell in the topography grid and update the UI accordingly.""" # Remove old patch if it exists @@ -383,6 +535,19 @@ def _select_cell(self, i, j): if hasattr(self, "_depth_specifier"): self._depth_specifier.disabled = False self._depth_specifier.value = self.topo.depth.data[j, i] + + # Enable statistic buttons if statistics are available + has_stats = ( + hasattr(self.topo._src, "_topo_stats") + and self.topo._src._topo_stats is not None + ) + for btn in [ + self._set_to_mean_button, + self._set_to_max_button, + self._set_to_min_button, + ]: + btn.disabled = not has_stats + if hasattr(self, "_basin_specifier"): label = self.topo.basintmask.data[j, i] self._basin_specifier.value = f"Basin Label Number: {str(label)}" @@ -405,7 +570,9 @@ def construct_observances(self): # Double click event for cell selection on the plot self.fig.canvas.mpl_connect("button_press_event", self.on_double_click) - + # Zoom-dependent stats overlay + self.ax.callbacks.connect("xlim_changed", self._on_zoom_change) + self.ax.callbacks.connect("ylim_changed", self._on_zoom_change) # Min depth change observer self._min_depth_specifier.observe( self.on_min_depth_change, names="value", type="change" @@ -420,6 +587,11 @@ def construct_observances(self): self.on_depth_change, names="value", type="change" ) + # Statistic buttons + self._set_to_mean_button.on_click(self.set_depth_to_mean) + self._set_to_max_button.on_click(self.set_depth_to_max) + self._set_to_min_button.on_click(self.set_depth_to_min) + # Undo/Redo/Reset buttons self._undo_button.on_click(self.undo_last_edit) self._redo_button.on_click(self.redo_last_edit) @@ -496,6 +668,37 @@ def on_depth_change(self, change): self.apply_edit(cmd) self.update_undo_redo_buttons() + def _get_statistic_value(self, stat_name): + """Get a statistic value for the selected cell.""" + if self._selected_cell is None or not hasattr(self.topo._src, "_topo_stats"): + return None + + i, j, _ = self._selected_cell + ds = self.topo._src._topo_stats + + if ds is None or stat_name not in ds.data_vars: + return None + + return float(ds[stat_name].data[j, i]) + + def set_depth_to_mean(self, b): + """Set the selected cell's depth to the mean value.""" + val = self._get_statistic_value("D_mean") + if val is not None: + self._depth_specifier.value = val + + def set_depth_to_max(self, b): + """Set the selected cell's depth to the max value.""" + val = self._get_statistic_value("D_max") + if val is not None: + self._depth_specifier.value = val + + def set_depth_to_min(self, b): + """Set the selected cell's depth to the min value.""" + val = self._get_statistic_value("D_min") + if val is not None: + self._depth_specifier.value = val + def on_git_create_branch(self, b): """Create a new git branch""" name = self._git_branch_name.value.strip() diff --git a/mom6_forge/utils.py b/mom6_forge/utils.py index adc09bca..1fcbce39 100644 --- a/mom6_forge/utils.py +++ b/mom6_forge/utils.py @@ -697,3 +697,30 @@ def quadrilateral_area(v1, v2, v3, v4): a4 = angle_between(v4, v1, v3) return (a1 + a2 + a3 + a4 - 2 * np.pi) * R**2 + + +def compute_subsampling_factor(src_nj, src_ni, dst_nj, dst_ni): + """ + Compute the sub-sampling factors needed so that the super-sampled + destination grid has at least as many points as the source grid. + + Parameters + ---------- + src_nj, src_ni : int + Source grid dimensions. + dst_nj, dst_ni : int + Destination grid dimensions. + + Returns + ------- + ny_sub, nx_sub : int + """ + nx_sub = 1 + while nx_sub * dst_ni < src_ni: + nx_sub += 1 + + ny_sub = 1 + while ny_sub * dst_nj < src_nj: + ny_sub += 1 + + return ny_sub, nx_sub diff --git a/setup.py b/setup.py index 4b4ca999..facab1e4 100644 --- a/setup.py +++ b/setup.py @@ -37,5 +37,6 @@ "gitpython>=3.1,<3.2.0", "cartopy>=0.23,<0.30", "xesmf>=0.8.10,<1.0.0", + "dask>=2026.3.0,<2026.4.0", ], ) diff --git a/tests/conftest.py b/tests/conftest.py index 0b83f72e..36161409 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ import numpy as np import pytest +import tempfile +from pathlib import Path +import xarray as xr from mom6_forge.grid import Grid from mom6_forge.topo import Topo +from mom6_forge._source_bathy import SourceBathy @pytest.fixture @@ -252,3 +256,98 @@ def get_rect_topo(get_rect_grid, tmp_path): topo = Topo(get_rect_grid, min_depth=0, version_control_dir=tmp_path) topo.set_flat(1000) return topo + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() + + +@pytest.fixture +def get_simple_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=1, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid + + +@pytest.fixture +def get_simple_global_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=0, # min longitude in [0, 360] + lenx=360, # longitude extent in degrees + ystart=-90, # min latitude in [-90, 90] + leny=180, # latitude extent in degrees + name="panama1", + cyclic_x=True, + ) + return grid + + +@pytest.fixture +def get_PM_seam_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=359, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=-1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid + + +@pytest.fixture +def get_dateline_seam_grid(): + grid = Grid( + resolution=1, # in degrees + xstart=-1, # min longitude in [0, 360] + lenx=2, # longitude extent in degrees + ystart=-1, # min latitude in [-90, 90] + leny=2, # latitude extent in degrees + name="panama1", + ) + return grid diff --git a/tests/test_mapping.py b/tests/test_mapping.py new file mode 100644 index 00000000..81407562 --- /dev/null +++ b/tests/test_mapping.py @@ -0,0 +1,86 @@ +from mom6_forge.mapping import _make_subgrid_points, regrid_with_subsampling +import numpy as np +import xarray as xr + + +def test_make_subgrid_points(get_simple_grid): + # Test with a simple 2x2 grid and 2 sub-points per cell + nx_sub = ny_sub = 2 + grid = get_simple_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + + expected_sub_lon = np.array( + [ + [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]], + [[[4 / 3, 5 / 3], [4 / 3, 5 / 3]], [[7 / 3, 8 / 3], [7 / 3, 8 / 3]]], + ] + ) + expected_sub_lat = np.array( + [ + [[[4 / 3, 4 / 3], [5 / 3, 5 / 3]], [[4 / 3, 4 / 3], [5 / 3, 5 / 3]]], + [[[7 / 3, 7 / 3], [8 / 3, 8 / 3]], [[7 / 3, 7 / 3], [8 / 3, 8 / 3]]], + ] + ) + + assert np.allclose( + sub_lon, expected_sub_lon + ), "Sub-grid longitudes do not match expected values." + assert np.allclose( + sub_lat, expected_sub_lat + ), "Sub-grid latitudes do not match expected values." + + +def test_smoke_seams_and_global_make_subgrid_points( + get_dateline_seam_grid, get_PM_seam_grid, get_simple_global_grid +): + # Test with a simple 2x2 grid and 2 sub-points per cell + nx_sub = ny_sub = 2 + grid = get_dateline_seam_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + grid = get_PM_seam_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + grid = get_simple_global_grid + sub_lon, sub_lat = _make_subgrid_points( + grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + + +def test_regrid_with_subsampling(get_simple_grid): + # Test with a simple 2x2 grid and 2 sub-points per cell with data that lands exactly on the sub points (subtracted by 0.1 to show snapping to sub points) + nx_sub = ny_sub = 2 + grid = get_simple_grid + lon = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + lat = [4 / 3, 5 / 3, 7 / 3, 8 / 3] + input_ds = xr.Dataset( + { + "data": ( + ["lon", "lat"], + [ + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + np.arange(1, 5, 1), + ], + ) + }, + coords={ + "lon": (["lon"], [x - 0.1 for x in lon]), + "lat": (["lat"], [x - 0.1 for x in lat]), + }, + ) + ds = regrid_with_subsampling( + input_ds, grid.qlon.values, grid.qlat.values, nx_sub, ny_sub + ) + assert ds["data"].shape == (2, 2, 2, 2), "Output shape is incorrect." + expected_data = np.array( + [[[[1, 1], [2, 2]], [[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[3, 3], [4, 4]]]] + ) + assert np.allclose( + ds["data"].values, expected_data + ), "Regridded data does not match expected values." diff --git a/tests/test_source_bathy.py b/tests/test_source_bathy.py new file mode 100644 index 00000000..a96b6b3f --- /dev/null +++ b/tests/test_source_bathy.py @@ -0,0 +1,105 @@ +"""Smoke test for SourceBathy loader.""" + +import numpy as np +import pytest +import tempfile +from pathlib import Path +import xarray as xr +from mom6_forge._source_bathy import SourceBathy + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() + + +def test_source_bathy_initialization(synthetic_bathy_file): + """Test SourceBathy initialization and coordinate names.""" + src = SourceBathy( + synthetic_bathy_file, + lon_name="lon", + lat_name="lat", + elevation_name="elevation", + ) + + assert src.path == Path(synthetic_bathy_file) + assert src.lon_name == "lon" + assert src.lat_name == "lat" + assert src.elevation_name == "elevation" + assert src._da is None # Not loaded yet + + +def test_source_bathy_slice_to_domain(get_rect_topo, synthetic_bathy_file): + """Smoke test: load and slice elevation to topo domain.""" + topo = get_rect_topo + + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.4) + + # Verify data was loaded + assert src._da is not None + assert src.lon is not None + assert src.lat is not None + + # Verify shape makes sense + assert len(src.lon) > 0, f"Expected lon data, got empty array" + assert len(src.lat) > 0, f"Expected lat data, got empty array" + assert src._da.shape == (len(src.lat), len(src.lon)) + + +def test_source_bathy_depth_conversion(get_rect_topo, synthetic_bathy_file): + """Test that elevation is converted to positive-down depth.""" + topo = get_rect_topo + + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + + # Get depth and verify sign conversion + depth = src.depth + + # Verify no NaNs in the result + assert not bool(np.isnan(depth).all()), "All depth values are NaN" + + # Verify positive depth values for ocean (elevation is negative) + non_nan_values = depth[~np.isnan(depth)] + assert len(non_nan_values) > 0, "No valid depth values" + assert np.any(non_nan_values > 0), "Expected positive depth values for ocean" + + assert depth.shape == src.da.shape diff --git a/tests/test_topo.py b/tests/test_topo.py index 5da340d8..7e89b0e2 100644 --- a/tests/test_topo.py +++ b/tests/test_topo.py @@ -1,4 +1,52 @@ +import numpy as np +import pytest +import tempfile +from pathlib import Path +import xarray as xr from mom6_forge.topo import * +from mom6_forge._source_bathy import SourceBathy + + +@pytest.fixture +def synthetic_bathy_file(): + """Create a temporary synthetic bathymetry NetCDF file for testing. + + Covers the Panama region (278-282°E, 7-10°N) to match get_rect_grid(). + """ + with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as tmp: + bathy_file = tmp.name + + # Create synthetic bathymetry covering the Panama region + # get_rect_grid uses xstart=278, lenx=4, ystart=7, leny=3 + # So we need lon 278-282, lat 7-10 with some buffer + lon = np.linspace(276, 284, 80) # Cover 278-282 with buffer + lat = np.linspace(5, 12, 70) # Cover 7-10 with buffer + + # Create synthetic elevation data (positive-up, like GEBCO) + # Ocean is negative (water), land is positive + elevation = np.full((len(lat), len(lon)), -500.0) # Ocean baseline = 500m deep + + # Add synthetic land masses (islands) + # Create an island around (280, 8.5) + lon_2d, lat_2d = np.meshgrid(lon, lat) + island_mask = (lon_2d - 280) ** 2 + (lat_2d - 8.5) ** 2 < 0.5 + elevation[island_mask] = 200.0 # Synthetic island + + ds = xr.Dataset( + { + "elevation": (["lat", "lon"], elevation), + }, + coords={ + "lon": lon, + "lat": lat, + }, + ) + ds.to_netcdf(bathy_file) + + yield bathy_file + + # Cleanup + Path(bathy_file).unlink() def test_topo_from_version_control(get_rect_topo): diff --git a/tests/test_topo_bathymetry_workflows.py b/tests/test_topo_bathymetry_workflows.py new file mode 100644 index 00000000..5a5dbf30 --- /dev/null +++ b/tests/test_topo_bathymetry_workflows.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest +from mom6_forge.topo import * +from mom6_forge._source_bathy import SourceBathy + + +def test_compute_topo_stats(get_rect_topo, synthetic_bathy_file): + """Test _compute_topo_stats: per-cell depth statistics via xesmf nearest neighbor regridding. + + This test validates the refactored _compute_topo_stats method which: + - Generates sub-points within each grid cell + - Uses xesmf nearest_s2d regridding to snap sub-points to nearest source data + - Computes per-cell statistics (OCN_FRAC, D_mean, D_min, D_max, D2_mean) + """ + topo = get_rect_topo + + # Load source bathymetry and slice to topo domain + src = SourceBathy(synthetic_bathy_file) + src.slice_to_domain(topo, buf=0.5) + topo._src = src + + # Test with different sub-sampling densities + for nx_sub, ny_sub in [(2, 2), (3, 3)]: + # Call _compute_topo_stats + stats = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + + # Verify output is a Dataset with expected variables + assert isinstance(stats, xr.Dataset) + required_vars = ["OCN_FRAC", "D_mean", "D_min", "D_max", "D2_mean"] + for var in required_vars: + assert var in stats.data_vars, f"Missing {var} in output" + + # Verify shapes match topo grid + expected_shape = (topo.depth.shape[0], topo.depth.shape[1]) + assert stats["OCN_FRAC"].shape == expected_shape + assert stats["D_mean"].shape == expected_shape + assert stats["D_min"].shape == expected_shape + assert stats["D_max"].shape == expected_shape + assert stats["D2_mean"].shape == expected_shape + + # Verify OCN_FRAC is between 0 and 1 + assert (stats["OCN_FRAC"] >= 0).all() + assert (stats["OCN_FRAC"] <= 1).all() + + # Verify D_min <= D_mean <= D_max + ocean_cells = stats["OCN_FRAC"].values > 0 + assert ( + stats["D_min"].values[ocean_cells] <= stats["D_mean"].values[ocean_cells] + ).all() + assert ( + stats["D_mean"].values[ocean_cells] <= stats["D_max"].values[ocean_cells] + ).all() + + # Verify caching: second call should return cached result + stats2 = topo._compute_topo_stats(nx_sub=nx_sub, ny_sub=ny_sub, mask_hmin=0.0) + # Should be the exact same object (cached) + assert stats2 is stats