diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index a01242d957d..547c14cdc99 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -23,6 +23,7 @@ from cudf.api.types import is_integer, is_list_like, is_object_dtype from cudf.core import column from cudf.core._base_index import _return_get_indexer_result +from cudf.core.column_accessor import ColumnAccessor from cudf.core.frame import Frame from cudf.core.index import ( BaseIndex, @@ -446,45 +447,26 @@ def __repr__(self): ) preprocess = self.take(indices) else: - preprocess = self.copy(deep=False) - - if any(col.has_nulls() for col in preprocess._data.columns): - preprocess_df = preprocess.to_frame(index=False) - for name, col in preprocess._data.items(): - if isinstance( - col, - ( - column.datetime.DatetimeColumn, - column.timedelta.TimeDeltaColumn, - ), - ): - preprocess_df[name] = col.astype("str").fillna( - str(cudf.NaT) - ) + preprocess = self - tuples_list = list( - zip( - *list( - map(lambda val: pd.NA if val is None else val, col) - for col in preprocess_df.to_arrow() - .to_pydict() - .values() - ) - ) - ) + arrays = [] + for name, col in zip(self.names, preprocess._columns): + try: + pd_idx = col.to_pandas(nullable=True) + except NotImplementedError: + pd_idx = col.to_pandas(nullable=False) + pd_idx.name = name + arrays.append(pd_idx) - preprocess = preprocess.to_pandas(nullable=True) - preprocess.values[:] = tuples_list - else: - preprocess = preprocess.to_pandas(nullable=True) + preprocess_pd = pd.MultiIndex.from_arrays(arrays) - output = repr(preprocess) + output = repr(preprocess_pd) output_prefix = self.__class__.__name__ + "(" output = output.lstrip(output_prefix) lines = output.split("\n") if len(lines) > 1: - if "length=" in lines[-1] and len(self) != len(preprocess): + if "length=" in lines[-1] and len(self) != len(preprocess_pd): last_line = lines[-1] length_index = last_line.index("length=") last_line = last_line[:length_index] + f"length={len(self)})" @@ -1022,42 +1004,32 @@ def to_frame(self, index=True, name=no_default, allow_duplicates=False): a c a c b d b d """ - # TODO: Currently this function makes a shallow copy, which is - # incorrect. We want to make a deep copy, otherwise further - # modifications of the resulting DataFrame will affect the MultiIndex. if name is no_default: column_names = [ level if name is None else name for level, name in enumerate(self.names) ] + elif not is_list_like(name): + raise TypeError( + "'name' must be a list / sequence of column names." + ) + elif len(name) != len(self.levels): + raise ValueError( + "'name' should have the same length as " + "number of levels on index." + ) else: - if not is_list_like(name): - raise TypeError( - "'name' must be a list / sequence of column names." - ) - if len(name) != len(self.levels): - raise ValueError( - "'name' should have the same length as " - "number of levels on index." - ) column_names = name - all_none_names = None - if not ( - all_none_names := all(x is None for x in column_names) - ) and len(column_names) != len(set(column_names)): + if len(column_names) != len(set(column_names)): raise ValueError("Duplicate column names are not allowed") - df = cudf.DataFrame._from_data( - data=self._data, - columns=column_names - if name is not no_default and not all_none_names - else None, + ca = ColumnAccessor( + dict(zip(column_names, (col.copy() for col in self._columns))), + verify=False, + ) + return cudf.DataFrame._from_data( + data=ca, index=self if index else None ) - - if index: - df = df.set_index(self) - - return df @_cudf_nvtx_annotate def get_level_values(self, level): @@ -1243,7 +1215,7 @@ def values(self): @classmethod @_cudf_nvtx_annotate - def from_frame(cls, df, names=None): + def from_frame(cls, df: pd.DataFrame | cudf.DataFrame, names=None): """ Make a MultiIndex from a DataFrame. diff --git a/python/cudf/cudf/tests/test_repr.py b/python/cudf/cudf/tests/test_repr.py index 8f65bd26bd1..193d64a9e7f 100644 --- a/python/cudf/cudf/tests/test_repr.py +++ b/python/cudf/cudf/tests/test_repr.py @@ -1210,7 +1210,7 @@ def test_multiindex_repr(pmi, max_seq_items): .index, textwrap.dedent( """ - MultiIndex([('abc', 'NaT', 0.345), + MultiIndex([('abc', NaT, 0.345), ( , '0 days 00:00:00.000000001', ), ('xyz', '0 days 00:00:00.000000002', 100.0), ( , '0 days 00:00:00.000000003', 10.0)], @@ -1252,10 +1252,10 @@ def test_multiindex_repr(pmi, max_seq_items): .index, textwrap.dedent( """ - MultiIndex([('NaT', ), - ('NaT', ), - ('NaT', ), - ('NaT', )], + MultiIndex([(NaT, ), + (NaT, ), + (NaT, ), + (NaT, )], names=['b', 'a']) """ ),