Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions mom6_forge/_source_bathy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Source bathymetry loader for mom6_forge.

``SourceBathy`` is a lightweight data container for a regional slice of a
source bathymetry dataset. ``Topo._get_src()`` creates and caches one
automatically when ``set_from_dataset`` is called. Users who call pipeline
methods directly (e.g. ``high_res_regrid``, ``generate_mask_ocean_frac``)
should construct a ``SourceBathy`` explicitly::

from mom6_forge._source_bathy import SourceBathy
src = SourceBathy("gebco_2023.nc").slice_to_domain(topo)
"""

import numpy as np
import xarray as xr
from pathlib import Path
from mom6_forge.utils import longitude_slicer


class SourceBathy:
"""Regional slice of a source bathymetry dataset (e.g. GEBCO).

Holds the loaded, domain-clipped elevation DataArray together with its
coordinate-name metadata. Per-cell depth statistics are computed and
cached here so repeated calls with the same source file skip the
expensive sub-sampling step.

Parameters
----------
path : str or Path
lon_name : str — longitude coordinate name. Default ``"lon"``.
lat_name : str — latitude coordinate name. Default ``"lat"``.
elevation_name : str — elevation variable (positive-up). Default ``"elevation"``.
"""

def __init__(
self,
path,
lon_name="lon",
lat_name="lat",
elevation_name="elevation",
):
self.path = Path(path)
self.lon_name = lon_name
self.lat_name = lat_name
self.elevation_name = elevation_name
self._da = None # set by slice_to_domain
self._topo_stats = None # set by compute_topo_stats

# ------------------------------------------------------------------
# Loading
# ------------------------------------------------------------------

def slice_to_domain(self, topo, buf=0.5):
"""Load and clip elevation to the topo grid extent plus ``buf`` degrees.

Handles the global-longitude seam automatically. Mutates ``self``
in place and returns ``self`` for chaining.

