Skip to content

Commit

Permalink
Preserve index name in reindex (#13917)
Browse files Browse the repository at this point in the history
Fixes: #13900 

This PR fixes an issue with `reindex` API, where `name` of the index being reindexed upon was lost. This PR fixes it to match pandas by using the new index name if it exists or preserving the old name.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #13917
  • Loading branch information
galipremsagar authored Aug 18, 2023
1 parent b798a70 commit 263a85d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
27 changes: 25 additions & 2 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,10 +2582,12 @@ def _reindex(

df = self
if index is not None:
index = cudf.core.index.as_index(index)
index = cudf.core.index.as_index(
index, name=getattr(index, "name", self._index.name)
)

idx_dtype_match = (df.index.nlevels == index.nlevels) and all(
left_dtype == right_dtype
_is_same_dtype(left_dtype, right_dtype)
for left_dtype, right_dtype in zip(
(col.dtype for col in df.index._data.columns),
(col.dtype for col in index._data.columns),
Expand Down Expand Up @@ -5405,3 +5407,24 @@ def _drop_rows_by_labels(
res = obj.to_frame(name="tmp").join(key_df, how="leftanti")["tmp"]
res.name = obj.name
return res


def _is_same_dtype(lhs_dtype, rhs_dtype):
# Utility specific to `_reindex` to check
# for matching column dtype.
if lhs_dtype == rhs_dtype:
return True
elif (
is_categorical_dtype(lhs_dtype)
and not is_categorical_dtype(rhs_dtype)
and lhs_dtype.categories.dtype == rhs_dtype
):
return True
elif (
is_categorical_dtype(rhs_dtype)
and not is_categorical_dtype(lhs_dtype)
and rhs_dtype.categories.dtype == lhs_dtype
):
return True
else:
return False
28 changes: 28 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10288,3 +10288,31 @@ def test_dataframe_mixed_dtype_error(dtype):
pdf = pd.Series([1, 2, 3], dtype=dtype).to_frame().astype(object)
with pytest.raises(TypeError):
cudf.from_pandas(pdf)


@pytest.mark.parametrize(
"index_data,name",
[([10, 13], "a"), ([30, 40, 20], "b"), (["ef"], "c"), ([2, 3], "Z")],
)
def test_dataframe_reindex_with_index_names(index_data, name):
gdf = cudf.DataFrame(
{
"a": [10, 12, 13],
"b": [20, 30, 40],
"c": cudf.Series(["ab", "cd", "ef"], dtype="category"),
}
)
if name in gdf.columns:
gdf = gdf.set_index(name)
pdf = gdf.to_pandas()

gidx = cudf.Index(index_data, name=name)
actual = gdf.reindex(gidx)
expected = pdf.reindex(gidx.to_pandas())

assert_eq(actual, expected)

actual = gdf.reindex(index_data)
expected = pdf.reindex(index_data)

assert_eq(actual, expected)

0 comments on commit 263a85d

Please sign in to comment.