From 5f8c6d338266c38d45bf7c619b3b2da078e1513a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 12 Dec 2023 09:13:16 -0800 Subject: [PATCH] Fix astype/fillna not maintaining column subclass and types --- python/cudf/cudf/core/column_accessor.py | 4 ++++ python/cudf/cudf/core/frame.py | 4 ++++ python/cudf/cudf/tests/test_dataframe.py | 30 ++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/python/cudf/cudf/core/column_accessor.py b/python/cudf/cudf/core/column_accessor.py index b106b8bbb02..021d4994613 100644 --- a/python/cudf/cudf/core/column_accessor.py +++ b/python/cudf/cudf/core/column_accessor.py @@ -157,6 +157,8 @@ def _create_unsafe( data: Dict[Any, ColumnBase], multiindex: bool = False, level_names=None, + rangeindex: bool = False, + label_dtype: Dtype | None = None, ) -> ColumnAccessor: # create a ColumnAccessor without verifying column # type or size @@ -164,6 +166,8 @@ def _create_unsafe( obj._data = data obj.multiindex = multiindex obj._level_names = level_names + obj.rangeindex = rangeindex + obj.label_dtype = label_dtype return obj def __iter__(self): diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index b2f0651d576..e1b2f7d674d 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -280,6 +280,8 @@ def astype(self, dtype, copy=False, **kwargs): data=result_data, multiindex=self._data.multiindex, level_names=self._data.level_names, + rangeindex=self._data.rangeindex, + label_dtype=self._data.label_dtype, ) @_cudf_nvtx_annotate @@ -876,6 +878,8 @@ def fillna( data=filled_data, multiindex=self._data.multiindex, level_names=self._data.level_names, + rangeindex=self._data.rangeindex, + label_dtype=self._data.label_dtype, ) ), inplace=inplace, diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 2a0edf09079..334b63ae13d 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -4643,6 +4643,36 @@ def test_dataframe_columns_empty_data_preserves_dtype(dtype, idx_data, data): assert_eq(result, expected) +@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"]) +def test_dataframe_astype_preserves_column_dtype(dtype): + result = cudf.DataFrame([1], columns=cudf.Index([1], dtype=dtype)) + result = result.astype(np.int32).columns + expected = pd.Index([1], dtype=dtype) + assert_eq(result, expected) + + +def test_dataframe_astype_preserves_column_rangeindex(): + result = cudf.DataFrame([1], columns=range(1)) + result = result.astype(np.int32).columns + expected = pd.RangeIndex(1) + assert_eq(result, expected) + + +@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"]) +def test_dataframe_fillna_preserves_column_dtype(dtype): + result = cudf.DataFrame([1, None], columns=cudf.Index([1], dtype=dtype)) + result = result.fillna(2).columns + expected = pd.Index([1], dtype=dtype) + assert_eq(result, expected) + + +def test_dataframe_fillna_preserves_column_rangeindex(): + result = cudf.DataFrame([1, None], columns=range(1)) + result = result.fillna(2).columns + expected = pd.RangeIndex(1) + assert_eq(result, expected) + + @pytest.mark.parametrize( "data", [