Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions mom6_forge/channel_width.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
A barebones set of classes to manage channel width constraints for MOM6 grids. It is effectively a list that can be applied on top of the bathymetry, and is not git-backed.
The channels represent a separate concern and can be independently shared, or managed outside the topography versioning system.
It is a basically a dict of info for each channel and then a list wrapped with write/load functions.
"""

from dataclasses import dataclass
from typing import List, Optional
from pathlib import Path


@dataclass
class ChannelWidth:
"""Single channel width constraint for MOM6 grid"""

component: str # 'U_width' or 'V_width'
lon1: float
lon2: float
lat1: float
lat2: float
width: float # meters
place: str # comment/location name

def __post_init__(self):
"""Validate component is U_width or V_width"""
if self.component not in ("U_width", "V_width"):
raise ValueError(
f"component must be 'U_width' or 'V_width', got '{self.component}'"
)


class ChannelWidthList:
"""
Manages list of channel width constraints.

Note: Channel widths are NOT git-backed. These are additive configuration constraints
applied on top of the bathymetry, not edits to the bathymetry itself. They represent
a separate concern and can be independently shared, or managed outside the
topography versioning system.
"""

FMT_OUT = "{0:s}, {1:8.2f}, {2:8.2f}, {3:8.2f}, {4:8.2f}, {5:10.1f} ! {6:s}\n"

def __init__(self, filepath: Optional[str | Path] = None):
self.channels: List[ChannelWidth] = []
if filepath is not None:
self.load(filepath)

def add(self, channel: ChannelWidth):
"""Add a channel width constraint"""
self.channels.append(channel)

def get_all(self) -> List[ChannelWidth]:
"""Get all channels"""
return self.channels

def write(self, filepath: str | Path):
"""Persist to ASCII file"""
with open(filepath, "w") as f:
for ch in self.channels:
line = self.FMT_OUT.format(
ch.component, ch.lon1, ch.lon2, ch.lat1, ch.lat2, ch.width, ch.place
)
f.write(line)

def load(self, filepath: str | Path):
"""Load from ASCII file"""
filepath = Path(filepath)
if filepath.exists():
with open(filepath) as f:
for line in f:
line = line.strip()
if not line:
continue
# Parse: "U_width, -6.50, -4.75, 35.60, 36.30, 12000.0 ! St. of Gibralter"
parts = line.split("!")
comment = parts[1].strip() if len(parts) > 1 else ""

vals = parts[0].split(",")
component = vals[0].strip()
lon1 = float(vals[1].strip())
lon2 = float(vals[2].strip())
lat1 = float(vals[3].strip())
lat2 = float(vals[4].strip())
width = float(vals[5].strip())

self.channels.append(
ChannelWidth(
component=component,
lon1=lon1,
lon2=lon2,
lat1=lat1,
lat2=lat2,
width=width,
place=comment,
)
)
63 changes: 58 additions & 5 deletions mom6_forge/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mom6_forge.edit_command import *
from mom6_forge.command_manager import TopoCommandManager, CommandType
from mom6_forge.mapping import regrid_dataset_via_xesmf
from mom6_forge.channel_width import ChannelWidthList
import regionmask


Expand All @@ -22,7 +23,14 @@ class Topo:
Bathymetry Generator for MOM6 grids (mom6_forge.grid.Grid).
"""

def __init__(self, grid, min_depth, version_control_dir="TopoLibrary", git=True):
def __init__(
self,
grid,
min_depth,
channel_widths=None,
version_control_dir="TopoLibrary",
git=True,
):
"""
MOM6 Simpler Models bathymetry constructor.

Expand All @@ -32,6 +40,8 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary", git=True)
horizontal grid instance for which the bathymetry is to be created.
min_depth: float
Minimum water column depth. Columns with shallow depths are to be masked out.
channel_widths: str | Path | ChannelWidthList, optional
Channel width constraints. Can be a filepath to load from, a ChannelWidthList object, or None.
version_control_dir: str, optional
Directory in which to store version-controlled bathymetry data. Defaults to
"TopoLibrary". Ignored if git is False (version control is no longer used)
Expand All @@ -56,6 +66,19 @@ def __init__(self, grid, min_depth, version_control_dir="TopoLibrary", git=True)
initial_command = MinDepthEditCommand(
self, attr="min_depth", new_value=min_depth
)

# Initialize channel widths
if channel_widths is None:
self.channel_widths = ChannelWidthList()
elif isinstance(channel_widths, ChannelWidthList):
self.channel_widths = channel_widths
else:
# Assume it's a filepath
self.channel_widths = ChannelWidthList(filepath=channel_widths)

initial_command = MinDepthEditCommand(
self, attr="min_depth", new_value=min_depth
)
if git:

# Create a folder to store bathymetry objects in
Expand Down Expand Up @@ -103,21 +126,26 @@ def __getitem__(self, slices):

new_grid = self._grid[slices]
new_topo = Topo(
new_grid, self._min_depth, git=self.has_version_control
new_grid,
self._min_depth,
git=self.has_version_control,
channel_widths=self.channel_widths,
) # Create new topo with the same version control setting
if self._depth is not None:
new_topo._depth = self._depth[slices]
return new_topo

@classmethod
def from_version_control(cls, folder_path: str | Path):
def from_version_control(cls, folder_path: str | Path, channel_widths=None):
"""
Create a bathymetry object from an existing version-controlled bathymetry folder.

