Skip to content

Commit

Permalink
Fix astype/fillna not maintaining column subclass and types
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Dec 12, 2023
1 parent 0fa80ec commit 5f8c6d3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,17 @@ def _create_unsafe(
data: Dict[Any, ColumnBase],
multiindex: bool = False,
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
) -> ColumnAccessor:
# create a ColumnAccessor without verifying column
# type or size
obj = cls()
obj._data = data
obj.multiindex = multiindex
obj._level_names = level_names
obj.rangeindex = rangeindex
obj.label_dtype = label_dtype
return obj

def __iter__(self):
Expand Down
4 changes: 4 additions & 0 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ def astype(self, dtype, copy=False, **kwargs):
data=result_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
)

@_cudf_nvtx_annotate
Expand Down Expand Up @@ -876,6 +878,8 @@ def fillna(
data=filled_data,
multiindex=self._data.multiindex,
level_names=self._data.level_names,
rangeindex=self._data.rangeindex,
label_dtype=self._data.label_dtype,
)
),
inplace=inplace,
Expand Down
30 changes: 30 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4643,6 +4643,36 @@ def test_dataframe_columns_empty_data_preserves_dtype(dtype, idx_data, data):
assert_eq(result, expected)


@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"])
def test_dataframe_astype_preserves_column_dtype(dtype):
result = cudf.DataFrame([1], columns=cudf.Index([1], dtype=dtype))
result = result.astype(np.int32).columns
expected = pd.Index([1], dtype=dtype)
assert_eq(result, expected)


def test_dataframe_astype_preserves_column_rangeindex():
result = cudf.DataFrame([1], columns=range(1))
result = result.astype(np.int32).columns
expected = pd.RangeIndex(1)
assert_eq(result, expected)


@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"])
def test_dataframe_fillna_preserves_column_dtype(dtype):
result = cudf.DataFrame([1, None], columns=cudf.Index([1], dtype=dtype))
result = result.fillna(2).columns
expected = pd.Index([1], dtype=dtype)
assert_eq(result, expected)


def test_dataframe_fillna_preserves_column_rangeindex():
result = cudf.DataFrame([1, None], columns=range(1))
result = result.fillna(2).columns
expected = pd.RangeIndex(1)
assert_eq(result, expected)


@pytest.mark.parametrize(
"data",
[
Expand Down

0 comments on commit 5f8c6d3

Please sign in to comment.