From 119ca616fab4e0c4ca14d27b55a05b45ed000ae5 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 1 Mar 2022 16:56:13 -0800 Subject: [PATCH] Unify explode implementations. --- python/cudf/cudf/core/dataframe.py | 7 +----- python/cudf/cudf/core/frame.py | 24 ------------------- python/cudf/cudf/core/indexed_frame.py | 32 ++++++++++++++++++++++++++ python/cudf/cudf/core/series.py | 9 ++------ 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index d9eb938f0a2..c726fbfcace 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -5925,7 +5925,7 @@ def explode(self, column, ignore_index=False): Parameters ---------- - column : str or tuple + column : str Column to explode. ignore_index : bool, default False If True, the resulting index will be labeled 0, 1, …, n - 1. @@ -5960,11 +5960,6 @@ def explode(self, column, ignore_index=False): if column not in self._column_names: raise KeyError(column) - if not is_list_dtype(self._data[column].dtype): - data = self._data.copy(deep=True) - idx = None if ignore_index else self._index.copy(deep=True) - return self.__class__._from_data(data, index=idx) - return super()._explode(column, ignore_index) def pct_change( diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 7a9bc4625be..d54503318c6 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -589,30 +589,6 @@ def equals(self, other, **kwargs): else: return self._index.equals(other._index) - @annotate("FRAME_EXPLODE", color="green", domain="cudf_python") - def _explode(self, explode_column: Any, ignore_index: bool): - """Helper function for `explode` in `Series` and `Dataframe`, explodes - a specified nested column. Other columns' corresponding rows are - duplicated. If ignore_index is set, the original index is not exploded - and will be replaced with a `RangeIndex`. - """ - explode_column_num = self._column_names.index(explode_column) - if not ignore_index and self._index is not None: - explode_column_num += self._index.nlevels - - res = self.__class__._from_data( # type: ignore - *libcudf.lists.explode_outer( - self, explode_column_num, ignore_index - ) - ) - - res._data.multiindex = self._data.multiindex - res._data._level_names = self._data._level_names - - if not ignore_index and self._index is not None: - res.index.names = self._index.names - return res - @annotate( "FRAME_GET_COLUMNS_BY_LABEL", color="green", domain="cudf_python" ) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index c617f4b53f5..cdb4e2c824d 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -24,6 +24,7 @@ is_bool_dtype, is_categorical_dtype, is_integer_dtype, + is_list_dtype, is_list_like, ) from cudf.core.column import ColumnBase @@ -2163,6 +2164,37 @@ def drop( if not inplace: return out + @annotate("INDEXED_FRAME_EXPLODE", color="green", domain="cudf_python") + def _explode(self, explode_column: Any, ignore_index: bool): + # Helper function for `explode` in `Series` and `Dataframe`, explodes a + # specified nested column. Other columns' corresponding rows are + # duplicated. If ignore_index is set, the original index is not + # exploded and will be replaced with a `RangeIndex`. + if not is_list_dtype(self._data[explode_column].dtype): + data = self._data.copy(deep=True) + idx = None if ignore_index else self._index.copy(deep=True) + return self.__class__._from_data(data, index=idx) + + explode_column_num = self._column_names.index(explode_column) + if not ignore_index and self._index is not None: + explode_column_num += self._index.nlevels + + data, index = libcudf.lists.explode_outer( + self, explode_column_num, ignore_index + ) + res = self.__class__._from_data( + ColumnAccessor( + data, + multiindex=self._data.multiindex, + level_names=self._data._level_names, + ), + index=index, + ) + + if not ignore_index and self._index is not None: + res.index.names = self._index.names + return res + def _check_duplicate_level_names(specified, level_names): """Raise if any of `specified` has duplicates in `level_names`.""" diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index bb3d8a4e221..2a809aeb30b 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -3150,7 +3150,7 @@ def explode(self, ignore_index=False): Returns ------- - DataFrame + Series Examples -------- @@ -3172,12 +3172,7 @@ def explode(self, ignore_index=False): 3 5 dtype: int64 """ - if not is_list_dtype(self._column.dtype): - data = self._data.copy(deep=True) - idx = None if ignore_index else self._index.copy(deep=True) - return self.__class__._from_data(data, index=idx) - - return super()._explode(self._column_names[0], ignore_index) + return super()._explode(self.name, ignore_index) def pct_change( self, periods=1, fill_method="ffill", limit=None, freq=None