Parameters
----------
topo : Topo — only ``topo._grid.qlon`` / ``topo._grid.qlat`` are used.
buf : float — degree buffer around the Q-grid bounding box. Default 0.5.
"""
lon_extent = (float(topo._grid.qlon.min()), float(topo._grid.qlon.max()))
lat_extent = (float(topo._grid.qlat.min()), float(topo._grid.qlat.max()))

ds_src = xr.open_dataset(self.path, chunks="auto")

ds = ds_src.sel(
{self.lat_name: slice(lat_extent[0] - buf, lat_extent[1] + buf)}
)

dlon = float(ds[self.lon_name][1] - ds[self.lon_name][0])
total_lon = float(ds[self.lon_name][-1] - ds[self.lon_name][0] + dlon)
if np.isclose(total_lon, 360):
ds = longitude_slicer(
ds,
np.array(lon_extent) + np.array([-buf, buf]),
self.lon_name,
)
else:
ds = ds.sel(
{self.lon_name: slice(lon_extent[0] - buf, lon_extent[1] + buf)}
)
self._ds = ds
self._da = ds[self.elevation_name].load()
return self

# ------------------------------------------------------------------
# Accessors
# ------------------------------------------------------------------

@property
def lon(self):
"""1-D longitude array."""
return self._da[self.lon_name].values

@property
def lat(self):
"""1-D latitude array."""
return self._da[self.lat_name].values

@property
def depth(self):
"""2-D depth array, positive-down (ocean > 0), shape (ny_src, nx_src)."""
return -self._da.values.astype(float)

@property
def da(self):
"""Raw elevation DataArray with source coordinate names (positive-up)."""
return self._da

@property
def ds(self):
"""Raw dataset with source coordinate names (positive-up)."""
return self._ds

def __repr__(self):
shape = self._da.shape if self._da is not None else "not loaded"
return (
f"SourceBathy({self.path.name!r}, lon={self.lon_name!r}, "
f"lat={self.lat_name!r}, elevation={self.elevation_name!r}, shape={shape})"
)
39 changes: 23 additions & 16 deletions mom6_forge/chl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from datetime import datetime
from pathlib import Path
from os.path import isfile
from mom6_forge.utils import fill_missing_data
from mom6_forge.utils import fill_missing_data, compute_subsampling_factor
from mom6_forge.mapping import regrid_with_subsampling


def interpolate_and_fill_seawifs(
Expand Down Expand Up @@ -64,7 +65,7 @@ def interpolate_and_fill_seawifs(
(np.arange(src_ni) + 0.5) / src_ni
) * 360.0 + src_x0 # Recompute as doubles

spr_lat, spr_lon = grid.tlat.values, grid.tlon.values
ny_sub, nx_sub = compute_subsampling_factor(src_nj, src_ni, ocn_nj, ocn_ni)

# Set output path
if output_path is None:
Expand All @@ -88,21 +89,27 @@ def interpolate_and_fill_seawifs(
# Iterate through time
for t in range(src_data.shape[0]):

# Bilinearly interpolate the source data onto the super-sampled grid
# adj lon to -180 to 180
adj_lon = spr_lon - 360
data = src_data[t, ::-1, :].values

src = {"lon": src_lon, "lat": src_lat}
dst = {"lon": adj_lon, "lat": spr_lat}
regridder = xe.Regridder(
src,
dst,
"bilinear",
filename="bilin_weights.nc",
reuse_weights=isfile("bilin_weights.nc"),
# Build source dataset for this timestep
src_ds = xr.Dataset(
{
"chlor_a": xr.DataArray(
src_data[t, ::-1, :].values,
dims=["lat", "lon"],
coords={"lat": src_lat, "lon": src_lon},
)
}
)

# Regrid to super-sampled sub-point grid and average back to model grid
q_sub = regrid_with_subsampling(
input_dataset=src_ds,
qlon=grid.qlon.values,
qlat=grid.qlat.values,
nx_sub=nx_sub,
ny_sub=ny_sub,
regridding_method="bilinear",
)
q_int = regridder(data)
q_int = q_sub["chlor_a"].mean(dim=["ny_sub", "nx_sub"]).values

# Fill any missing data
q = q_int * ocn_mask
Expand Down
142 changes: 142 additions & 0 deletions mom6_forge/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,148 @@ def regrid_dataset_via_xesmf(
return dataset


def _make_subgrid_points(qlon, qlat, nx_sub, ny_sub):
"""
Given corner coordinates of a low-res grid (qlon/qlat on the q-point staggering),
return sub-sampled lon/lat arrays of shape (ny, nx, ny_sub, nx_sub). (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python)

Parameters
----------
qlon, qlat : np.ndarray shape (ny+1, nx+1)
Corner (q-point) coordinates of the low-res grid.
nx_sub, ny_sub : int
Number of sub-points per cell in each direction.

