Skip to content

Commit

Permalink
Unify explode implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
vyasr committed Mar 2, 2022
1 parent 8b966c0 commit 119ca61
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 37 deletions.
7 changes: 1 addition & 6 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5925,7 +5925,7 @@ def explode(self, column, ignore_index=False):
Parameters
----------
column : str or tuple
column : str
Column to explode.
ignore_index : bool, default False
If True, the resulting index will be labeled 0, 1, …, n - 1.
Expand Down Expand Up @@ -5960,11 +5960,6 @@ def explode(self, column, ignore_index=False):
if column not in self._column_names:
raise KeyError(column)

if not is_list_dtype(self._data[column].dtype):
data = self._data.copy(deep=True)
idx = None if ignore_index else self._index.copy(deep=True)
return self.__class__._from_data(data, index=idx)

return super()._explode(column, ignore_index)

def pct_change(
Expand Down
24 changes: 0 additions & 24 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,30 +589,6 @@ def equals(self, other, **kwargs):
else:
return self._index.equals(other._index)

@annotate("FRAME_EXPLODE", color="green", domain="cudf_python")
def _explode(self, explode_column: Any, ignore_index: bool):
"""Helper function for `explode` in `Series` and `Dataframe`, explodes
a specified nested column. Other columns' corresponding rows are
duplicated. If ignore_index is set, the original index is not exploded
and will be replaced with a `RangeIndex`.
"""
explode_column_num = self._column_names.index(explode_column)
if not ignore_index and self._index is not None:
explode_column_num += self._index.nlevels

res = self.__class__._from_data( # type: ignore
*libcudf.lists.explode_outer(
self, explode_column_num, ignore_index
)
)

res._data.multiindex = self._data.multiindex
res._data._level_names = self._data._level_names

if not ignore_index and self._index is not None:
res.index.names = self._index.names
return res

@annotate(
"FRAME_GET_COLUMNS_BY_LABEL", color="green", domain="cudf_python"
)
Expand Down
32 changes: 32 additions & 0 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_bool_dtype,
is_categorical_dtype,
is_integer_dtype,
is_list_dtype,
is_list_like,
)
from cudf.core.column import ColumnBase
Expand Down Expand Up @@ -2163,6 +2164,37 @@ def drop(
if not inplace:
return out

@annotate("INDEXED_FRAME_EXPLODE", color="green", domain="cudf_python")
def _explode(self, explode_column: Any, ignore_index: bool):
# Helper function for `explode` in `Series` and `Dataframe`, explodes a
# specified nested column. Other columns' corresponding rows are
# duplicated. If ignore_index is set, the original index is not
# exploded and will be replaced with a `RangeIndex`.
if not is_list_dtype(self._data[explode_column].dtype):
data = self._data.copy(deep=True)
idx = None if ignore_index else self._index.copy(deep=True)
return self.__class__._from_data(data, index=idx)

explode_column_num = self._column_names.index(explode_column)
if not ignore_index and self._index is not None:
explode_column_num += self._index.nlevels

data, index = libcudf.lists.explode_outer(
self, explode_column_num, ignore_index
)
res = self.__class__._from_data(
ColumnAccessor(
data,
multiindex=self._data.multiindex,
level_names=self._data._level_names,
),
index=index,
)

if not ignore_index and self._index is not None:
res.index.names = self._index.names
return res


def _check_duplicate_level_names(specified, level_names):
"""Raise if any of `specified` has duplicates in `level_names`."""
Expand Down
9 changes: 2 additions & 7 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3150,7 +3150,7 @@ def explode(self, ignore_index=False):
Returns
-------
DataFrame
Series
Examples
--------
Expand All @@ -3172,12 +3172,7 @@ def explode(self, ignore_index=False):
3 5
dtype: int64
"""
if not is_list_dtype(self._column.dtype):
data = self._data.copy(deep=True)
idx = None if ignore_index else self._index.copy(deep=True)
return self.__class__._from_data(data, index=idx)

return super()._explode(self._column_names[0], ignore_index)
return super()._explode(self.name, ignore_index)

def pct_change(
self, periods=1, fill_method="ffill", limit=None, freq=None
Expand Down

0 comments on commit 119ca61

Please sign in to comment.