From 263a85d70edbf08232beb3286c1a2d0f08afe76e Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Fri, 18 Aug 2023 17:49:38 -0500 Subject: [PATCH] Preserve index `name` in `reindex` (#13917) Fixes: #13900 This PR fixes an issue with `reindex` API, where `name` of the index being reindexed upon was lost. This PR fixes it to match pandas by using the new index name if it exists or preserving the old name. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/13917 --- python/cudf/cudf/core/indexed_frame.py | 27 +++++++++++++++++++++-- python/cudf/cudf/tests/test_dataframe.py | 28 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 51a2d085d00..8e6cdbb2787 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -2582,10 +2582,12 @@ def _reindex( df = self if index is not None: - index = cudf.core.index.as_index(index) + index = cudf.core.index.as_index( + index, name=getattr(index, "name", self._index.name) + ) idx_dtype_match = (df.index.nlevels == index.nlevels) and all( - left_dtype == right_dtype + _is_same_dtype(left_dtype, right_dtype) for left_dtype, right_dtype in zip( (col.dtype for col in df.index._data.columns), (col.dtype for col in index._data.columns), @@ -5405,3 +5407,24 @@ def _drop_rows_by_labels( res = obj.to_frame(name="tmp").join(key_df, how="leftanti")["tmp"] res.name = obj.name return res + + +def _is_same_dtype(lhs_dtype, rhs_dtype): + # Utility specific to `_reindex` to check + # for matching column dtype. + if lhs_dtype == rhs_dtype: + return True + elif ( + is_categorical_dtype(lhs_dtype) + and not is_categorical_dtype(rhs_dtype) + and lhs_dtype.categories.dtype == rhs_dtype + ): + return True + elif ( + is_categorical_dtype(rhs_dtype) + and not is_categorical_dtype(lhs_dtype) + and rhs_dtype.categories.dtype == lhs_dtype + ): + return True + else: + return False diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 2e1e20dee40..0501874ecda 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -10288,3 +10288,31 @@ def test_dataframe_mixed_dtype_error(dtype): pdf = pd.Series([1, 2, 3], dtype=dtype).to_frame().astype(object) with pytest.raises(TypeError): cudf.from_pandas(pdf) + + +@pytest.mark.parametrize( + "index_data,name", + [([10, 13], "a"), ([30, 40, 20], "b"), (["ef"], "c"), ([2, 3], "Z")], +) +def test_dataframe_reindex_with_index_names(index_data, name): + gdf = cudf.DataFrame( + { + "a": [10, 12, 13], + "b": [20, 30, 40], + "c": cudf.Series(["ab", "cd", "ef"], dtype="category"), + } + ) + if name in gdf.columns: + gdf = gdf.set_index(name) + pdf = gdf.to_pandas() + + gidx = cudf.Index(index_data, name=name) + actual = gdf.reindex(gidx) + expected = pdf.reindex(gidx.to_pandas()) + + assert_eq(actual, expected) + + actual = gdf.reindex(index_data) + expected = pdf.reindex(index_data) + + assert_eq(actual, expected)