Skip to content

Commit

Permalink
Preserve DataFrame(columns=).columns dtype during empty-like construc…
Browse files Browse the repository at this point in the history
…tion (#14381)

`.column` used to always return `pd.Index([], dtype=object)` even if an empty-dtyped columns was passed into the DataFrame constructor e.g. `DatetimeIndex([])`. Needed to preserved some information about what column dtype was passed in so we can return a correctly type Index

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #14381
  • Loading branch information
mroeschke authored Nov 21, 2023
1 parent 947081f commit fcc8950
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
14 changes: 13 additions & 1 deletion python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from cudf.core import column

if TYPE_CHECKING:
from cudf._typing import Dtype
from cudf.core.column import ColumnBase


Expand Down Expand Up @@ -99,6 +100,9 @@ class ColumnAccessor(abc.MutableMapping):
rangeindex : bool, optional
Whether the keys should be returned as a RangeIndex
in `to_pandas_index` (default=False).
label_dtype : Dtype, optional
What dtype should be returned in `to_pandas_index`
(default=None).
"""

_data: "Dict[Any, ColumnBase]"
Expand All @@ -111,8 +115,10 @@ def __init__(
multiindex: bool = False,
level_names=None,
rangeindex: bool = False,
label_dtype: Dtype | None = None,
):
self.rangeindex = rangeindex
self.label_dtype = label_dtype
if data is None:
data = {}
# TODO: we should validate the keys of `data`
Expand All @@ -123,6 +129,7 @@ def __init__(
self.multiindex = multiindex
self._level_names = level_names
self.rangeindex = data.rangeindex
self.label_dtype = data.label_dtype
else:
# This code path is performance-critical for copies and should be
# modified with care.
Expand Down Expand Up @@ -292,7 +299,12 @@ def to_pandas_index(self) -> pd.Index:
self.names[0], self.names[-1] + diff, diff
)
return pd.RangeIndex(new_range, name=self.name)
result = pd.Index(self.names, name=self.name, tupleize_cols=False)
result = pd.Index(
self.names,
name=self.name,
tupleize_cols=False,
dtype=self.label_dtype,
)
return result

def insert(
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def __init__(
rangeindex = isinstance(
columns, (range, pd.RangeIndex, cudf.RangeIndex)
)
label_dtype = getattr(columns, "dtype", None)
self._data = ColumnAccessor(
{
k: column.column_empty(
Expand All @@ -745,6 +746,7 @@ def __init__(
if isinstance(columns, pd.Index)
else None,
rangeindex=rangeindex,
label_dtype=label_dtype,
)
elif isinstance(data, ColumnAccessor):
raise TypeError(
Expand Down Expand Up @@ -995,12 +997,15 @@ def _init_from_list_like(self, data, index=None, columns=None):
self._data.rangeindex = isinstance(
columns, (range, pd.RangeIndex, cudf.RangeIndex)
)
self._data.label_dtype = getattr(columns, "dtype", None)

@_cudf_nvtx_annotate
def _init_from_dict_like(
self, data, index=None, columns=None, nan_as_null=None
):
label_dtype = None
if columns is not None:
label_dtype = getattr(columns, "dtype", None)
# remove all entries in data that are not in columns,
# inserting new empty columns for entries in columns that
# are not in data
Expand Down Expand Up @@ -1069,6 +1074,7 @@ def _init_from_dict_like(
if isinstance(columns, pd.Index)
else self._data._level_names
)
self._data.label_dtype = label_dtype

@classmethod
def _from_data(
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4566,6 +4566,17 @@ def test_dataframe_columns_returns_rangeindex_single_col():
assert_eq(result, expected)


@pytest.mark.parametrize("dtype", ["int64", "datetime64[ns]", "int8"])
@pytest.mark.parametrize("idx_data", [[], [1, 2]])
@pytest.mark.parametrize("data", [None, [], {}])
def test_dataframe_columns_empty_data_preserves_dtype(dtype, idx_data, data):
result = cudf.DataFrame(
data, columns=cudf.Index(idx_data, dtype=dtype)
).columns
expected = pd.Index(idx_data, dtype=dtype)
assert_eq(result, expected)


@pytest.mark.parametrize(
"data",
[
Expand Down

0 comments on commit fcc8950

Please sign in to comment.