From fcc89503c1f1e15ec287519959013adcf2bf8a52 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 21 Nov 2023 05:19:11 -1000 Subject: [PATCH] Preserve DataFrame(columns=).columns dtype during empty-like construction (#14381) `.column` used to always return `pd.Index([], dtype=object)` even if an empty-dtyped columns was passed into the DataFrame constructor e.g. `DatetimeIndex([])`. Needed to preserved some information about what column dtype was passed in so we can return a correctly type Index Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/14381 --- python/cudf/cudf/core/column_accessor.py | 14 +++++++++++++- python/cudf/cudf/core/dataframe.py | 6 ++++++ python/cudf/cudf/tests/test_dataframe.py | 11 +++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column_accessor.py b/python/cudf/cudf/core/column_accessor.py index 93105b4a252..b106b8bbb02 100644 --- a/python/cudf/cudf/core/column_accessor.py +++ b/python/cudf/cudf/core/column_accessor.py @@ -27,6 +27,7 @@ from cudf.core import column if TYPE_CHECKING: + from cudf._typing import Dtype from cudf.core.column import ColumnBase @@ -99,6 +100,9 @@ class ColumnAccessor(abc.MutableMapping): rangeindex : bool, optional Whether the keys should be returned as a RangeIndex in `to_pandas_index` (default=False). + label_dtype : Dtype, optional + What dtype should be returned in `to_pandas_index` + (default=None). """ _data: "Dict[Any, ColumnBase]" @@ -111,8 +115,10 @@ def __init__( multiindex: bool = False, level_names=None, rangeindex: bool = False, + label_dtype: Dtype | None = None, ): self.rangeindex = rangeindex + self.label_dtype = label_dtype if data is None: data = {} # TODO: we should validate the keys of `data` @@ -123,6 +129,7 @@ def __init__( self.multiindex = multiindex self._level_names = level_names self.rangeindex = data.rangeindex + self.label_dtype = data.label_dtype else: # This code path is performance-critical for copies and should be # modified with care. @@ -292,7 +299,12 @@ def to_pandas_index(self) -> pd.Index: self.names[0], self.names[-1] + diff, diff ) return pd.RangeIndex(new_range, name=self.name) - result = pd.Index(self.names, name=self.name, tupleize_cols=False) + result = pd.Index( + self.names, + name=self.name, + tupleize_cols=False, + dtype=self.label_dtype, + ) return result def insert( diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index fd4a15a3391..43ae9b9e81e 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -734,6 +734,7 @@ def __init__( rangeindex = isinstance( columns, (range, pd.RangeIndex, cudf.RangeIndex) ) + label_dtype = getattr(columns, "dtype", None) self._data = ColumnAccessor( { k: column.column_empty( @@ -745,6 +746,7 @@ def __init__( if isinstance(columns, pd.Index) else None, rangeindex=rangeindex, + label_dtype=label_dtype, ) elif isinstance(data, ColumnAccessor): raise TypeError( @@ -995,12 +997,15 @@ def _init_from_list_like(self, data, index=None, columns=None): self._data.rangeindex = isinstance( columns, (range, pd.RangeIndex, cudf.RangeIndex) ) + self._data.label_dtype = getattr(columns, "dtype", None) @_cudf_nvtx_annotate def _init_from_dict_like( self, data, index=None, columns=None, nan_as_null=None ): + label_dtype = None if columns is not None: + label_dtype = getattr(columns, "dtype", None) # remove all entries in data that are not in columns, # inserting new empty columns for entries in columns that # are not in data @@ -1069,6 +1074,7 @@ def _init_from_dict_like( if isinstance(columns, pd.Index) else self._data._level_names ) + self._data.label_dtype = label_dtype @classmethod def _from_data( diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 5677f97408a..74165731683 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -4566,6 +4566,17 @@ def test_dataframe_columns_returns_rangeindex_single_col(): assert_eq(result, expected) +@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"]) +@pytest.mark.parametrize("idx_data", [[], [1, 2]]) +@pytest.mark.parametrize("data", [None, [], {}]) +def test_dataframe_columns_empty_data_preserves_dtype(dtype, idx_data, data): + result = cudf.DataFrame( + data, columns=cudf.Index(idx_data, dtype=dtype) + ).columns + expected = pd.Index(idx_data, dtype=dtype) + assert_eq(result, expected) + + @pytest.mark.parametrize( "data", [