Skip to content

Commit

Permalink
Fixes Bears-R-Us#3762: Fix dataframe groupby aggregations when keys c…
Browse files Browse the repository at this point in the history
…ontain `NaN`s (Bears-R-Us#3766)

* Fixes Bears-R-Us#3762: Fix dataframe groupby aggregations when keys contain `NaN`s

This PR (fixes Bears-R-Us#3762) using dataframe groupby with keys that contain `NaN`s would cause the aggregations to fail. To resolve this, we mask out the values that belong to the `NaN` segment

* updated in response to PR feedback

---------

Co-authored-by: Tess Hayes <[email protected]>
  • Loading branch information
stress-tess and stress-tess authored Sep 18, 2024
1 parent 530f198 commit e5723b3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
44 changes: 40 additions & 4 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import random
from collections import UserDict
from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from warnings import warn

Expand All @@ -24,10 +25,11 @@
from arkouda.join import inner_join
from arkouda.numpy import cast as akcast
from arkouda.numpy import cumsum, where
from arkouda.numpy.dtypes import bigint
from arkouda.numpy.dtypes import _is_dtype_in_union, bigint
from arkouda.numpy.dtypes import bool_ as akbool
from arkouda.numpy.dtypes import float64 as akfloat64
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import numeric_scalars
from arkouda.numpy.dtypes import uint64 as akuint64
from arkouda.pdarrayclass import RegistrationError, pdarray
from arkouda.pdarraycreation import arange, array, create_pdarray, full, zeros
Expand Down Expand Up @@ -105,13 +107,47 @@ class DataFrameGroupBy:
"""

def __init__(self, gb, df, gb_key_names=None, as_index=True):

self.gb = gb
self.df = df
self.gb_key_names = gb_key_names
self.as_index = as_index
for attr in ["nkeys", "permutation", "unique_keys", "segments"]:
setattr(self, attr, getattr(gb, attr))

self.dropna = self.gb.dropna
self.where_not_nan = None
self.all_non_nan = False

if self.dropna:
from arkouda import all as ak_all
from arkouda import isnan

# calculate ~isnan on each key then & them all together
# keep up with if they're all_non_nan, so we can skip indexing later
key_cols = (
[df[k] for k in gb_key_names] if isinstance(gb_key_names, List) else [df[gb_key_names]]
)
where_key_not_nan = [
~isnan(col)
for col in key_cols
if isinstance(col, pdarray) and _is_dtype_in_union(col.dtype, numeric_scalars)
]

if len(where_key_not_nan) == 0:
# if empty then none of the keys are pdarray, so non are nan
self.all_non_nan = True
else:
self.where_not_nan = reduce(lambda x, y: x & y, where_key_not_nan)
self.all_non_nan = ak_all(self.where_not_nan)

def _get_df_col(self, c):
# helper function to mask out the values where the keys are nan when dropna is True
if not self.dropna or self.all_non_nan:
return self.df.data[c]
else:
return self.df.data[c][self.where_not_nan]

@classmethod
def _make_aggop(cls, opname):
numerical_dtypes = [akfloat64, akint64, akuint64]
Expand Down Expand Up @@ -148,18 +184,18 @@ def aggop(self, colnames=None):
if isinstance(colnames, List):
if isinstance(self.gb_key_names, str):
return DataFrame(
{c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames},
{c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames},
index=Index(self.gb.unique_keys, name=self.gb_key_names),
)
elif isinstance(self.gb_key_names, list) and len(self.gb_key_names) == 1:
return DataFrame(
{c: self.gb.aggregate(self.df.data[c], opname)[1] for c in colnames},
{c: self.gb.aggregate(self._get_df_col(c), opname)[1] for c in colnames},
index=Index(self.gb.unique_keys, name=self.gb_key_names[0]),
)
elif isinstance(self.gb_key_names, list):
column_dict = dict(zip(self.gb_key_names, self.unique_keys))
for c in colnames:
column_dict[c] = self.gb.aggregate(self.df.data[c], opname)[1]
column_dict[c] = self.gb.aggregate(self._get_df_col(c), opname)[1]
return DataFrame(column_dict)
else:
return None
Expand Down
39 changes: 34 additions & 5 deletions tests/dataframe_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import glob
import itertools
import os
import tempfile

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -652,16 +650,47 @@ def test_gb_aggregations_example_numeric_types(self, agg):
pd_result = getattr(pd_df.groupby(group_on), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.parametrize("agg", ["count", "max", "mean", "median", "min", "std", "sum", "var"])
def test_gb_aggregations_with_nans(self, agg):
def test_gb_aggregations_with_nans(self, agg, dropna):
df = self.build_ak_df_with_nans()
# @TODO handle bool columns correctly
df.drop("bools", axis=1, inplace=True)
pd_df = df.to_pandas()

group_on = ["key1", "key2"]
ak_result = getattr(df.groupby(group_on), agg)()
pd_result = getattr(pd_df.groupby(group_on, as_index=False), agg)()
ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)()
pd_result = getattr(pd_df.groupby(group_on, as_index=False, dropna=dropna), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

# TODO aggregations of string columns not currently supported (even for count)
df.drop("key1", axis=1, inplace=True)
df.drop("key2", axis=1, inplace=True)
pd_df = df.to_pandas()

group_on = ["nums1", "nums2"]
ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)()
pd_result = getattr(pd_df.groupby(group_on, as_index=False, dropna=dropna), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

# TODO aggregation mishandling NaN see issue #3765
df.drop("nums2", axis=1, inplace=True)
pd_df = df.to_pandas()
group_on = "nums1"
ak_result = getattr(df.groupby(group_on, dropna=dropna), agg)()
pd_result = getattr(pd_df.groupby(group_on, dropna=dropna), agg)()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

@pytest.mark.parametrize("dropna", [True, False])
def test_count_nan_bug(self, dropna):
# verify reproducer for #3762 is fixed
df = ak.DataFrame({"A": [1, 2, 2, np.nan], "B": [3, 4, 5, 6], "C": [1, np.nan, 2, 3]})
ak_result = df.groupby("A", dropna=dropna).count()
pd_result = df.to_pandas().groupby("A", dropna=dropna).count()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

ak_result = df.groupby(["A", "C"], as_index=False, dropna=dropna).count()
pd_result = df.to_pandas().groupby(["A", "C"], as_index=False, dropna=dropna).count()
assert_frame_equal(ak_result.to_pandas(retain_index=True), pd_result)

def test_gb_aggregations_return_dataframe(self):
Expand Down

0 comments on commit e5723b3

Please sign in to comment.