Skip to content

Commit

Permalink
Fix duplicate names issue in MultiIndex.deserialize (#9258)
Browse files Browse the repository at this point in the history
Fixes: #9254 

This PR fixes `deserialize` in `cudf.MultiIndex` so that there is no data-corruption happening when there are duplicate names.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #9258
  • Loading branch information
galipremsagar authored Sep 20, 2021
1 parent 625810a commit 1fdd62f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,10 @@ def deserialize(cls, header, frames):
)
df = cudf.DataFrame.deserialize(header["source_data"], frames)
obj = cls.from_frame(df)
obj._set_names(names)
return obj
return obj._set_names(names)
columns = column.deserialize_columns(header["columns"], frames)
return cls._from_data(dict(zip(names, columns)))
obj = cls._from_data(dict(zip(range(0, len(names)), columns)))
return obj._set_names(names)

def __getitem__(self, index):
match = self.take(index)
Expand Down
31 changes: 31 additions & 0 deletions python/cudf/cudf/tests/test_multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"""
import itertools
import operator
import pickle
import re
from io import BytesIO

import cupy as cp
import numpy as np
Expand Down Expand Up @@ -1553,3 +1555,32 @@ def test_multiIndex_duplicate_names():
)

assert_eq(gi, pi)


@pytest.mark.parametrize(
"names",
[
["a", "b", "c"],
[None, None, None],
["aa", "aa", "aa"],
["bb", "aa", "aa"],
None,
],
)
def test_pickle_rountrip_multiIndex(names):
df = cudf.DataFrame(
{
"one": [1, 2, 3],
"two": [True, False, True],
"three": ["ab", "cd", "ef"],
"four": [0.2, 0.1, -10.2],
}
)
expected_df = df.set_index(["one", "two", "three"])
expected_df.index.names = names
local_file = BytesIO()

pickle.dump(expected_df, local_file)
local_file.seek(0)
actual_df = pickle.load(local_file)
assert_eq(expected_df, actual_df)

0 comments on commit 1fdd62f

Please sign in to comment.