Skip to content

Commit

Permalink
BUG: Period.__eq__ numpy scalar (#44182 (comment)) (#44285)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause authored Nov 4, 2021
1 parent 669acb4 commit a3bcbf8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pandas/_libs/tslibs/period.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1657,8 +1657,12 @@ cdef class _Period(PeriodMixin):
elif other is NaT:
return _nat_scalar_rules[op]
elif util.is_array(other):
# in particular ndarray[object]; see test_pi_cmp_period
return np.array([PyObject_RichCompare(self, x, op) for x in other])
# GH#44285
if cnp.PyArray_IsZeroDim(other):
return PyObject_RichCompare(self, other.item(), op)
else:
# in particular ndarray[object]; see test_pi_cmp_period
return np.array([PyObject_RichCompare(self, x, op) for x in other])
return NotImplemented

def __hash__(self):
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/arithmetic/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def test_pi_cmp_period(self):
result = idx.values.reshape(10, 2) < idx[10]
tm.assert_numpy_array_equal(result, exp.reshape(10, 2))

# Tests Period.__richcmp__ against ndarray[object, ndim=0]
result = idx < np.array(idx[10])
tm.assert_numpy_array_equal(result, exp)

# TODO: moved from test_datetime64; de-duplicate with version below
def test_parr_cmp_period_scalar2(self, box_with_array):
xbox = get_expected_box(box_with_array)
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/scalar/period/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,16 @@ def test_period_cmp_nat(self):
assert not left <= right
assert not left >= right

@pytest.mark.parametrize(
"zerodim_arr, expected",
((np.array(0), False), (np.array(Period("2000-01", "M")), True)),
)
def test_comparison_numpy_zerodim_arr(self, zerodim_arr, expected):
p = Period("2000-01", "M")

assert (p == zerodim_arr) is expected
assert (zerodim_arr == p) is expected


class TestArithmetic:
def test_sub_delta(self):
Expand Down

0 comments on commit a3bcbf8

Please sign in to comment.