Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def array_equivalent_object(
right: npt.NDArray[np.object_],
) -> bool: ...
def has_infs(arr: np.ndarray) -> bool: ... # const floating[:]
def has_nans(arr: np.ndarray) -> bool: ... # const floating[:]
def all_nans(arr: np.ndarray) -> bool: ... # const floating[:]
def array_equivalent_float(
left: np.ndarray,
right: np.ndarray,
) -> bool: ... # const floating[:]
def array_equivalent_bytes(
left: np.ndarray,
right: np.ndarray,
) -> bool: ...
def has_only_ints_or_nan(arr: np.ndarray) -> bool: ... # const floating[:]
def get_reverse_indexer(
indexer: np.ndarray, # const intp_t[:]
Expand Down
114 changes: 114 additions & 0 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ from cython cimport (
Py_ssize_t,
floating,
)
from libc.string cimport memcmp

from pandas._config import using_string_dtype

Expand Down Expand Up @@ -497,6 +498,119 @@ def has_infs(const floating[:] arr) -> bool:
return ret


@cython.wraparound(False)
@cython.boundscheck(False)
def has_nans(const floating[:] arr) -> bool:
"""
Faster equivalent to ``np.isnan(arr).any()``; exits on the first NaN found.
"""
cdef:
Py_ssize_t i, n = len(arr)
Py_ssize_t n4 = n & ~3 # round down to multiple of 4
bint found = False

with nogil:
for i in range(0, n4, 4):
if (
(arr[i] != arr[i])
| (arr[i + 1] != arr[i + 1])
| (arr[i + 2] != arr[i + 2])
| (arr[i + 3] != arr[i + 3])
):
found = True
break
if not found:
for i in range(n4, n):
if arr[i] != arr[i]:
found = True
break
return found


@cython.wraparound(False)
@cython.boundscheck(False)
def all_nans(const floating[:] arr) -> bool:
"""
Faster equivalent to ``np.isnan(arr).all()``; exits on the first non-NaN found.
"""
cdef:
Py_ssize_t i, n = len(arr)
Py_ssize_t n4 = n & ~3
bint found_non_nan = False

with nogil:
for i in range(0, n4, 4):
if (
(arr[i] == arr[i])
| (arr[i + 1] == arr[i + 1])
| (arr[i + 2] == arr[i + 2])
| (arr[i + 3] == arr[i + 3])
):
found_non_nan = True
break
if not found_non_nan:
for i in range(n4, n):
if arr[i] == arr[i]:
found_non_nan = True
break
return not found_non_nan


@cython.wraparound(False)
@cython.boundscheck(False)
def array_equivalent_float(const floating[:] left,
const floating[:] right) -> bool:
"""
Faster equivalent to ``((left == right) | (isnan(left) & isnan(right))).all()``;
exits on the first mismatch. Caller is responsible for checking shapes match.
"""
cdef:
Py_ssize_t i, n = len(left)
floating lval, rval
bint mismatch = False

with nogil:
for i in range(n):
lval = left[i]
rval = right[i]
if lval != rval:
if not (lval != lval and rval != rval):
mismatch = True
break
return not mismatch


def array_equivalent_bytes(left, right) -> bool:
"""
Faster equivalent to ``np.array_equal(left, right)`` via ``memcmp`` on
C-contiguous inputs. Not safe for dtypes where distinct bit patterns can
represent the same value (e.g. floats with -0.0/+0.0 or NaN) or for arrays
that contain object pointers.
Comment on lines +586 to +588
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding asserts for the common unsafe ones (float/complex/object I think)?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think so; the caller is responsible for gating

"""
cdef:
Py_ssize_t nbytes
int ndim, idx
ndarray left_arr, right_arr

left_arr = np.asarray(left)
right_arr = np.asarray(right)

ndim = cnp.PyArray_NDIM(left_arr)
if ndim != cnp.PyArray_NDIM(right_arr):
return False
for idx in range(ndim):
if cnp.PyArray_DIM(left_arr, idx) != cnp.PyArray_DIM(right_arr, idx):
return False
if not (cnp.PyArray_IS_C_CONTIGUOUS(left_arr)
and cnp.PyArray_IS_C_CONTIGUOUS(right_arr)):
return np.array_equal(left_arr, right_arr)
nbytes = cnp.PyArray_NBYTES(left_arr)
if nbytes == 0:
return True
return memcmp(cnp.PyArray_DATA(left_arr), cnp.PyArray_DATA(right_arr),
<size_t>nbytes) == 0


@cython.boundscheck(False)
@cython.wraparound(False)
def has_only_ints_or_nan(const floating[:] arr) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2626,7 +2626,7 @@ def equals(self, other: object) -> bool:
return False
elif self._categories_match_up_to_permutation(other):
other = self._encode_with_my_categories(other)
return np.array_equal(self._codes, other._codes)
return lib.array_equivalent_bytes(self._codes, other._codes)
return False

def _accumulate(self, name: str, skipna: bool = True, **kwargs) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,7 +1946,7 @@ def _validate_frequency(cls, index, freq: BaseOffset, **kwargs) -> None:
unit=index.unit,
**kwargs,
)
if not np.array_equal(index.asi8, on_freq.asi8):
if not lib.array_equivalent_bytes(index.asi8, on_freq.asi8):
raise ValueError
except ValueError as err:
if "non-fixed" in str(err):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def equals(self, other) -> bool:

