From ea82bf4f93427334f199c2e2cfd4d6e1feb6023b Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Mon, 28 Jun 2021 20:11:18 -0400 Subject: [PATCH] Propagate **kwargs through to as_*_column methods (#8618) Fixes #8616 Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - https://github.com/brandon-b-miller URL: https://github.com/rapidsai/cudf/pull/8618 --- python/cudf/cudf/core/column/categorical.py | 6 +++--- python/cudf/cudf/core/column/column.py | 8 ++++---- python/cudf/cudf/core/column/datetime.py | 4 ++-- python/cudf/cudf/core/column/decimal.py | 6 +++--- python/cudf/cudf/core/column/numerical.py | 4 ++-- python/cudf/cudf/core/column/string.py | 6 ++++-- python/cudf/cudf/core/column/timedelta.py | 4 ++-- python/cudf/cudf/tests/test_dataframe.py | 11 +++++++++++ 8 files changed, 31 insertions(+), 18 deletions(-) diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index e2aa20cc948..135fb6e6f30 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -32,10 +32,10 @@ from cudf.core.dtypes import CategoricalDtype from cudf.utils.dtypes import ( is_categorical_dtype, + is_interval_dtype, is_mixed_with_object_dtype, min_signed_type, min_unsigned_type, - is_interval_dtype, ) if TYPE_CHECKING: @@ -1388,10 +1388,10 @@ def as_categorical_column( new_categories=dtype.categories, ordered=dtype.ordered ) - def as_numerical_column(self, dtype: Dtype) -> NumericalColumn: + def as_numerical_column(self, dtype: Dtype, **kwargs) -> NumericalColumn: return self._get_decategorized_column().as_numerical_column(dtype) - def as_string_column(self, dtype, format=None) -> StringColumn: + def as_string_column(self, dtype, format=None, **kwargs) -> StringColumn: return self._get_decategorized_column().as_string_column( dtype, format=format ) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 79d97a3dbe1..50367651146 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -877,7 +877,7 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool: def astype(self, dtype: Dtype, **kwargs) -> ColumnBase: if _is_non_decimal_numeric_dtype(dtype): - return self.as_numerical_column(dtype) + return self.as_numerical_column(dtype, **kwargs) elif is_categorical_dtype(dtype): return self.as_categorical_column(dtype, **kwargs) elif pandas_dtype(dtype).type in { @@ -901,7 +901,7 @@ def astype(self, dtype: Dtype, **kwargs) -> ColumnBase: elif np.issubdtype(dtype, np.timedelta64): return self.as_timedelta_column(dtype, **kwargs) else: - return self.as_numerical_column(dtype) + return self.as_numerical_column(dtype, **kwargs) def as_categorical_column(self, dtype, **kwargs) -> ColumnBase: if "ordered" in kwargs: @@ -947,7 +947,7 @@ def as_categorical_column(self, dtype, **kwargs) -> ColumnBase: ) def as_numerical_column( - self, dtype: Dtype + self, dtype: Dtype, **kwargs ) -> "cudf.core.column.NumericalColumn": raise NotImplementedError @@ -967,7 +967,7 @@ def as_timedelta_column( raise NotImplementedError def as_string_column( - self, dtype: Dtype, format=None + self, dtype: Dtype, format=None, **kwargs ) -> "cudf.core.column.StringColumn": raise NotImplementedError diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index b96a49c2514..150ce2c48ec 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -241,14 +241,14 @@ def as_timedelta_column( ) def as_numerical_column( - self, dtype: Dtype + self, dtype: Dtype, **kwargs ) -> "cudf.core.column.NumericalColumn": return cast( "cudf.core.column.NumericalColumn", self.as_numerical.astype(dtype) ) def as_string_column( - self, dtype: Dtype, format=None + self, dtype: Dtype, format=None, **kwargs ) -> "cudf.core.column.StringColumn": if format is None: format = _dtype_to_format_conversion.get( diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index b6bd2f18144..2f0ddb78987 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -21,8 +21,8 @@ from cudf.utils.dtypes import is_scalar from cudf.utils.utils import pa_mask_buffer_to_mask -from .numerical_base import NumericalBaseColumn from ...api.types import is_integer_dtype +from .numerical_base import NumericalBaseColumn class DecimalColumn(NumericalBaseColumn): @@ -161,12 +161,12 @@ def as_decimal_column( return libcudf.unary.cast(self, dtype) def as_numerical_column( - self, dtype: Dtype + self, dtype: Dtype, **kwargs ) -> "cudf.core.column.NumericalColumn": return libcudf.unary.cast(self, dtype) def as_string_column( - self, dtype: Dtype, format=None + self, dtype: Dtype, format=None, **kwargs ) -> "cudf.core.column.StringColumn": if len(self) > 0: return cpp_from_decimal(self) diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 17e0b6e454f..64a0780e9f9 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -208,7 +208,7 @@ def int2ip(self) -> "cudf.core.column.StringColumn": return libcudf.string_casting.int2ip(self) def as_string_column( - self, dtype: Dtype, format=None + self, dtype: Dtype, format=None, **kwargs ) -> "cudf.core.column.StringColumn": if len(self) > 0: return string._numeric_to_str_typecast_functions[ @@ -252,7 +252,7 @@ def as_decimal_column( ) -> "cudf.core.column.DecimalColumn": return libcudf.unary.cast(self, dtype) - def as_numerical_column(self, dtype: Dtype) -> NumericalColumn: + def as_numerical_column(self, dtype: Dtype, **kwargs) -> NumericalColumn: dtype = np.dtype(dtype) if dtype == self.dtype: return self diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index dd1c0c1e4ac..c1d98ac5600 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5118,7 +5118,7 @@ def str(self, parent: ParentType = None) -> StringMethods: return StringMethods(self, parent=parent) def as_numerical_column( - self, dtype: Dtype + self, dtype: Dtype, **kwargs ) -> "cudf.core.column.NumericalColumn": out_dtype = np.dtype(dtype) @@ -5195,7 +5195,9 @@ def as_decimal_column( ) -> "cudf.core.column.DecimalColumn": return cpp_to_decimal(self, dtype) - def as_string_column(self, dtype: Dtype, format=None) -> StringColumn: + def as_string_column( + self, dtype: Dtype, format=None, **kwargs + ) -> StringColumn: return self @property diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index b202838662c..a27c20cc50c 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -322,7 +322,7 @@ def fillna( return super().fillna(method=method) def as_numerical_column( - self, dtype: Dtype + self, dtype: Dtype, **kwargs ) -> "cudf.core.column.NumericalColumn": return cast( "cudf.core.column.NumericalColumn", self.as_numerical.astype(dtype) @@ -336,7 +336,7 @@ def as_datetime_column( ) def as_string_column( - self, dtype: Dtype, format=None + self, dtype: Dtype, format=None, **kwargs ) -> "cudf.core.column.StringColumn": if format is None: format = _dtype_to_format_conversion.get( diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 853cfe2e88e..a89b9b58e6e 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -736,6 +736,17 @@ def test_dataframe_astype(nelem): np.testing.assert_equal(df["a"].to_array(), df["b"].to_array()) +def test_astype_dict(): + gdf = cudf.DataFrame({"a": [1, 2, 3], "b": ["1", "2", "3"]}) + pdf = gdf.to_pandas() + + assert_eq(pdf.astype({"a": "str"}), gdf.astype({"a": "str"})) + assert_eq( + pdf.astype({"a": "str", "b": np.int64}), + gdf.astype({"a": "str", "b": np.int64}), + ) + + @pytest.mark.parametrize("nelem", [0, 100]) def test_index_astype(nelem): df = cudf.DataFrame()