Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Preserve Series/DataFrame subclasses through groupby operations #33884

Merged
merged 38 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7da0703
Add tests to confirm groupby ops preserve subclasses
JBGreisman Apr 29, 2020
f6e1a89
Update SeriesGroupBy constructor calls to preserve subclassed Series
JBGreisman Apr 29, 2020
ee6de43
Fix concat() method to preserve subclassed DataFrames
JBGreisman Apr 29, 2020
3f9f4c4
Add GH28330 comment to concat.py
JBGreisman Apr 29, 2020
0112826
Preserve subclassing in DataFrame.idxmin() and DataFrame.idxmax() calls
JBGreisman Apr 29, 2020
422e702
GH28330 Fix GroupBy.size() to preserve subclassed types
JBGreisman Apr 29, 2020
53ac397
GH28330 Fix GroupBy.ngroup() to preserve subclassed types
JBGreisman Apr 29, 2020
2bc2520
GH28330 Fix GroupBy.cumcount() to preserve subclassed types
JBGreisman Apr 29, 2020
8d9a885
GH28330 Fix constructor calls to preserve subclasses through groupby()
JBGreisman Apr 29, 2020
e4d7fa8
Fix typo -- Series.constructor() to Series._constructor()
JBGreisman Apr 29, 2020
1dbe986
Remove DeprecationWarning due to empty Series construction
JBGreisman Apr 30, 2020
c998422
BUG: GH28330 Preserve subclassing with groupby operations
JBGreisman Apr 30, 2020
abdb861
BUG: GH28330 Preserve subclassing with groupby operations
JBGreisman Apr 30, 2020
d36ad6d
Merge remote-tracking branch 'upstream/master' into groupby-preserve-…
JBGreisman Apr 30, 2020
5b83062
Fix formatting of .py files with black
JBGreisman Apr 30, 2020
b6ea731
Removed trailing whitespace in doc/source/whatsnew/v1.1.0.rst
JBGreisman Apr 30, 2020
0cdf0ea
Update DataFrameGroupBy._cython_agg_blocks() to pass mypy
JBGreisman Apr 30, 2020
6e48e07
Merge remote-tracking branch 'upstream/master' into groupby-preserve-…
JBGreisman Apr 30, 2020
a70c21a
Remove unused import of typing.cast from pandas/core/groupby/generic.py
JBGreisman Apr 30, 2020
c03d459
Move tests to test_groupby_subclass.py
JBGreisman Apr 30, 2020
9e42c79
Add tests for DataFrame.idxmin() and DataFrame.idxmax() with subclasses
JBGreisman Apr 30, 2020
9fbc645
Add test to confirm concat() preserves subclassed types
JBGreisman Apr 30, 2020
5750d72
Update whatsnew entry bugfix
JBGreisman Apr 30, 2020
7f4c5a7
Merge remote-tracking branch 'upstream/master' into groupby-preserve-…
JBGreisman Apr 30, 2020
8eee73c
Revert unnecessary changes in GroupBy()
JBGreisman May 1, 2020
4b304c1
Fix test to expect Series from GroupBy.ngroup() and GroupBy.cumcount()
JBGreisman May 1, 2020
b3e039a
Fix formatting of groupby.py
JBGreisman May 1, 2020
b1118de
Merge remote-tracking branch 'upstream/master' into groupby-preserve-…
JBGreisman May 1, 2020
a490e38
Avoid DeprecationWarning by checking for instance of Series
JBGreisman May 1, 2020
5bcf9fa
Merge remote-tracking branch 'upstream/master' into groupby-preserve-…
JBGreisman May 1, 2020
0244b36
Remove unnecessary constructor call in DataFrameGroupBy._cython_agg_b…
JBGreisman May 3, 2020
dcd4692
Fix mypy static typing issue in DataFrameGroupBy._cython_agg_blocks()
JBGreisman May 3, 2020
a92c51b
Ensure consistent return types for GroupBy.size(), ngroup(), and cumc…
JBGreisman May 3, 2020
cf3b978
Revert DataFrameGroupBy._cython_agg_blocks() back to origin/master
JBGreisman May 11, 2020
f08cf59
Add GroupBy._constructor() to facilitate preserving subclassed types
JBGreisman May 11, 2020
d2a7de2
Change DataFrameGroupBy._transform_fast() to use _constructor property
JBGreisman May 12, 2020
37ea97f
Restructure GroupBy._constructor() to remove else statement
JBGreisman May 12, 2020
f1570da
Rename GroupBy._constructor property to GroupBy._obj_1d_constructor
JBGreisman May 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.groupby` where a ``ValueError`` would be raised when grouping by a categorical column with read-only categories and ``sort=False`` (:issue:`33410`)
- Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`)
- Bug in :meth:`Rolling.min` and :meth:`Rolling.max`: Growing memory usage after multiple calls when using a fixed window (:issue:`30726`)
- Bug in :meth:`GroupBy.agg`, :meth:`GroupBy.transform`, and :meth:`GroupBy.resample` where subclasses are not preserved (:issue:`28330`)

