diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 24a4b17ca3c..02231f26f61 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -71,6 +71,7 @@ min_unsigned_type, np_to_pa_dtype, pandas_dtypes_to_cudf_dtypes, + pandas_dtypes_alias_to_cudf_alias, ) from cudf.utils.utils import mask_dtype @@ -881,14 +882,13 @@ def astype(self, dtype: Dtype, **kwargs) -> ColumnBase: if is_categorical_dtype(dtype): return self.as_categorical_column(dtype, **kwargs) - dtype = pandas_dtypes_to_cudf_dtypes.get(dtype, dtype) + dtype = ( + pandas_dtypes_alias_to_cudf_alias.get(dtype, dtype) + if isinstance(dtype, str) + else pandas_dtypes_to_cudf_dtypes.get(dtype, dtype) + ) if _is_non_decimal_numeric_dtype(dtype): - try: - return self.as_numerical_column(dtype, **kwargs) - except TypeError: - return self.as_numerical_column( - self.convert_alias(dtype), **kwargs - ) + return self.as_numerical_column(dtype, **kwargs) elif is_categorical_dtype(dtype): return self.as_categorical_column(dtype, **kwargs) elif pandas_dtype(dtype).type in { @@ -968,20 +968,6 @@ def as_numerical_column( ) -> "cudf.core.column.NumericalColumn": raise NotImplementedError - def convert_alias(self, dtype): - aliases = { - "UInt8": "uint8", - "UInt16": "uint16", - "UInt32": "uint32", - "UInt64": "uint64", - "Int8": "int8", - "Int16": "int16", - "Int32": "int32", - "Int64": "int64", - "boolean": "bool", - } - return aliases[dtype] - def as_datetime_column( self, dtype: Dtype, **kwargs ) -> "cudf.core.column.DatetimeColumn": diff --git a/python/cudf/cudf/tests/test_column.py b/python/cudf/cudf/tests/test_column.py index 11b2e4bc9f9..761b2f32f18 100644 --- a/python/cudf/cudf/tests/test_column.py +++ b/python/cudf/cudf/tests/test_column.py @@ -495,6 +495,8 @@ def test_concatenate_large_column_strings(): ("Int32", "int32"), ("Int64", "int64"), ("boolean", "bool"), + ("Float32", "float32"), + ("Float64", "float64"), ], ) @pytest.mark.parametrize( diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index e1ae87e5089..46bd1b449c4 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -92,11 +92,25 @@ pd.StringDtype(): np.dtype("object"), } +pandas_dtypes_alias_to_cudf_alias = { + "UInt8": "uint8", + "UInt16": "uint16", + "UInt32": "uint32", + "UInt64": "uint64", + "Int8": "int8", + "Int16": "int16", + "Int32": "int32", + "Int64": "int64", + "boolean": "bool", +} + if PANDAS_GE_120: cudf_dtypes_to_pandas_dtypes[np.dtype("float32")] = pd.Float32Dtype() cudf_dtypes_to_pandas_dtypes[np.dtype("float64")] = pd.Float64Dtype() pandas_dtypes_to_cudf_dtypes[pd.Float32Dtype()] = np.dtype("float32") pandas_dtypes_to_cudf_dtypes[pd.Float64Dtype()] = np.dtype("float64") + pandas_dtypes_alias_to_cudf_alias["Float32"] = "float32" + pandas_dtypes_alias_to_cudf_alias["Float64"] = "float64" SIGNED_INTEGER_TYPES = {"int8", "int16", "int32", "int64"} UNSIGNED_TYPES = {"uint8", "uint16", "uint32", "uint64"}