Skip to content

Commit

Permalink
Fix various issues with replace API and add support in datetime a…
Browse files Browse the repository at this point in the history
…nd `timedelta` columns (#17331)

This PR:

- [x] Adds support for `find_and_replace` in `DateTimeColumn` and `TimeDeltaColumn`, such that when `.replace` is called on a series or dataframe with these columns, we don't error and replace the values correctly.
- [x] Fixed various type combination edge cases that were previously incorrectly handled and updated stale tests associated with them.
- [x] Added a small parquet file in pytests that has multiple rows that uncovered these bugs.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17331
  • Loading branch information
galipremsagar authored Nov 15, 2024
1 parent d67d017 commit d475dca
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 55 deletions.
1 change: 1 addition & 0 deletions python/cudf/cudf/core/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@


PANDAS_GE_210 = PANDAS_VERSION >= version.parse("2.1.0")
PANDAS_GT_214 = PANDAS_VERSION > version.parse("2.1.4")
PANDAS_GE_220 = PANDAS_VERSION >= version.parse("2.2.0")
PANDAS_LT_300 = PANDAS_VERSION < version.parse("3.0.0")
21 changes: 20 additions & 1 deletion python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.dtypes import _get_base_dtype
from cudf.utils.utils import _all_bools_with_nulls
from cudf.utils.utils import (
_all_bools_with_nulls,
_datetime_timedelta_find_and_replace,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -630,6 +633,22 @@ def quantile(
)
return result.astype(self.dtype)

def find_and_replace(
self,
to_replace: ColumnBase,
replacement: ColumnBase,
all_nan: bool = False,
) -> DatetimeColumn:
return cast(
DatetimeColumn,
_datetime_timedelta_find_and_replace(
original_column=self,
to_replace=to_replace,
replacement=replacement,
all_nan=all_nan,
),
)

def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
reflect, op = self._check_reflected_op(op)
other = self._wrap_binop_normalization(other)
Expand Down
47 changes: 34 additions & 13 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,24 +511,41 @@ def find_and_replace(
):
return self.copy()

to_replace_col = _normalize_find_and_replace_input(
self.dtype, to_replace
)
try:
to_replace_col = _normalize_find_and_replace_input(
self.dtype, to_replace
)
except TypeError:
# if `to_replace` cannot be normalized to the current dtype,
# that means no value of `to_replace` is present in self,
# Hence there is no point of proceeding further.
return self.copy()

if all_nan:
replacement_col = column.as_column(replacement, dtype=self.dtype)
else:
replacement_col = _normalize_find_and_replace_input(
self.dtype, replacement
)
try:
replacement_col = _normalize_find_and_replace_input(
self.dtype, replacement
)
except TypeError:
# Some floating values can never be converted into signed or unsigned integers
# for those cases, we just need a column of `replacement` constructed
# with its own type for the final type determination below at `find_common_type`
# call.
replacement_col = column.as_column(
replacement,
dtype=self.dtype if len(replacement) <= 0 else None,
)
common_type = find_common_type(
(to_replace_col.dtype, replacement_col.dtype, self.dtype)
)
if len(replacement_col) == 1 and len(to_replace_col) > 1:
replacement_col = column.as_column(
replacement[0], length=len(to_replace_col), dtype=self.dtype
replacement[0], length=len(to_replace_col), dtype=common_type
)
elif len(replacement_col) == 1 and len(to_replace_col) == 0:
return self.copy()
common_type = find_common_type(
(to_replace_col.dtype, replacement_col.dtype, self.dtype)
)
replaced = self.astype(common_type)
df = cudf.DataFrame._from_data(
{
Expand Down Expand Up @@ -718,6 +735,8 @@ def _normalize_find_and_replace_input(
if isinstance(col_to_normalize, list):
if normalized_column.null_count == len(normalized_column):
normalized_column = normalized_column.astype(input_column_dtype)
if normalized_column.can_cast_safely(input_column_dtype):
return normalized_column.astype(input_column_dtype)
col_to_normalize_dtype = min_column_type(
normalized_column, input_column_dtype
)
Expand All @@ -728,7 +747,7 @@ def _normalize_find_and_replace_input(
if np.isinf(col_to_normalize[0]):
return normalized_column
col_to_normalize_casted = np.array(col_to_normalize[0]).astype(
input_column_dtype
col_to_normalize_dtype
)

if not np.isnan(col_to_normalize_casted) and (
Expand All @@ -739,8 +758,8 @@ def _normalize_find_and_replace_input(
f"{col_to_normalize[0]} "
f"to {input_column_dtype.name}"
)
else:
col_to_normalize_dtype = input_column_dtype
if normalized_column.can_cast_safely(col_to_normalize_dtype):
return normalized_column.astype(col_to_normalize_dtype)
elif hasattr(col_to_normalize, "dtype"):
col_to_normalize_dtype = col_to_normalize.dtype
else:
Expand All @@ -755,6 +774,8 @@ def _normalize_find_and_replace_input(
f"{col_to_normalize_dtype.name} "
f"to {input_column_dtype.name}"
)
if not normalized_column.can_cast_safely(input_column_dtype):
return normalized_column
return normalized_column.astype(input_column_dtype)


Expand Down
53 changes: 51 additions & 2 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from cudf.core.buffer import Buffer, acquire_spill_lock
from cudf.core.column import ColumnBase, column, string
from cudf.utils.dtypes import np_to_pa_dtype
from cudf.utils.utils import _all_bools_with_nulls
from cudf.utils.utils import (
_all_bools_with_nulls,
_datetime_timedelta_find_and_replace,
)

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -95,7 +98,7 @@ def __init__(
size = data.size // dtype.itemsize
size = size - offset
if len(children) != 0:
raise ValueError("TimedeltaColumn must have no children.")
raise ValueError("TimeDeltaColumn must have no children.")
super().__init__(
data=data,
size=size,
Expand Down Expand Up @@ -306,6 +309,52 @@ def as_timedelta_column(self, dtype: Dtype) -> TimeDeltaColumn:
return self
return libcudf.unary.cast(self, dtype=dtype)

def find_and_replace(
self,
to_replace: ColumnBase,
replacement: ColumnBase,
all_nan: bool = False,
) -> TimeDeltaColumn:
return cast(
TimeDeltaColumn,
_datetime_timedelta_find_and_replace(
original_column=self,
to_replace=to_replace,
replacement=replacement,
all_nan=all_nan,
),
)

def can_cast_safely(self, to_dtype: Dtype) -> bool:
if to_dtype.kind == "m": # type: ignore[union-attr]
to_res, _ = np.datetime_data(to_dtype)
self_res, _ = np.datetime_data(self.dtype)

max_int = np.iinfo(np.int64).max

max_dist = np.timedelta64(
self.max().astype(np.int64, copy=False), self_res
)
min_dist = np.timedelta64(
self.min().astype(np.int64, copy=False), self_res
)

self_delta_dtype = np.timedelta64(0, self_res).dtype

if max_dist <= np.timedelta64(max_int, to_res).astype(
self_delta_dtype
) and min_dist <= np.timedelta64(max_int, to_res).astype(
self_delta_dtype
):
return True
else:
return False
elif to_dtype == cudf.dtype("int64") or to_dtype == cudf.dtype("O"):
# can safely cast to representation, or string
return True
else:
return False

def mean(self, skipna=None) -> pd.Timedelta:
return pd.Timedelta(
cast(
Expand Down
Binary file not shown.
119 changes: 86 additions & 33 deletions python/cudf/cudf/tests/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from cudf.core._compat import (
PANDAS_CURRENT_SUPPORTED_VERSION,
PANDAS_GE_220,
PANDAS_GT_214,
PANDAS_VERSION,
)
from cudf.core.dtypes import Decimal32Dtype, Decimal64Dtype, Decimal128Dtype
Expand Down Expand Up @@ -116,8 +117,10 @@ def test_series_replace():
sr6 = sr1.replace([0, 1], [5, 6])
assert_eq(a6, sr6.to_numpy())

with pytest.raises(TypeError):
sr1.replace([0, 1], [5.5, 6.5])
assert_eq(
sr1.replace([0, 1], [5.5, 6.5]),
sr1.to_pandas().replace([0, 1], [5.5, 6.5]),
)

# Series input
a8 = np.array([5, 5, 5, 3, 4])
Expand Down Expand Up @@ -160,8 +163,10 @@ def test_series_replace_with_nulls():
assert_eq(a6, sr6.to_numpy())

sr1 = cudf.Series([0, 1, 2, 3, 4, None])
with pytest.raises(TypeError):
sr1.replace([0, 1], [5.5, 6.5]).fillna(-10)
assert_eq(
sr1.replace([0, 1], [5.5, 6.5]).fillna(-10),
sr1.to_pandas().replace([0, 1], [5.5, 6.5]).fillna(-10),
)

# Series input
a8 = np.array([-10, -10, -10, 3, 4, -10])
Expand Down Expand Up @@ -967,30 +972,37 @@ def test_series_multiple_times_with_nulls():
@pytest.mark.parametrize(
"replacement", [128, 128.0, 128.5, 32769, 32769.0, 32769.5]
)
def test_numeric_series_replace_dtype(series_dtype, replacement):
def test_numeric_series_replace_dtype(request, series_dtype, replacement):
request.applymarker(
pytest.mark.xfail(
condition=PANDAS_GT_214
and (
(
series_dtype == "int8"
and replacement in {128, 128.0, 32769, 32769.0}
)
or (
series_dtype == "int16" and replacement in {32769, 32769.0}
)
),
reason="Pandas throws an AssertionError for these "
"cases and asks us to log a bug, they are trying to "
"avoid a RecursionError which cudf will not run into",
)
)
psr = pd.Series([0, 1, 2, 3, 4, 5], dtype=series_dtype)
sr = cudf.from_pandas(psr)

numpy_replacement = np.array(replacement).astype(sr.dtype)[()]
can_replace = numpy_replacement == replacement
expect = psr.replace(1, replacement)
got = sr.replace(1, replacement)

# Both Scalar
if not can_replace:
with pytest.raises(TypeError):
sr.replace(1, replacement)
else:
expect = psr.replace(1, replacement).astype(psr.dtype)
got = sr.replace(1, replacement)
assert_eq(expect, got)
assert_eq(expect, got)

# to_replace is a list, replacement is a scalar
if not can_replace:
with pytest.raises(TypeError):
sr.replace([2, 3], replacement)
else:
expect = psr.replace([2, 3], replacement).astype(psr.dtype)
got = sr.replace([2, 3], replacement)
assert_eq(expect, got)
expect = psr.replace([2, 3], replacement)
got = sr.replace([2, 3], replacement)

assert_eq(expect, got)

# If to_replace is a scalar and replacement is a list
with pytest.raises(TypeError):
Expand All @@ -1001,17 +1013,9 @@ def test_numeric_series_replace_dtype(series_dtype, replacement):
sr.replace([0, 1], [replacement])

# Both lists of equal length
if (
np.dtype(type(replacement)).kind == "f" and sr.dtype.kind in {"i", "u"}
) or (not can_replace):
with pytest.raises(TypeError):
sr.replace([2, 3], [replacement, replacement])
else:
expect = psr.replace([2, 3], [replacement, replacement]).astype(
psr.dtype
)
got = sr.replace([2, 3], [replacement, replacement])
assert_eq(expect, got)
expect = psr.replace([2, 3], [replacement, replacement])
got = sr.replace([2, 3], [replacement, replacement])
assert_eq(expect, got)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1392,3 +1396,52 @@ def test_replace_with_index_objects():
result = cudf.Series([1, 2]).replace(cudf.Index([1]), cudf.Index([2]))
expected = pd.Series([1, 2]).replace(pd.Index([1]), pd.Index([2]))
assert_eq(result, expected)


# Example test function for datetime series replace
def test_replace_datetime_series():
# Create a pandas datetime series
pd_series = pd.Series(pd.date_range("20210101", periods=5))
# Replace a specific datetime value
pd_result = pd_series.replace(
pd.Timestamp("2021-01-02"), pd.Timestamp("2021-01-10")
)

# Create a cudf datetime series
cudf_series = cudf.Series(pd.date_range("20210101", periods=5))
# Replace a specific datetime value
cudf_result = cudf_series.replace(
pd.Timestamp("2021-01-02"), pd.Timestamp("2021-01-10")
)

assert_eq(pd_result, cudf_result)


# Example test function for timedelta series replace
def test_replace_timedelta_series():
# Create a pandas timedelta series
pd_series = pd.Series(pd.timedelta_range("1 days", periods=5))
# Replace a specific timedelta value
pd_result = pd_series.replace(
pd.Timedelta("2 days"), pd.Timedelta("10 days")
)

# Create a cudf timedelta series
cudf_series = cudf.Series(pd.timedelta_range("1 days", periods=5))
# Replace a specific timedelta value
cudf_result = cudf_series.replace(
pd.Timedelta("2 days"), pd.Timedelta("10 days")
)

assert_eq(pd_result, cudf_result)


def test_replace_multiple_rows(datadir):
path = datadir / "parquet" / "replace_multiple_rows.parquet"
pdf = pd.read_parquet(path)
gdf = cudf.read_parquet(path)

pdf.replace([np.inf, -np.inf], np.nan, inplace=True)
gdf.replace([np.inf, -np.inf], np.nan, inplace=True)

assert_eq(pdf, gdf, check_dtype=False)
18 changes: 12 additions & 6 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,19 @@ def min_column_type(x, expected_type):
if x.null_count == len(x):
return x.dtype

if x.dtype.kind == "f":
return get_min_float_dtype(x)

elif cudf.dtype(expected_type).kind in "iu":
max_bound_dtype = np.min_scalar_type(x.max())
min_bound_dtype = np.min_scalar_type(x.min())
min_value, max_value = x.min(), x.max()
either_is_inf = np.isinf(min_value) or np.isinf(max_value)
expected_type = cudf.dtype(expected_type)
if not either_is_inf and expected_type.kind in "i":
max_bound_dtype = min_signed_type(max_value)
min_bound_dtype = min_signed_type(min_value)
result_type = np.promote_types(max_bound_dtype, min_bound_dtype)
elif not either_is_inf and expected_type.kind in "u":
max_bound_dtype = min_unsigned_type(max_value)
min_bound_dtype = min_unsigned_type(min_value)
result_type = np.promote_types(max_bound_dtype, min_bound_dtype)
elif x.dtype.kind == "f":
return get_min_float_dtype(x)
else:
result_type = x.dtype

Expand Down
Loading

0 comments on commit d475dca

Please sign in to comment.