From 1fdd62f4f593512addf7d98a07650fd2aab02021 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Mon, 20 Sep 2021 18:48:17 -0500 Subject: [PATCH] Fix duplicate names issue in `MultiIndex.deserialize ` (#9258) Fixes: #9254 This PR fixes `deserialize` in `cudf.MultiIndex` so that there is no data-corruption happening when there are duplicate names. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/9258 --- python/cudf/cudf/core/multiindex.py | 6 ++--- python/cudf/cudf/tests/test_multiindex.py | 31 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 84566b4627c..fba857694e8 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -977,10 +977,10 @@ def deserialize(cls, header, frames): ) df = cudf.DataFrame.deserialize(header["source_data"], frames) obj = cls.from_frame(df) - obj._set_names(names) - return obj + return obj._set_names(names) columns = column.deserialize_columns(header["columns"], frames) - return cls._from_data(dict(zip(names, columns))) + obj = cls._from_data(dict(zip(range(0, len(names)), columns))) + return obj._set_names(names) def __getitem__(self, index): match = self.take(index) diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index 465cf36e1f3..981ab8b63b9 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -5,7 +5,9 @@ """ import itertools import operator +import pickle import re +from io import BytesIO import cupy as cp import numpy as np @@ -1553,3 +1555,32 @@ def test_multiIndex_duplicate_names(): ) assert_eq(gi, pi) + + +@pytest.mark.parametrize( + "names", + [ + ["a", "b", "c"], + [None, None, None], + ["aa", "aa", "aa"], + ["bb", "aa", "aa"], + None, + ], +) +def test_pickle_rountrip_multiIndex(names): + df = cudf.DataFrame( + { + "one": [1, 2, 3], + "two": [True, False, True], + "three": ["ab", "cd", "ef"], + "four": [0.2, 0.1, -10.2], + } + ) + expected_df = df.set_index(["one", "two", "three"]) + expected_df.index.names = names + local_file = BytesIO() + + pickle.dump(expected_df, local_file) + local_file.seek(0) + actual_df = pickle.load(local_file) + assert_eq(expected_df, actual_df)