diff --git a/xarray/computation/computation.py b/xarray/computation/computation.py index 3df468f19a9..11b4fbbf18d 100644 --- a/xarray/computation/computation.py +++ b/xarray/computation/computation.py @@ -1000,18 +1000,28 @@ def _calc_idxminmax( # This will run argmin or argmax. index = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Handle chunked arrays (e.g. dask). - coord = array[dim]._variable.to_base_variable() - if is_chunked_array(array.data): - chunkmanager = get_chunked_array_type(array.data) - coord_array = chunkmanager.from_array( - array[dim].data, chunks=((array.sizes[dim],),) + coord_data = array[dim].data + if utils.is_allowed_extension_array(coord_data): + # Preserve extension-array-backed coordinates by reconstructing the + # selected labels directly instead of routing through Variable indexing. + data = duck_array_ops.reshape( + coord_data[duck_array_ops.ravel(index.data)], index.shape ) - coord = coord.copy(data=coord_array) + res = index.copy(data=data) + res.name = dim else: - coord = coord.copy(data=to_like_array(array[dim].data, array.data)) + # Handle chunked arrays (e.g. dask). + coord = array[dim]._variable.to_base_variable() + if is_chunked_array(array.data): + chunkmanager = get_chunked_array_type(array.data) + coord_array = chunkmanager.from_array( + coord_data, chunks=((array.sizes[dim],),) + ) + coord = coord.copy(data=coord_array) + else: + coord = coord.copy(data=to_like_array(coord_data, array.data)) - res = index._replace(coord[(index.variable,)]).rename(dim) + res = index._replace(coord[(index.variable,)]).rename(dim) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8eb52046a31..f1e068a5a9b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5623,6 +5623,15 @@ def test_argmax_dim( assert_identical(result2[key], expected2[key]) +def test_idxmax_intervalindex_coord() -> None: + idx = pd.IntervalIndex.from_breaks([0, 1, 2, 3]) + da = xr.DataArray([False, True, True], dims=["z"], coords={"z": idx}) + + expected = xr.DataArray(idx[1], name="z") + + assert_identical(da.idxmax(), expected) + + @pytest.mark.parametrize( ["x", "minindex", "maxindex", "nanindex"], [