Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Columns.astype & cudf.dtype #15125

Merged
merged 16 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cudf/cudf/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ def _is_pandas_nullable_extension_dtype(dtype_to_check) -> bool:
):
return True
elif isinstance(dtype_to_check, pd.CategoricalDtype):
if dtype_to_check.categories is None:
return False
return _is_pandas_nullable_extension_dtype(
dtype_to_check.categories.dtype
)
Expand Down
53 changes: 18 additions & 35 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@
min_scalar_type,
min_unsigned_type,
np_to_pa_dtype,
pandas_dtypes_alias_to_cudf_alias,
pandas_dtypes_to_np_dtypes,
)
from cudf.utils.utils import _array_ufunc, mask_dtype

Expand Down Expand Up @@ -974,42 +972,20 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
col = self.copy()
else:
col = self
if self.dtype == dtype:
return col
if _is_categorical_dtype(dtype):
if dtype == "category":
# TODO: Figure out why `cudf.dtype("category")`
# astype's different than just the string
return col.as_categorical_column(dtype)

if (
isinstance(dtype, str)
and dtype in pandas_dtypes_alias_to_cudf_alias
):
if cudf.get_option("mode.pandas_compatible"):
raise NotImplementedError("not supported")
else:
dtype = pandas_dtypes_alias_to_cudf_alias[dtype]
elif _is_pandas_nullable_extension_dtype(dtype) and cudf.get_option(
"mode.pandas_compatible"
elif dtype == "interval" and isinstance(
self.dtype, cudf.IntervalDtype
):
raise NotImplementedError("not supported")
else:
dtype = pandas_dtypes_to_np_dtypes.get(dtype, dtype)
if _is_non_decimal_numeric_dtype(dtype):
return col.as_numerical_column(dtype)
elif _is_categorical_dtype(dtype):
return col
was_object = dtype == object or dtype == np.dtype(object)
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
return col
elif isinstance(dtype, CategoricalDtype):
return col.as_categorical_column(dtype)
elif cudf.dtype(dtype).type in {
np.str_,
np.object_,
str,
}:
if cudf.get_option("mode.pandas_compatible") and np.dtype(
dtype
).type in {np.object_}:
raise ValueError(
f"Casting to {dtype} is not supported, use "
"`.astype('str')` instead."
)
return col.as_string_column(dtype)
elif isinstance(dtype, IntervalDtype):
return col.as_interval_column(dtype)
elif isinstance(dtype, (ListDtype, StructDtype)):
Expand All @@ -1024,6 +1000,13 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
return col.as_datetime_column(dtype)
elif np.issubdtype(cast(Any, dtype), np.timedelta64):
return col.as_timedelta_column(dtype)
elif dtype.kind == "O":
if cudf.get_option("mode.pandas_compatible") and was_object:
raise ValueError(
f"Casting to {dtype} is not supported, use "
"`.astype('str')` instead."
)
return col.as_string_column(dtype)
else:
return col.as_numerical_column(dtype)

Expand Down
38 changes: 19 additions & 19 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,38 +42,38 @@ def dtype(arbitrary):
# next, try interpreting arbitrary as a NumPy dtype that we support:
try:
np_dtype = np.dtype(arbitrary)
if np_dtype.kind in ("OU"):
return np.dtype("object")
except TypeError:
pass
else:
if np_dtype not in cudf._lib.types.SUPPORTED_NUMPY_TO_LIBCUDF_TYPES:
if np_dtype.kind in set("OU"):
return np.dtype("object")
elif np_dtype not in cudf._lib.types.SUPPORTED_NUMPY_TO_LIBCUDF_TYPES:
raise TypeError(f"Unsupported type {np_dtype}")
return np_dtype

# use `pandas_dtype` to try and interpret
# `arbitrary` as a Pandas extension type.
# Return the corresponding NumPy/cuDF type.
pd_dtype = pd.api.types.pandas_dtype(arbitrary)
if cudf.get_option(
"mode.pandas_compatible"
) and cudf.api.types._is_pandas_nullable_extension_dtype(pd_dtype):
raise NotImplementedError("not supported")
try:
return dtype(pd_dtype.numpy_dtype)
except AttributeError:
if isinstance(pd_dtype, pd.CategoricalDtype):
return cudf.CategoricalDtype.from_pandas(pd_dtype)
if cudf.api.types._is_pandas_nullable_extension_dtype(pd_dtype):
if cudf.get_option("mode.pandas_compatible"):
raise NotImplementedError(
"Nullable types not supported in pandas compatibility mode"
)
elif isinstance(pd_dtype, pd.StringDtype):
return np.dtype("object")
elif isinstance(pd_dtype, pd.IntervalDtype):
return cudf.IntervalDtype.from_pandas(pd_dtype)
elif isinstance(pd_dtype, pd.DatetimeTZDtype):
return pd_dtype
else:
raise TypeError(
f"Cannot interpret {arbitrary} as a valid cuDF dtype"
)
return dtype(pd_dtype.numpy_dtype)
elif isinstance(pd_dtype, pd.core.dtypes.dtypes.NumpyEADtype):
return dtype(pd_dtype.numpy_dtype)
elif isinstance(pd_dtype, pd.CategoricalDtype):
return cudf.CategoricalDtype.from_pandas(pd_dtype)
elif isinstance(pd_dtype, pd.IntervalDtype):
return cudf.IntervalDtype.from_pandas(pd_dtype)
elif isinstance(pd_dtype, pd.DatetimeTZDtype):
return pd_dtype
else:
raise TypeError(f"Cannot interpret {arbitrary} as a valid cuDF dtype")


def _decode_type(
Expand Down
14 changes: 0 additions & 14 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,11 @@
pd.StringDtype(): np.dtype("object"),
}

pandas_dtypes_alias_to_cudf_alias = {
"UInt8": "uint8",
"UInt16": "uint16",
"UInt32": "uint32",
"UInt64": "uint64",
"Int8": "int8",
"Int16": "int16",
"Int32": "int32",
"Int64": "int64",
"boolean": "bool",
}


np_dtypes_to_pandas_dtypes[np.dtype("float32")] = pd.Float32Dtype()
np_dtypes_to_pandas_dtypes[np.dtype("float64")] = pd.Float64Dtype()
pandas_dtypes_to_np_dtypes[pd.Float32Dtype()] = np.dtype("float32")
pandas_dtypes_to_np_dtypes[pd.Float64Dtype()] = np.dtype("float64")
pandas_dtypes_alias_to_cudf_alias["Float32"] = "float32"
pandas_dtypes_alias_to_cudf_alias["Float64"] = "float64"

SIGNED_INTEGER_TYPES = {"int8", "int16", "int32", "int64"}
UNSIGNED_TYPES = {"uint8", "uint16", "uint32", "uint64"}
Expand Down
Loading