diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 35588725655..78e5b17301a 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -5486,14 +5486,18 @@ def from_arrow(cls, table): return out @_cudf_nvtx_annotate - def to_arrow(self, preserve_index=True): + def to_arrow(self, preserve_index=None): """ Convert to a PyArrow Table. Parameters ---------- - preserve_index : bool, default True - whether index column and its meta data needs to be saved or not + preserve_index : bool, optional + whether index column and its meta data needs to be saved + or not. The default of None will store the index as a + column, except for a RangeIndex which is stored as + metadata only. Setting preserve_index to True will force + a RangeIndex to be materialized. Returns ------- @@ -5524,34 +5528,35 @@ def to_arrow(self, preserve_index=True): data = self.copy(deep=False) index_descr = [] - if preserve_index: - if isinstance(self.index, cudf.RangeIndex): + write_index = preserve_index is not False + keep_range_index = write_index and preserve_index is None + index = self.index + if write_index: + if isinstance(index, cudf.RangeIndex) and keep_range_index: descr = { "kind": "range", - "name": self.index.name, - "start": self.index._start, - "stop": self.index._stop, + "name": index.name, + "start": index._start, + "stop": index._stop, "step": 1, } else: - if isinstance(self.index, MultiIndex): + if isinstance(index, cudf.RangeIndex): + index = index._as_int_index() + index.name = "__index_level_0__" + if isinstance(index, MultiIndex): gen_names = tuple( - f"level_{i}" - for i, _ in enumerate(self.index._data.names) + f"level_{i}" for i, _ in enumerate(index._data.names) ) else: gen_names = ( - self.index.names - if self.index.name is not None - else ("index",) + index.names if index.name is not None else ("index",) ) - for gen_name, col_name in zip( - gen_names, self.index._data.names - ): + for gen_name, col_name in zip(gen_names, index._data.names): data._insert( data.shape[1], gen_name, - self.index._data[col_name], + index._data[col_name], ) descr = gen_names[0] index_descr.append(descr) @@ -5561,7 +5566,7 @@ def to_arrow(self, preserve_index=True): columns_to_convert=[self[col] for col in self._data.names], df=self, column_names=out.schema.names, - index_levels=[self.index], + index_levels=[index], index_descriptors=index_descr, preserve_index=preserve_index, types=out.schema.types,