Skip to content

Commit

Permalink
Merge pull request #14807 from rapidsai/branch-24.02
Browse files Browse the repository at this point in the history
Forward-merge branch-24.02 to branch-24.04
  • Loading branch information
GPUtester authored Jan 19, 2024
2 parents 8dd6c7a + 51ecef3 commit 3e9056f
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 29 deletions.
10 changes: 4 additions & 6 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,9 +958,7 @@ def distinct_count(self, dropna: bool = True) -> int:
def can_cast_safely(self, to_dtype: Dtype) -> bool:
raise NotImplementedError()

def astype(
self, dtype: Dtype, copy: bool = False, format: str | None = None
) -> ColumnBase:
def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
if copy:
col = self.copy()
else:
Expand Down Expand Up @@ -1000,7 +998,7 @@ def astype(
f"Casting to {dtype} is not supported, use "
"`.astype('str')` instead."
)
return col.as_string_column(dtype, format=format)
return col.as_string_column(dtype)
elif isinstance(dtype, (ListDtype, StructDtype)):
if not col.dtype == dtype:
raise NotImplementedError(
Expand All @@ -1012,9 +1010,9 @@ def astype(
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
return col.as_decimal_column(dtype)
elif np.issubdtype(cast(Any, dtype), np.datetime64):
return col.as_datetime_column(dtype, format=format)
return col.as_datetime_column(dtype)
elif np.issubdtype(cast(Any, dtype), np.timedelta64):
return col.as_timedelta_column(dtype, format=format)
return col.as_timedelta_column(dtype)
else:
return col.as_numerical_column(dtype)

Expand Down
10 changes: 8 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Callable,
Dict,
List,
Literal,
MutableMapping,
Optional,
Set,
Expand Down Expand Up @@ -1774,7 +1775,12 @@ def _concat(

return out

def astype(self, dtype, copy=False, errors="raise", **kwargs):
def astype(
self,
dtype,
copy: bool = False,
errors: Literal["raise", "ignore"] = "raise",
):
if is_dict_like(dtype):
if len(set(dtype.keys()) - set(self._data.names)) > 0:
raise KeyError(
Expand All @@ -1783,7 +1789,7 @@ def astype(self, dtype, copy=False, errors="raise", **kwargs):
)
else:
dtype = {cc: dtype for cc in self._data.names}
return super().astype(dtype, copy, errors, **kwargs)
return super().astype(dtype, copy, errors)

def _clean_renderable_dataframe(self, output):
"""
Expand Down
13 changes: 5 additions & 8 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,11 @@ def __len__(self):
return self._num_rows

@_cudf_nvtx_annotate
def astype(self, dtype, copy=False, **kwargs):
result_data = {}
for col_name, col in self._data.items():
dt = dtype.get(col_name, col.dtype)
if not is_dtype_equal(dt, col.dtype):
result_data[col_name] = col.astype(dt, copy=copy, **kwargs)
else:
result_data[col_name] = col.copy() if copy else col
def astype(self, dtype, copy: bool = False):
result_data = {
col_name: col.astype(dtype.get(col_name, col.dtype), copy=copy)
for col_name, col in self._data.items()
}

return ColumnAccessor._create_unsafe(
data=result_data,
Expand Down
11 changes: 8 additions & 3 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Callable,
Dict,
List,
Literal,
MutableMapping,
Optional,
Tuple,
Expand Down Expand Up @@ -3736,7 +3737,12 @@ def _append(

return cudf.concat(to_concat, ignore_index=ignore_index, sort=sort)

def astype(self, dtype, copy=False, errors="raise", **kwargs):
def astype(
self,
dtype,
copy: bool = False,
errors: Literal["raise", "ignore"] = "raise",
):
"""Cast the object to the given dtype.
Parameters
Expand All @@ -3757,7 +3763,6 @@ def astype(self, dtype, copy=False, errors="raise", **kwargs):
- ``raise`` : allow exceptions to be raised
- ``ignore`` : suppress exceptions. On error return original
object.
**kwargs : extra arguments to pass on to the constructor
Returns
-------
Expand Down Expand Up @@ -3848,7 +3853,7 @@ def astype(self, dtype, copy=False, errors="raise", **kwargs):
raise ValueError("invalid error value specified")

try:
data = super().astype(dtype, copy, **kwargs)
data = super().astype(dtype, copy)
except Exception as e:
if errors == "raise":
raise e
Expand Down
10 changes: 8 additions & 2 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
Any,
Dict,
Literal,
MutableMapping,
Optional,
Sequence,
Expand Down Expand Up @@ -2141,7 +2142,12 @@ def nullmask(self):
return cudf.Series(self._column.nullmask)

@_cudf_nvtx_annotate
def astype(self, dtype, copy=False, errors="raise", **kwargs):
def astype(
self,
dtype,
copy: bool = False,
errors: Literal["raise", "ignore"] = "raise",
):
if is_dict_like(dtype):
if len(dtype) > 1 or self.name not in dtype:
raise KeyError(
Expand All @@ -2150,7 +2156,7 @@ def astype(self, dtype, copy=False, errors="raise", **kwargs):
)
else:
dtype = {self.name: dtype}
return super().astype(dtype, copy, errors, **kwargs)
return super().astype(dtype, copy, errors)

@_cudf_nvtx_annotate
def sort_index(self, axis=0, *args, **kwargs):
Expand Down
19 changes: 11 additions & 8 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

import array as arr
import contextlib
Expand Down Expand Up @@ -5114,15 +5114,18 @@ def test_df_astype_to_categorical_ordered(ordered):


@pytest.mark.parametrize(
"dtype,args",
[(dtype, {}) for dtype in ALL_TYPES]
+ [("category", {"ordered": True}), ("category", {"ordered": False})],
"dtype",
[dtype for dtype in ALL_TYPES]
+ [
cudf.CategoricalDtype(ordered=True),
cudf.CategoricalDtype(ordered=False),
],
)
def test_empty_df_astype(dtype, args):
def test_empty_df_astype(dtype):
df = cudf.DataFrame()
kwargs = {}
kwargs.update(args)
assert_eq(df, df.astype(dtype=dtype, **kwargs))
result = df.astype(dtype=dtype)
assert_eq(df, result)
assert_eq(df.to_pandas().astype(dtype=dtype), result)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3e9056f

Please sign in to comment.