diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ec3d8fa171b..a29bf4f5942 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -3,7 +3,7 @@ import copy import math from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload import numpy as np @@ -39,8 +39,6 @@ PostPersistCallable: Any # type: ignore[no-redef] # T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]") - DimsInput = Union[str, Iterable[Hashable]] - Dims = tuple[Hashable, ...] AttrsInput = Union[Mapping[Any, Any], None] @@ -75,7 +73,10 @@ def as_compatible_data( return cast(T_DuckArray, np.asarray(data)) -class NamedArray(Generic[T_DuckArray]): +T_Dim = TypeVar("T_Dim", bound=Hashable) + + +class NamedArray(Generic[T_Dim, T_DuckArray]): """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, @@ -84,20 +85,60 @@ class NamedArray(Generic[T_DuckArray]): __slots__ = ("_data", "_dims", "_attrs") _data: T_DuckArray - _dims: Dims + _dims: tuple[T_Dim, ...] _attrs: dict[Any, Any] | None + @overload + def __init__( + self: NamedArray[str, T_DuckArray], + dims: str, + data: T_DuckArray, + attrs: AttrsInput = None, + fastpath: bool = False, + ) -> None: + ... + + @overload + def __init__( + self: NamedArray[str, np.ndarray[Any, np.dtype[np.generic]]], + dims: str, + data: np.typing.ArrayLike, + attrs: AttrsInput = None, + fastpath: bool = False, + ) -> None: + ... + + @overload + def __init__( + self: NamedArray[T_Dim, T_DuckArray], + dims: Iterable[T_Dim], + data: T_DuckArray, + attrs: AttrsInput = None, + fastpath: bool = False, + ) -> None: + ... + + @overload + def __init__( + self: NamedArray[T_Dim, np.ndarray[Any, np.dtype[np.generic]]], + dims: Iterable[T_Dim], + data: np.typing.ArrayLike, + attrs: AttrsInput = None, + fastpath: bool = False, + ) -> None: + ... + def __init__( self, - dims: DimsInput, + dims: str | Iterable[T_Dim], data: T_DuckArray | np.typing.ArrayLike, attrs: AttrsInput = None, fastpath: bool = False, - ): + ) -> None: """ Parameters ---------- - dims : str or iterable of str + dims : str or iterable of hashable Name(s) of the dimension(s). data : T_DuckArray or np.typing.ArrayLike The actual data that populates the array. Should match the shape specified by `dims`. @@ -194,22 +235,22 @@ def nbytes(self) -> int: return self.size * self.dtype.itemsize @property - def dims(self) -> Dims: + def dims(self) -> tuple[T_Dim, ...]: """Tuple of dimension names with which this NamedArray is associated.""" return self._dims @dims.setter - def dims(self, value: DimsInput) -> None: + def dims(self, value: str | Iterable[T_Dim]) -> None: self._dims = self._parse_dimensions(value) - def _parse_dimensions(self, dims: DimsInput) -> Dims: - dims = (dims,) if isinstance(dims, str) else tuple(dims) - if len(dims) != self.ndim: + def _parse_dimensions(self, dims: str | Iterable[T_Dim]) -> tuple[T_Dim, ...]: + pdims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(pdims) != self.ndim: raise ValueError( - f"dimensions {dims} must have the same length as the " + f"dimensions {pdims} must have the same length as the " f"number of data dimensions, ndim={self.ndim}" ) - return dims + return pdims # type: ignore[return-value] @property def attrs(self) -> dict[Any, Any]: @@ -397,7 +438,7 @@ def sizes(self) -> dict[Hashable, int]: def _replace( self, - dims: DimsInput | Default = _default, + dims: str | Iterable[T_Dim] | Default = _default, data: T_DuckArray | np.typing.ArrayLike | Default = _default, attrs: AttrsInput | Default = _default, ) -> Self: diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index ea1588bf554..00167805620 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -91,7 +91,7 @@ class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed): def test_properties() -> None: data = 0.5 * np.arange(10).reshape(2, 5) - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray(["x", "y"], data, {"key": "value"}) assert named_array.dims == ("x", "y") assert np.array_equal(named_array.data, data) @@ -104,7 +104,7 @@ def test_properties() -> None: def test_attrs() -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) assert named_array.attrs == {} named_array.attrs["key"] = "value" @@ -114,7 +114,7 @@ def test_attrs() -> None: def test_data(random_inputs: np.ndarray[Any, Any]) -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray(["x", "y", "z"], random_inputs) assert np.array_equal(named_array.data, random_inputs) with pytest.raises(ValueError): @@ -130,7 +130,7 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None: ], ) def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray([], data) assert named_array.data == data assert named_array.dims == () @@ -142,7 +142,7 @@ def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None: def test_0d_object() -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray([], (10, 12, 12)) expected_data = np.empty((), dtype=object) expected_data[()] = (10, 12, 12) @@ -157,7 +157,7 @@ def test_0d_object() -> None: def test_0d_datetime() -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray([], np.datetime64("2000-01-01")) assert named_array.dtype == np.dtype("datetime64[D]") @@ -179,7 +179,7 @@ def test_0d_datetime() -> None: def test_0d_timedelta( timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] ) -> None: - named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]] + named_array: NamedArray[str, np.ndarray[Any, np.dtype[np.timedelta64]]] named_array = NamedArray([], timedelta) assert named_array.dtype == expected_dtype assert named_array.data == timedelta @@ -196,7 +196,7 @@ def test_0d_timedelta( ], ) def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None: - named_array: NamedArray[np.ndarray[Any, Any]] + named_array: NamedArray[str, np.ndarray[Any, Any]] named_array = NamedArray(dims, np.random.random(data_shape)) assert named_array.dims == tuple(dims) if raises: