diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index d449d52927e..a5e49b026f3 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -70,6 +70,7 @@ get_time_unit, min_unsigned_type, np_to_pa_dtype, + pandas_dtypes_to_cudf_dtypes, ) from cudf.utils.utils import mask_dtype @@ -877,6 +878,10 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool: raise NotImplementedError() def astype(self, dtype: Dtype, **kwargs) -> ColumnBase: + if is_categorical_dtype(dtype): + return self.as_categorical_column(dtype, **kwargs) + + dtype = pandas_dtypes_to_cudf_dtypes.get(dtype, dtype) if _is_non_decimal_numeric_dtype(dtype): return self.as_numerical_column(dtype, **kwargs) elif is_categorical_dtype(dtype): diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index ecd31afd9e8..9acf6783095 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -3633,6 +3633,23 @@ def test_one_row_head(): assert_eq(head_pdf, head_gdf) +@pytest.mark.parametrize("dtype", ALL_TYPES) +@pytest.mark.parametrize( + "np_dtype,pd_dtype", + [ + tuple(item) + for item in cudf.utils.dtypes.cudf_dtypes_to_pandas_dtypes.items() + ], +) +def test_series_astype_pandas_nullable(dtype, np_dtype, pd_dtype): + source = cudf.Series([0, 1, None], dtype=dtype) + + expect = source.astype(np_dtype) + got = source.astype(pd_dtype) + + assert_eq(expect, got) + + @pytest.mark.parametrize("dtype", NUMERIC_TYPES) @pytest.mark.parametrize("as_dtype", NUMERIC_TYPES) def test_series_astype_numeric_to_numeric(dtype, as_dtype): diff --git a/python/cudf/cudf/tests/test_joining.py b/python/cudf/cudf/tests/test_joining.py index 7b56f864272..4ae7c40ead8 100644 --- a/python/cudf/cudf/tests/test_joining.py +++ b/python/cudf/cudf/tests/test_joining.py @@ -1529,7 +1529,7 @@ def test_categorical_typecast_inner_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="inner") assert result["key"].dtype == left["key"].dtype.categories.dtype @@ -1541,7 +1541,7 @@ def test_categorical_typecast_left_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="left") assert result["key"].dtype == left["key"].dtype @@ -1553,7 +1553,7 @@ def test_categorical_typecast_outer_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="outer") assert result["key"].dtype == left["key"].dtype.categories.dtype