Returns
-------
sub_lon, sub_lat : np.ndarray shape (ny, nx, ny_sub, nx_sub)
"""
assert type(qlon) == type(qlat) == np.ndarray, "qlon and qlat must be numpy arrays"

SW_lon = qlon[:-1, :-1]
SW_lat = qlat[:-1, :-1]
SE_lon = qlon[:-1, 1:]
SE_lat = qlat[:-1, 1:]
NE_lon = qlon[1:, 1:]
NE_lat = qlat[1:, 1:]
NW_lon = qlon[1:, :-1]
NW_lat = qlat[1:, :-1]

# Fix antimeridian-straddling cells
def _fix_period(lon, ref):
lon = np.where(lon - ref > 270, lon - 360, lon)
lon = np.where(lon - ref < -270, lon + 360, lon)
return lon

SW_lon = _fix_period(SW_lon, NE_lon)
SE_lon = _fix_period(SE_lon, NE_lon)
NW_lon = _fix_period(NW_lon, NE_lon)

ifrac = (np.arange(1, nx_sub + 1) / (nx_sub + 1)).astype(float)
jfrac = (np.arange(1, ny_sub + 1) / (ny_sub + 1)).astype(float)

i_ = ifrac[np.newaxis, np.newaxis, np.newaxis, :] # (1,1,1,nx_sub)
j_ = jfrac[np.newaxis, np.newaxis, :, np.newaxis] # (1,1,ny_sub,1)

# Broadcast all corners to (ny, nx, 1, 1),
SW_lon = SW_lon[:, :, np.newaxis, np.newaxis]
SE_lon = SE_lon[:, :, np.newaxis, np.newaxis]
NE_lon = NE_lon[:, :, np.newaxis, np.newaxis]
NW_lon = NW_lon[:, :, np.newaxis, np.newaxis]
SW_lat = SW_lat[:, :, np.newaxis, np.newaxis]
SE_lat = SE_lat[:, :, np.newaxis, np.newaxis]
NE_lat = NE_lat[:, :, np.newaxis, np.newaxis]
NW_lat = NW_lat[:, :, np.newaxis, np.newaxis]

sub_lon = (
(1 - i_) * (1 - j_) * SW_lon
+ i_ * (1 - j_) * SE_lon
+ i_ * j_ * NE_lon
+ (1 - i_) * j_ * NW_lon
)
sub_lat = (
(1 - i_) * (1 - j_) * SW_lat
+ i_ * (1 - j_) * SE_lat
+ i_ * j_ * NE_lat
+ (1 - i_) * j_ * NW_lat
)

return sub_lon, sub_lat


def regrid_with_subsampling(
input_dataset,
qlon,
qlat,
nx_sub,
ny_sub,
regridding_method="nearest_s2d",
):
"""
Regrids input_dataset to sub_sampled_grid to
properly analyze high-res source data into each coarse cell. (Originally created by Frank Bryan in Fortran for NCAR/tx2_3, reimplemented in Python)

Parameters
----------
input_dataset : xr.Dataset
qlon, qlat : np.ndarray shape (ny+1, nx+1)
Corner coordinates of the destination grid.
nx_sub, ny_sub : int
Number of sub-points per cell (typically from compute_subsampling_factor).
Returns
-------
regridded_dataset : xr.Dataset
Regridded dataset with dimensions (..., ny, nx, ny_sub, nx_sub), where the sub-sampling points are kept as separate dimensions. (User should perform stats calc)
"""
ny, nx = qlon.shape[0] - 1, qlon.shape[1] - 1

# Build the (ny, nx, ny_sub, nx_sub) sub-point grid
sub_lon, sub_lat = _make_subgrid_points(qlon, qlat, nx_sub, ny_sub)

# Flatten to (ny, nx*ny_sub*nx_sub) so xesmf sees a 2D destination
flat_lon = sub_lon.reshape(ny, nx * ny_sub * nx_sub)
flat_lat = sub_lat.reshape(ny, nx * ny_sub * nx_sub)

flat_output = xr.Dataset(
{
"lon": xr.DataArray(flat_lon, dims=["ny", "nx"]),
"lat": xr.DataArray(flat_lat, dims=["ny", "nx"]),
}
)

regridded_flat = regrid_dataset_via_xesmf(
input_dataset,
flat_output,
regridding_method=regridding_method,
write_to_file=False,
)

# Reshape to 4D, keeping sub-points as their own dimension
data_vars = {}
for var in regridded_flat.data_vars:
data = regridded_flat[var].values # (..., ny, nx*ny_sub*nx_sub)
reshaped = data.reshape(ny, nx, ny_sub, nx_sub)

original_dims = regridded_flat[var].dims
new_dims = (*original_dims, "ny_sub", "nx_sub")

data_vars[var] = xr.DataArray(
reshaped,
dims=new_dims,
attrs=regridded_flat[var].attrs,
)

coords = {
k: v
for k, v in regridded_flat.coords.items()
if "ny" not in v.dims and "nx" not in v.dims
}
coords["ny_sub"] = np.arange(ny_sub)
coords["nx_sub"] = np.arange(nx_sub)

return xr.Dataset(data_vars, coords=coords, attrs=input_dataset.attrs)


def main(args):

if args.parallel:
Expand Down
Loading
Loading