diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index a422eb82231..417d8b0922a 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -504,6 +504,8 @@ def _is_pandas_nullable_extension_dtype(dtype_to_check) -> bool: ): return True elif isinstance(dtype_to_check, pd.CategoricalDtype): + if dtype_to_check.categories is None: + return False return _is_pandas_nullable_extension_dtype( dtype_to_check.categories.dtype ) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 8941d111d02..ff1204b6178 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -90,8 +90,6 @@ min_scalar_type, min_unsigned_type, np_to_pa_dtype, - pandas_dtypes_alias_to_cudf_alias, - pandas_dtypes_to_np_dtypes, ) from cudf.utils.utils import _array_ufunc, mask_dtype @@ -974,42 +972,20 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase: col = self.copy() else: col = self - if self.dtype == dtype: - return col - if _is_categorical_dtype(dtype): + if dtype == "category": + # TODO: Figure out why `cudf.dtype("category")` + # astype's different than just the string return col.as_categorical_column(dtype) - - if ( - isinstance(dtype, str) - and dtype in pandas_dtypes_alias_to_cudf_alias - ): - if cudf.get_option("mode.pandas_compatible"): - raise NotImplementedError("not supported") - else: - dtype = pandas_dtypes_alias_to_cudf_alias[dtype] - elif _is_pandas_nullable_extension_dtype(dtype) and cudf.get_option( - "mode.pandas_compatible" + elif dtype == "interval" and isinstance( + self.dtype, cudf.IntervalDtype ): - raise NotImplementedError("not supported") - else: - dtype = pandas_dtypes_to_np_dtypes.get(dtype, dtype) - if _is_non_decimal_numeric_dtype(dtype): - return col.as_numerical_column(dtype) - elif _is_categorical_dtype(dtype): + return col + was_object = dtype == object or dtype == np.dtype(object) + dtype = cudf.dtype(dtype) + if self.dtype == dtype: + return col + elif isinstance(dtype, CategoricalDtype): return col.as_categorical_column(dtype) - elif cudf.dtype(dtype).type in { - np.str_, - np.object_, - str, - }: - if cudf.get_option("mode.pandas_compatible") and np.dtype( - dtype - ).type in {np.object_}: - raise ValueError( - f"Casting to {dtype} is not supported, use " - "`.astype('str')` instead." - ) - return col.as_string_column(dtype) elif isinstance(dtype, IntervalDtype): return col.as_interval_column(dtype) elif isinstance(dtype, (ListDtype, StructDtype)): @@ -1024,6 +1000,13 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase: return col.as_datetime_column(dtype) elif np.issubdtype(cast(Any, dtype), np.timedelta64): return col.as_timedelta_column(dtype) + elif dtype.kind == "O": + if cudf.get_option("mode.pandas_compatible") and was_object: + raise ValueError( + f"Casting to {dtype} is not supported, use " + "`.astype('str')` instead." + ) + return col.as_string_column(dtype) else: return col.as_numerical_column(dtype) diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 26d2ea3e992..c658701f851 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -42,12 +42,12 @@ def dtype(arbitrary): # next, try interpreting arbitrary as a NumPy dtype that we support: try: np_dtype = np.dtype(arbitrary) - if np_dtype.kind in ("OU"): - return np.dtype("object") except TypeError: pass else: - if np_dtype not in cudf._lib.types.SUPPORTED_NUMPY_TO_LIBCUDF_TYPES: + if np_dtype.kind in set("OU"): + return np.dtype("object") + elif np_dtype not in cudf._lib.types.SUPPORTED_NUMPY_TO_LIBCUDF_TYPES: raise TypeError(f"Unsupported type {np_dtype}") return np_dtype @@ -55,25 +55,25 @@ def dtype(arbitrary): # `arbitrary` as a Pandas extension type. # Return the corresponding NumPy/cuDF type. pd_dtype = pd.api.types.pandas_dtype(arbitrary) - if cudf.get_option( - "mode.pandas_compatible" - ) and cudf.api.types._is_pandas_nullable_extension_dtype(pd_dtype): - raise NotImplementedError("not supported") - try: - return dtype(pd_dtype.numpy_dtype) - except AttributeError: - if isinstance(pd_dtype, pd.CategoricalDtype): - return cudf.CategoricalDtype.from_pandas(pd_dtype) + if cudf.api.types._is_pandas_nullable_extension_dtype(pd_dtype): + if cudf.get_option("mode.pandas_compatible"): + raise NotImplementedError( + "Nullable types not supported in pandas compatibility mode" + ) elif isinstance(pd_dtype, pd.StringDtype): return np.dtype("object") - elif isinstance(pd_dtype, pd.IntervalDtype): - return cudf.IntervalDtype.from_pandas(pd_dtype) - elif isinstance(pd_dtype, pd.DatetimeTZDtype): - return pd_dtype else: - raise TypeError( - f"Cannot interpret {arbitrary} as a valid cuDF dtype" - ) + return dtype(pd_dtype.numpy_dtype) + elif isinstance(pd_dtype, pd.core.dtypes.dtypes.NumpyEADtype): + return dtype(pd_dtype.numpy_dtype) + elif isinstance(pd_dtype, pd.CategoricalDtype): + return cudf.CategoricalDtype.from_pandas(pd_dtype) + elif isinstance(pd_dtype, pd.IntervalDtype): + return cudf.IntervalDtype.from_pandas(pd_dtype) + elif isinstance(pd_dtype, pd.DatetimeTZDtype): + return pd_dtype + else: + raise TypeError(f"Cannot interpret {arbitrary} as a valid cuDF dtype") def _decode_type( diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index c8aca94ba19..3780fcc627e 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -74,25 +74,11 @@ pd.StringDtype(): np.dtype("object"), } -pandas_dtypes_alias_to_cudf_alias = { - "UInt8": "uint8", - "UInt16": "uint16", - "UInt32": "uint32", - "UInt64": "uint64", - "Int8": "int8", - "Int16": "int16", - "Int32": "int32", - "Int64": "int64", - "boolean": "bool", -} - np_dtypes_to_pandas_dtypes[np.dtype("float32")] = pd.Float32Dtype() np_dtypes_to_pandas_dtypes[np.dtype("float64")] = pd.Float64Dtype() pandas_dtypes_to_np_dtypes[pd.Float32Dtype()] = np.dtype("float32") pandas_dtypes_to_np_dtypes[pd.Float64Dtype()] = np.dtype("float64") -pandas_dtypes_alias_to_cudf_alias["Float32"] = "float32" -pandas_dtypes_alias_to_cudf_alias["Float64"] = "float64" SIGNED_INTEGER_TYPES = {"int8", "int16", "int32", "int64"} UNSIGNED_TYPES = {"uint8", "uint16", "uint32", "uint64"}