Reshaping
^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8524,7 +8524,7 @@ def idxmin(self, axis=0, skipna=True) -> Series:
indices = nanops.nanargmin(self.values, axis=axis, skipna=skipna)
index = self._get_axis(axis)
result = [index[i] if i >= 0 else np.nan for i in indices]
return Series(result, index=self._get_agg_axis(axis))
return self._constructor_sliced(result, index=self._get_agg_axis(axis))

def idxmax(self, axis=0, skipna=True) -> Series:
"""
Expand Down Expand Up @@ -8591,7 +8591,7 @@ def idxmax(self, axis=0, skipna=True) -> Series:
indices = nanops.nanargmax(self.values, axis=axis, skipna=skipna)
index = self._get_axis(axis)
result = [index[i] if i >= 0 else np.nan for i in indices]
return Series(result, index=self._get_agg_axis(axis))
return self._constructor_sliced(result, index=self._get_agg_axis(axis))

def _get_agg_axis(self, axis_num: int) -> Index:
"""
Expand Down
78 changes: 47 additions & 31 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _aggregate_multiple_funcs(self, arg):
# let higher level handle
return results

return DataFrame(results, columns=columns)
return self.obj._constructor_expanddim(results, columns=columns)

def _wrap_series_output(
self, output: Mapping[base.OutputKey, Union[Series, np.ndarray]], index: Index,
Expand Down Expand Up @@ -359,10 +359,12 @@ def _wrap_series_output(

result: Union[Series, DataFrame]
if len(output) > 1:
result = DataFrame(indexed_output, index=index)
result = self.obj._constructor_expanddim(indexed_output, index=index)
result.columns = columns
else:
result = Series(indexed_output[0], index=index, name=columns[0])
result = self.obj._constructor(
indexed_output[0], index=index, name=columns[0]
)

return result

Expand Down Expand Up @@ -421,7 +423,9 @@ def _wrap_transformed_output(
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
if len(keys) == 0:
# GH #6265
return Series([], name=self._selection_name, index=keys, dtype=np.float64)
return self.obj._constructor(
[], name=self._selection_name, index=keys, dtype=np.float64
)

def _get_index() -> Index:
if self.grouper.nkeys > 1:
Expand All @@ -433,7 +437,9 @@ def _get_index() -> Index:
if isinstance(values[0], dict):
# GH #823 #24880
index = _get_index()
result = self._reindex_output(DataFrame(values, index=index))
result = self._reindex_output(
self.obj._constructor_expanddim(values, index=index)
)
# if self.observed is False,
# keep all-NaN rows created while re-indexing
result = result.stack(dropna=self.observed)
Expand All @@ -447,7 +453,9 @@ def _get_index() -> Index:
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
else:
# GH #6265 #24880
result = Series(data=values, index=_get_index(), name=self._selection_name)
result = self.obj._constructor(
data=values, index=_get_index(), name=self._selection_name
)
return self._reindex_output(result)

def _aggregate_named(self, func, *args, **kwargs):
Expand Down Expand Up @@ -527,7 +535,7 @@ def _transform_general(

result = concat(results).sort_index()
else:
result = Series(dtype=np.float64)
result = self.obj._constructor(dtype=np.float64)

# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* user-defined funcs
Expand All @@ -550,7 +558,7 @@ def _transform_fast(self, result, func_nm: str) -> Series:
out = algorithms.take_1d(result._values, ids)
if cast:
out = maybe_cast_result(out, self.obj, how=func_nm)
return Series(out, index=self.obj.index, name=self.obj.name)
return self.obj._constructor(out, index=self.obj.index, name=self.obj.name)

def filter(self, func, dropna=True, *args, **kwargs):
"""
Expand Down Expand Up @@ -651,7 +659,7 @@ def nunique(self, dropna: bool = True) -> Series:
res, out = np.zeros(len(ri), dtype=out.dtype), res
res[ids[idx]] = out

result = Series(res, index=ri, name=self._selection_name)
result = self.obj._constructor(res, index=ri, name=self._selection_name)
return self._reindex_output(result, fill_value=0)

@doc(Series.describe)
Expand Down Expand Up @@ -753,7 +761,7 @@ def value_counts(

if is_integer_dtype(out):
out = ensure_int64(out)
return Series(out, index=mi, name=self._selection_name)
return self.obj._constructor(out, index=mi, name=self._selection_name)

# for compat. with libgroupby.value_counts need to ensure every
# bin is present at every index level, null filled with zeros
Expand Down Expand Up @@ -785,7 +793,7 @@ def build_codes(lev_codes: np.ndarray) -> np.ndarray:

if is_integer_dtype(out):
out = ensure_int64(out)
return Series(out, index=mi, name=self._selection_name)
return self.obj._constructor(out, index=mi, name=self._selection_name)

def count(self) -> Series:
"""
Expand All @@ -804,7 +812,7 @@ def count(self) -> Series:
minlength = ngroups or 0
out = np.bincount(ids[mask], minlength=minlength)

result = Series(
result = self.obj._constructor(
out,
index=self.grouper.result_index,
name=self._selection_name,
Expand Down Expand Up @@ -1202,11 +1210,11 @@ def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
if cannot_agg:
result_columns = result_columns.drop(cannot_agg)

return DataFrame(result, columns=result_columns)
return self.obj._constructor(result, columns=result_columns)

def _wrap_applied_output(self, keys, values, not_indexed_same=False):
if len(keys) == 0:
return DataFrame(index=keys)
return self.obj._constructor(index=keys)

key_names = self.grouper.names

Expand All @@ -1216,7 +1224,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
if first_not_none is None:
# GH9684. If all values are None, then this will throw an error.
# We'd prefer it return an empty dataframe.
return DataFrame()
return self.obj._constructor()
elif isinstance(first_not_none, DataFrame):
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
elif self.grouper.groupings is not None:
Expand Down Expand Up @@ -1247,13 +1255,13 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

# make Nones an empty object
if first_not_none is None:
return DataFrame()
return self.obj._constructor()
elif isinstance(first_not_none, NDFrame):

# this is to silence a DeprecationWarning
# TODO: Remove when default dtype of empty Series is object
kwargs = first_not_none._construct_axes_dict()
if first_not_none._constructor is Series:
if isinstance(first_not_none, Series):
backup = create_series_with_explicit_dtype(
**kwargs, dtype_if_empty=object
)
Expand Down Expand Up @@ -1320,7 +1328,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
or isinstance(key_index, MultiIndex)
):
stacked_values = np.vstack([np.asarray(v) for v in values])
result = DataFrame(
result = self.obj._constructor(
stacked_values, index=key_index, columns=index
)
else:
Expand All @@ -1337,15 +1345,17 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
result.columns = index
elif isinstance(v, ABCSeries):
stacked_values = np.vstack([np.asarray(v) for v in values])
result = DataFrame(
result = self.obj._constructor(
stacked_values.T, index=v.index, columns=key_index
)
else:
# GH#1738: values is list of arrays of unequal lengths
# fall through to the outer else clause
# TODO: sure this is right? we used to do this
# after raising AttributeError above
return Series(values, index=key_index, name=self._selection_name)
return self.obj._constructor_sliced(
values, index=key_index, name=self._selection_name
)

# if we have date/time like in the original, then coerce dates
# as we are stacking can easily have object dtypes here
Expand All @@ -1362,7 +1372,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
# self._selection_name not passed through to Series as the
# result should not take the name of original selection
# of columns
return Series(values, index=key_index)
return self.obj._constructor_sliced(values, index=key_index)

else:
# Handle cases like BinGrouper
Expand Down Expand Up @@ -1396,7 +1406,9 @@ def _transform_general(
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_func
# Return the result as a DataFrame for concatenation later
res = DataFrame(res, index=group.index, columns=group.columns)
res = self.obj._constructor(
res, index=group.index, columns=group.columns
)
else:
# Try slow path and fast path.
try:
Expand All @@ -1419,7 +1431,7 @@ def _transform_general(
r.columns = group.columns
r.index = group.index
else:
r = DataFrame(
r = self.obj._constructor(
np.concatenate([res.values] * len(group.index)).reshape(
group.shape
),
Expand Down Expand Up @@ -1495,7 +1507,9 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
output.append(res)

return DataFrame._from_arrays(output, columns=result.columns, index=obj.index)
return self.obj._constructor._from_arrays(
output, columns=result.columns, index=obj.index
)

def _define_paths(self, func, *args, **kwargs):
if isinstance(func, str):
Expand Down Expand Up @@ -1557,7 +1571,7 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
if len(output) < len(obj.columns):
columns = columns.take(inds)

return DataFrame(output, index=obj.index, columns=columns)
return self.obj._constructor(output, index=obj.index, columns=columns)

def filter(self, func, dropna=True, *args, **kwargs):
"""
Expand Down Expand Up @@ -1672,9 +1686,11 @@ def _wrap_frame_output(self, result, obj) -> DataFrame:
result_index = self.grouper.levels[0]

if self.axis == 0:
return DataFrame(result, index=obj.columns, columns=result_index).T
return self.obj._constructor(
result, index=obj.columns, columns=result_index
).T
else:
return DataFrame(result, index=obj.index, columns=result_index)
return self.obj._constructor(result, index=obj.index, columns=result_index)

def _get_data_to_aggregate(self) -> BlockManager:
obj = self._obj_with_exclusions
Expand Down Expand Up @@ -1718,7 +1734,7 @@ def _wrap_aggregated_output(
indexed_output = {key.position: val for key, val in output.items()}
columns = Index(key.label for key in output)

result = DataFrame(indexed_output)
result = self.obj._constructor(indexed_output)
result.columns = columns

if not self.as_index:
Expand Down Expand Up @@ -1751,7 +1767,7 @@ def _wrap_transformed_output(
indexed_output = {key.position: val for key, val in output.items()}
columns = Index(key.label for key in output)

result = DataFrame(indexed_output)
result = self.obj._constructor(indexed_output)
result.columns = columns
result.index = self.obj.index

Expand All @@ -1761,14 +1777,14 @@ def _wrap_agged_blocks(self, blocks: "Sequence[Block]", items: Index) -> DataFra
if not self.as_index:
index = np.arange(blocks[0].values.shape[-1])
mgr = BlockManager(blocks, axes=[items, index])
result = DataFrame(mgr)
result = self.obj._constructor(mgr)

self._insert_inaxis_grouper_inplace(result)
result = result._consolidate()
else:
index = self.grouper.result_index
mgr = BlockManager(blocks, axes=[items, index])
result = DataFrame(mgr)
result = self.obj._constructor(mgr)

if self.axis == 1:
result = result.T
Expand Down
19 changes: 15 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,14 @@ class GroupBy(_GroupBy[FrameOrSeries]):
more
"""

@property
def _obj_1d_constructor(self) -> Type["Series"]:
# GH28330 preserve subclassed Series/DataFrames
if isinstance(self.obj, DataFrame):
return self.obj._constructor_sliced
assert isinstance(self.obj, Series)
return self.obj._constructor

def _bool_agg(self, val_test, skipna):
"""
Shared func to call any / all Cython GroupBy implementations.
Expand Down Expand Up @@ -1420,8 +1428,11 @@ def size(self):
"""
result = self.grouper.size()

if isinstance(self.obj, Series):
result.name = self.obj.name
# GH28330 preserve subclassed Series/DataFrames through calls
if issubclass(self.obj._constructor, Series):
result = self._obj_1d_constructor(result, name=self.obj.name)
else:
result = self._obj_1d_constructor(result)
return self._reindex_output(result, fill_value=0)

@classmethod
Expand Down Expand Up @@ -2116,7 +2127,7 @@ def ngroup(self, ascending: bool = True):
"""
with _group_selection_context(self):
index = self._selected_obj.index
result = Series(self.grouper.group_info[0], index)
result = self._obj_1d_constructor(self.grouper.group_info[0], index)
if not ascending:
result = self.ngroups - 1 - result
return result
Expand Down Expand Up @@ -2178,7 +2189,7 @@ def cumcount(self, ascending: bool = True):
with _group_selection_context(self):
index = self._selected_obj.index
cumcounts = self._cumcount_array(ascending=ascending)
return Series(cumcounts, index)
return self._obj_1d_constructor(cumcounts, index)

@Substitution(name="groupby")
@Appender(_common_see_also)
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def get_result(self):
# combine as columns in a frame
else:
data = dict(zip(range(len(self.objs)), self.objs))
cons = DataFrame

# GH28330 Preserves subclassed objects through concat
cons = self.objs[0]._constructor_expanddim

index, columns = self.new_axes
df = cons(data, index=index)
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/frame/test_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,17 @@ def test_subclassed_boolean_reductions(self, all_boolean_reductions):
df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = getattr(df, all_boolean_reductions)()
assert isinstance(result, tm.SubclassedSeries)

def test_idxmin_preserves_subclass(self):
# GH 28330

df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = df.idxmin()
assert isinstance(result, tm.SubclassedSeries)

def test_idxmax_preserves_subclass(self):
# GH 28330

df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
result = df.idxmax()
assert isinstance(result, tm.SubclassedSeries)
Loading