Skip to content

Commit

Permalink
Allow where() to work with a Series and other=cudf.NA (#9019)
Browse files Browse the repository at this point in the history
Fixes #8969.

Duplicate of #8977 - some of the checks are erroring and I'm seeing strange messages about the git commits, so I'm re-opening the PR here to see if that fixes it.

Authors:
  - Sarah Yurick (https://github.com/sarahyurick)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)

URL: #9019
  • Loading branch information
sarahyurick authored Aug 11, 2021
1 parent 4968a96 commit 7461b20
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/cudf/cudf/core/_internals/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def _normalize_scalars(col: ColumnBase, other: ScalarLike) -> ScalarLike:
f"{type(other).__name__} to {col.dtype.name}"
)

return cudf.Scalar(other, dtype=col.dtype if other is None else None)
return cudf.Scalar(
other, dtype=col.dtype if other in {None, cudf.NA} else None
)


def _check_and_cast_columns_with_other(
Expand Down Expand Up @@ -234,9 +236,15 @@ def where(

if isinstance(frame, DataFrame):
if hasattr(cond, "__cuda_array_interface__"):
cond = DataFrame(
cond, columns=frame._column_names, index=frame.index
)
if isinstance(cond, Series):
cond = DataFrame(
{name: cond for name in frame._column_names},
index=frame.index,
)
else:
cond = DataFrame(
cond, columns=frame._column_names, index=frame.index
)
elif (
hasattr(cond, "__array_interface__")
and cond.__array_interface__["shape"] != frame.shape
Expand Down
20 changes: 20 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8731,3 +8731,23 @@ def test_frame_series_where():
expected = gdf.where(gdf.notna(), gdf.mean())
actual = pdf.where(pdf.notna(), pdf.mean(), axis=1)
assert_eq(expected, actual)


@pytest.mark.parametrize(
"data", [{"a": [1, 2, 3], "b": [1, 1, 0]}],
)
def test_frame_series_where_other(data):
gdf = cudf.DataFrame(data)
pdf = gdf.to_pandas()

expected = gdf.where(gdf["b"] == 1, cudf.NA)
actual = pdf.where(pdf["b"] == 1, pd.NA)
assert_eq(
actual.fillna(-1).values,
expected.fillna(-1).values,
check_dtype=False,
)

expected = gdf.where(gdf["b"] == 1, 0)
actual = pdf.where(pdf["b"] == 1, 0)
assert_eq(expected, actual)
2 changes: 2 additions & 0 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ def _can_cast(from_dtype, to_dtype):
`np.can_cast` but with some special handling around
cudf specific dtypes.
"""
if from_dtype in {None, cudf.NA}:
return True
if isinstance(from_dtype, type):
from_dtype = np.dtype(from_dtype)
if isinstance(to_dtype, type):
Expand Down

0 comments on commit 7461b20

Please sign in to comment.