Skip to content

Commit

Permalink
Fix DataFrame.reindex when column reindexing to MultiIndex/RangeIndex (
Browse files Browse the repository at this point in the history
…#14605)

1. `reindex(columns=cudf.MultiIndex)` will raise since `set(cudf.MultiIndex)` is being called
2. `reindex(columns=range/RangeIndex)` would not necessarily retain a RangeIndex result


Discovered during refactoring `DataFrame.__init__` and further convinces me that the `column` should be passed to `ColumnAccessor` eventually

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #14605
  • Loading branch information
mroeschke authored Dec 13, 2023
1 parent a894ca0 commit 8136a16
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 16 deletions.
13 changes: 8 additions & 5 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2744,11 +2744,14 @@ def reindex(
else:
if columns is None:
columns = labels
df = (
self
if columns is None
else self[list(set(self._column_names) & set(columns))]
)
if columns is None:
df = self
else:
columns = as_index(columns)
intersection = self._data.to_pandas_index().intersection(
columns.to_pandas()
)
df = self.loc[:, intersection]

return df._reindex(
column_names=columns,
Expand Down
44 changes: 33 additions & 11 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ def _reindex(
Parameters
----------
columns_names : array-like
The list of columns to select from the Frame,
array-like of columns to select from the Frame,
if ``columns`` is a superset of ``Frame.columns`` new
columns are created.
dtypes : dict
Expand Down Expand Up @@ -2638,9 +2638,35 @@ def _reindex(
df = df.take(index.argsort(ascending=True).argsort())

index = index if index is not None else df.index
names = (
column_names if column_names is not None else list(df._data.names)
)

if column_names is None:
names = list(df._data.names)
level_names = self._data.level_names
multiindex = self._data.multiindex
rangeindex = self._data.rangeindex
elif isinstance(column_names, (pd.Index, cudf.Index)):
if isinstance(column_names, (pd.MultiIndex, cudf.MultiIndex)):
multiindex = True
if isinstance(column_names, cudf.MultiIndex):
names = list(iter(column_names.to_pandas()))
else:
names = list(iter(column_names))
rangeindex = False
else:
multiindex = False
names = column_names
if isinstance(names, cudf.Index):
names = names.to_pandas()
rangeindex = isinstance(
column_names, (pd.RangeIndex, cudf.RangeIndex)
)
level_names = tuple(column_names.names)
else:
names = column_names
level_names = None
multiindex = False
rangeindex = False

cols = {
name: (
df._data[name].copy(deep=deep)
Expand All @@ -2653,17 +2679,13 @@ def _reindex(
)
for name in names
}
if column_names is None:
level_names = self._data.level_names
elif isinstance(column_names, pd.Index):
level_names = tuple(column_names.names)
else:
level_names = None

result = self.__class__._from_data(
data=cudf.core.column_accessor.ColumnAccessor(
cols,
multiindex=self._data.multiindex,
multiindex=multiindex,
level_names=level_names,
rangeindex=rangeindex,
),
index=index,
)
Expand Down
45 changes: 45 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3428,6 +3428,51 @@ def test_series_string_reindex(copy):
)


@pytest.mark.parametrize("names", [None, ["a", "b"]])
@pytest.mark.parametrize("klass", [cudf.MultiIndex, pd.MultiIndex])
def test_reindex_multiindex_col_to_multiindex(names, klass):
idx = pd.Index(
[("A", "one"), ("A", "two")],
dtype="object",
)
df = pd.DataFrame([[1, 2]], columns=idx)
gdf = cudf.from_pandas(df)
midx = klass.from_tuples([("A", "one"), ("A", "three")], names=names)
result = gdf.reindex(columns=midx)
expected = cudf.DataFrame([[1, None]], columns=midx)
# (pandas2.0): check_dtype=False won't be needed
# as None col will return object instead of float
assert_eq(result, expected, check_dtype=False)


@pytest.mark.parametrize("names", [None, ["a", "b"]])
@pytest.mark.parametrize("klass", [cudf.MultiIndex, pd.MultiIndex])
def test_reindex_tuple_col_to_multiindex(names, klass):
idx = pd.Index(
[("A", "one"), ("A", "two")], dtype="object", tupleize_cols=False
)
df = pd.DataFrame([[1, 2]], columns=idx)
gdf = cudf.from_pandas(df)
midx = klass.from_tuples([("A", "one"), ("A", "two")], names=names)
result = gdf.reindex(columns=midx)
expected = cudf.DataFrame([[1, 2]], columns=midx)
assert_eq(result, expected)


@pytest.mark.parametrize("name", [None, "foo"])
@pytest.mark.parametrize("klass", [range, cudf.RangeIndex, pd.RangeIndex])
def test_reindex_columns_rangeindex_keeps_rangeindex(name, klass):
new_columns = klass(3)
exp_name = None
if klass is not range:
new_columns.name = name
exp_name = name
df = cudf.DataFrame([[1, 2]])
result = df.reindex(columns=new_columns).columns
expected = pd.RangeIndex(3, name=exp_name)
assert_eq(result, expected)


def test_to_frame(pdf, gdf):
assert_eq(pdf.x.to_frame(), gdf.x.to_frame())

Expand Down

0 comments on commit 8136a16

Please sign in to comment.