# GH#44382 if e.g. self[1] is np.nan and other[1] is pd.NA, we are NOT
# equal.
if not np.array_equal(self._mask, other._mask):
if not lib.array_equivalent_bytes(self._mask, other._mask):
return False

left = self._data[~self._mask]
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _from_sequence(

arrdata = np.asarray(scalars)
if arrdata.dtype.kind == "f" and len(arrdata) > 0:
if not np.isnan(arrdata).all():
if not lib.all_nans(arrdata):
raise TypeError(
"PeriodArray does not allow floating point in construction"
)
Expand Down
24 changes: 21 additions & 3 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def array_equivalent(
# TODO: fastpath for pandas' StringDtype
return _array_equivalent_object(left, right, strict_nan)
else:
return np.array_equal(left, right)
return lib.array_equivalent_bytes(left, right)

# Slow path when we allow comparing different dtypes.
# Object arrays can contain None, NaN and NaT.
Expand Down Expand Up @@ -477,15 +477,33 @@ def array_equivalent(
) and left.dtype != right.dtype:
return False

if left.dtype == right.dtype and left.dtype.kind != "V":
return lib.array_equivalent_bytes(left, right)
return np.array_equal(left, right)


def _array_equivalent_float(left: np.ndarray, right: np.ndarray) -> bool:
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())
if left.dtype.kind == "c":
if not (left.flags.c_contiguous and right.flags.c_contiguous):
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())
# View complex as float pairs (complex128 -> float64, complex64 -> float32)
float_dtype = np.finfo(left.dtype).dtype
left = left.view(float_dtype)
right = right.view(float_dtype)
if left.ndim > 1:
if left.flags.f_contiguous and right.flags.f_contiguous:
# .T is a C-contiguous view of an F-contiguous array
left = left.T
right = right.T
if not (left.flags.c_contiguous and right.flags.c_contiguous):
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())
left = left.ravel()
right = right.ravel()
return lib.array_equivalent_float(left, right)


def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray) -> bool:
return np.array_equal(left.view("i8"), right.view("i8"))
return lib.array_equivalent_bytes(left.view("i8"), right.view("i8"))


def _array_equivalent_object(
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def equals(self, other: Any) -> bool:
if type(self) != type(other):
return False
elif self.dtype == other.dtype:
return np.array_equal(self.asi8, other.asi8)
return lib.array_equivalent_bytes(self.asi8, other.asi8)
elif (self.dtype.kind == "M" and self.tz == other.tz) or self.dtype.kind == "m": # type: ignore[attr-defined]
# different units, otherwise matching
try:
Expand All @@ -379,7 +379,7 @@ def equals(self, other: Any) -> bool:
except (OutOfBoundsDatetime, OutOfBoundsTimedelta):
return False
else:
return np.array_equal(left.view("i8"), right.view("i8"))
return lib.array_equivalent_bytes(left.view("i8"), right.view("i8"))
return False

def __contains__(self, key: Any) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4360,14 +4360,14 @@ def equals(self, other: object) -> bool:
other_codes = other.codes[i]
self_mask = self_codes == -1
other_mask = other_codes == -1
if not np.array_equal(self_mask, other_mask):
if not lib.array_equivalent_bytes(self_mask, other_mask):
return False
self_level = self.levels[i]
other_level = other.levels[i]
new_codes = recode_for_categories(
other_codes, other_level, self_level, copy=False
)
if not np.array_equal(self_codes, new_codes):
if not lib.array_equivalent_bytes(self_codes, new_codes):
return False
if not self_level[:0].equals(other_level[:0]):
# e.g. Int64 != int64
Expand Down
Loading