Parameters
----------
folder_path: str | Path
Path to an existing bathymetry folder created by mom6_forge with version control enabled.
channel_widths: str | Path | ChannelWidthList, optional
Channel width constraints. Can be a filepath to load from, a ChannelWidthList object, or None.
"""

folder_path = Path(folder_path)
Expand All @@ -130,7 +158,10 @@ def from_version_control(cls, folder_path: str | Path):

# Create the topo object
topo = Topo(
grid, 0.0, version_control_dir=folder_path.parent
grid,
0.0,
version_control_dir=folder_path.parent,
channel_widths=channel_widths,
) # Because we hash the grid, the correct domain will be selected

# Reapply any changes
Expand All @@ -148,6 +179,7 @@ def from_topo_file(
varname="depth",
version_control_dir="TopoLibrary",
git=True,
channel_widths=None,
):
"""
Create a bathymetry object from an existing topog file.
Expand All @@ -164,9 +196,19 @@ def from_topo_file(
Name of the variable representing ocean depth in the dataset. Default is "depth".
git: bool, optional
Passed through to Topo.__init__. See Topo docstring for details.
version_control_dir: str, optional
Directory for version control. Default is "TopoLibrary".
channel_widths: str | Path | ChannelWidthList, optional
Channel width constraints. Can be a filepath to load from, a ChannelWidthList object, or None.
"""

topo = cls(grid, min_depth, version_control_dir=version_control_dir, git=git)
topo = cls(
grid,
min_depth,
version_control_dir=version_control_dir,
channel_widths=channel_widths,
git=git,
)
if topo.tcm is not None:
topo.tcm.reapply_changes()
topo.set_depth_via_topog_file(topo_file_path, varname)
Expand Down Expand Up @@ -1533,8 +1575,19 @@ def write_topo(self, file_path, title=None):
Path to TOPO_FILE to be written.
title: str, optional
File title.

Note
----
If channel_widths is not empty, remember to also write those constraints using
channel_widths.write(channel_file_path).
"""

if self.channel_widths.get_all():
print(
"Note: Channel widths are defined. Remember to write them with "
"channel_widths.write(filepath)"
)

ds = self.gen_topo_ds(title=title)
ds.to_netcdf(file_path, format="NETCDF3_64BIT")

Expand Down
86 changes: 86 additions & 0 deletions tests/test_channel_width.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from mom6_forge.channel_width import ChannelWidth, ChannelWidthList


def test_channel_width_validation():
"""Test that ChannelWidth validates component is U_width or V_width"""
# Valid cases
ch_u = ChannelWidth(
component="U_width",
lon1=-6.5,
lon2=-4.75,
lat1=35.6,
lat2=36.3,
width=12000.0,
place="St. of Gibralter",
)
assert ch_u.component == "U_width"

ch_v = ChannelWidth(
component="V_width",
lon1=28.75,
lon2=29.5,
lat1=41.1,
lat2=41.3,
width=5000.0,
place="Bosphorus",
)
assert ch_v.component == "V_width"

# Invalid component
with pytest.raises(ValueError, match="component must be 'U_width' or 'V_width'"):
ChannelWidth(
component="X_width",
lon1=0.0,
lon2=1.0,
lat1=0.0,
lat2=1.0,
width=1000.0,
place="Invalid",
)


def test_channel_width_list_write_load(tmp_path):
"""Test write and load roundtrip for ChannelWidthList"""
# Create a list with some channels
channels = ChannelWidthList()
channels.add(
ChannelWidth(
component="U_width",
lon1=-6.5,
lon2=-4.75,
lat1=35.6,
lat2=36.3,
width=12000.0,
place="St. of Gibralter",
)
)
channels.add(
ChannelWidth(
component="V_width",
lon1=28.75,
lon2=29.5,
lat1=41.1,
lat2=41.3,
width=5000.0,
place="Bosphorus",
)
)

# Write to file
output_file = tmp_path / "channels.txt"
channels.write(output_file)
assert output_file.exists()

# Load from file
loaded_channels = ChannelWidthList(filepath=output_file)
assert len(loaded_channels.get_all()) == 2

# Verify content
all_channels = loaded_channels.get_all()
assert all_channels[0].component == "U_width"
assert all_channels[0].width == 12000.0
assert all_channels[0].place == "St. of Gibralter"
assert all_channels[1].component == "V_width"
assert all_channels[1].width == 5000.0
assert all_channels[1].place == "Bosphorus"
Loading