Skip to content

Commit

Permalink
Enable casting to pandas nullable dtypes (rapidsai#8889)
Browse files Browse the repository at this point in the history
Fixes rapidsai#8885

Defer to the closest corresponding NumPy dtype.

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Ashwin Srinath (https://github.com/shwina)

URL: rapidsai#8889
  • Loading branch information
brandon-b-miller authored Jul 29, 2021
1 parent f5a8446 commit fa3ae36
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
5 changes: 5 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
get_time_unit,
min_unsigned_type,
np_to_pa_dtype,
pandas_dtypes_to_cudf_dtypes,
)
from cudf.utils.utils import mask_dtype

Expand Down Expand Up @@ -877,6 +878,10 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
raise NotImplementedError()

def astype(self, dtype: Dtype, **kwargs) -> ColumnBase:
if is_categorical_dtype(dtype):
return self.as_categorical_column(dtype, **kwargs)

dtype = pandas_dtypes_to_cudf_dtypes.get(dtype, dtype)
if _is_non_decimal_numeric_dtype(dtype):
return self.as_numerical_column(dtype, **kwargs)
elif is_categorical_dtype(dtype):
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3633,6 +3633,23 @@ def test_one_row_head():
assert_eq(head_pdf, head_gdf)


@pytest.mark.parametrize("dtype", ALL_TYPES)
@pytest.mark.parametrize(
"np_dtype,pd_dtype",
[
tuple(item)
for item in cudf.utils.dtypes.cudf_dtypes_to_pandas_dtypes.items()
],
)
def test_series_astype_pandas_nullable(dtype, np_dtype, pd_dtype):
source = cudf.Series([0, 1, None], dtype=dtype)

expect = source.astype(np_dtype)
got = source.astype(pd_dtype)

assert_eq(expect, got)


@pytest.mark.parametrize("dtype", NUMERIC_TYPES)
@pytest.mark.parametrize("as_dtype", NUMERIC_TYPES)
def test_series_astype_numeric_to_numeric(dtype, as_dtype):
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/tests/test_joining.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ def test_categorical_typecast_inner_one_cat(dtype):
data = np.array([1, 2, 3], dtype=dtype)

left = make_categorical_dataframe(data)
right = left.astype(left["key"].dtype.categories)
right = left.astype(left["key"].dtype.categories.dtype)

result = left.merge(right, on="key", how="inner")
assert result["key"].dtype == left["key"].dtype.categories.dtype
Expand All @@ -1541,7 +1541,7 @@ def test_categorical_typecast_left_one_cat(dtype):
data = np.array([1, 2, 3], dtype=dtype)

left = make_categorical_dataframe(data)
right = left.astype(left["key"].dtype.categories)
right = left.astype(left["key"].dtype.categories.dtype)

result = left.merge(right, on="key", how="left")
assert result["key"].dtype == left["key"].dtype
Expand All @@ -1553,7 +1553,7 @@ def test_categorical_typecast_outer_one_cat(dtype):
data = np.array([1, 2, 3], dtype=dtype)

left = make_categorical_dataframe(data)
right = left.astype(left["key"].dtype.categories)
right = left.astype(left["key"].dtype.categories.dtype)

result = left.merge(right, on="key", how="outer")
assert result["key"].dtype == left["key"].dtype.categories.dtype
Expand Down

0 comments on commit fa3ae36

Please sign in to comment.