Skip to content

Commit

Permalink
Add Pearson correlation for sort groupby (python) (#9166)
Browse files Browse the repository at this point in the history
  • Loading branch information
skirui-source authored Nov 30, 2021
1 parent 20d6723 commit 991136c
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Computations / descriptive stats
GroupBy.std
GroupBy.sum
GroupBy.var
GroupBy.corr

The following methods are available in both ``SeriesGroupBy`` and
``DataFrameGroupBy`` objects, but may differ slightly, usually in that
Expand Down
10 changes: 10 additions & 0 deletions docs/cudf/source/basics/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ Aggregations on groups is supported via the ``agg`` method:
a
1 4 1 2.0
2 5 2 4.5
>>> df.groupby("a").corr(method="pearson")
b c
a
1 b 1.000000 0.866025
c 0.866025 1.000000
2 b 1.000000 1.000000
c 1.000000 1.000000
The following table summarizes the available aggregations and the types
that support them:
Expand Down Expand Up @@ -169,6 +176,9 @@ that support them:
+------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+
| unique ||||| | | | |
+------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+
| corr || | | | | | ||
+------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+


GroupBy apply
-------------
Expand Down
55 changes: 53 additions & 2 deletions python/cudf/cudf/_lib/aggregation.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.

from enum import Enum
from enum import Enum, IntEnum

import numba
import numpy as np
Expand Down Expand Up @@ -30,6 +30,7 @@ from cudf._lib.types import Interpolation

cimport cudf._lib.cpp.aggregation as libcudf_aggregation
cimport cudf._lib.cpp.types as libcudf_types
from cudf._lib.cpp.aggregation cimport underlying_type_t_correlation_type

import cudf

Expand Down Expand Up @@ -57,6 +58,22 @@ class AggregationKind(Enum):
UNIQUE = libcudf_aggregation.aggregation.Kind.COLLECT_SET
PTX = libcudf_aggregation.aggregation.Kind.PTX
CUDA = libcudf_aggregation.aggregation.Kind.CUDA
CORRELATION = libcudf_aggregation.aggregation.Kind.CORRELATION


class CorrelationType(IntEnum):
PEARSON = (
<underlying_type_t_correlation_type>
libcudf_aggregation.correlation_type.PEARSON
)
KENDALL = (
<underlying_type_t_correlation_type>
libcudf_aggregation.correlation_type.KENDALL
)
SPEARMAN = (
<underlying_type_t_correlation_type>
libcudf_aggregation.correlation_type.SPEARMAN
)


cdef class Aggregation:
Expand Down Expand Up @@ -321,6 +338,22 @@ cdef class Aggregation:
))
return agg

@classmethod
def corr(cls, method, libcudf_types.size_type min_periods):
cdef Aggregation agg = cls()
cdef libcudf_aggregation.correlation_type c_method = (
<libcudf_aggregation.correlation_type> (
<underlying_type_t_correlation_type> (
CorrelationType[method.upper()]
)
)
)
agg.c_obj = move(
libcudf_aggregation.make_correlation_aggregation[aggregation](
c_method, min_periods
))
return agg

cdef class RollingAggregation:
"""A Cython wrapper for rolling window aggregations.
Expand Down Expand Up @@ -692,6 +725,24 @@ cdef class GroupbyAggregation:
)
return agg

@classmethod
def corr(cls, method, libcudf_types.size_type min_periods):
cdef GroupbyAggregation agg = cls()
cdef libcudf_aggregation.correlation_type c_method = (
<libcudf_aggregation.correlation_type> (
<underlying_type_t_correlation_type> (
CorrelationType[method.upper()]
)
)
)
agg.c_obj = move(
libcudf_aggregation.
make_correlation_aggregation[groupby_aggregation](
c_method, min_periods
))
return agg


cdef class GroupbyScanAggregation:
"""A Cython wrapper for groupby scan aggregations.
Expand Down
15 changes: 13 additions & 2 deletions python/cudf/cudf/_lib/cpp/aggregation.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2020, NVIDIA CORPORATION.

# Copyright (c) 2020-2021, NVIDIA CORPORATION.
from libc.stdint cimport int32_t
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.vector cimport vector
Expand All @@ -11,6 +11,7 @@ from cudf._lib.cpp.types cimport (
size_type,
)

ctypedef int32_t underlying_type_t_correlation_type

cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:

Expand Down Expand Up @@ -38,6 +39,8 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
COLLECT_SET 'cudf::aggregation::COLLECT_SET'
PTX 'cudf::aggregation::PTX'
CUDA 'cudf::aggregation::CUDA'
CORRELATION 'cudf::aggregation::CORRELATION'

Kind kind

cdef cppclass rolling_aggregation:
Expand All @@ -53,6 +56,11 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
CUDA 'cudf::udf_type::CUDA'
PTX 'cudf::udf_type::PTX'

ctypedef enum correlation_type:
PEARSON 'cudf::correlation_type::PEARSON'
KENDALL 'cudf::correlation_type::KENDALL'
SPEARMAN 'cudf::correlation_type::SPEARMAN'

cdef unique_ptr[T] make_sum_aggregation[T]() except +

cdef unique_ptr[T] make_product_aggregation[T]() except +
Expand Down Expand Up @@ -106,3 +114,6 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil:
udf_type type,
string user_defined_aggregator,
data_type output_type) except +

cdef unique_ptr[T] make_correlation_aggregation[T](
correlation_type type, size_type min_periods) except +
4 changes: 2 additions & 2 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.

from collections import defaultdict

Expand Down Expand Up @@ -54,7 +54,7 @@ _CATEGORICAL_AGGS = {"COUNT", "SIZE", "NUNIQUE", "UNIQUE"}
_STRING_AGGS = {"COUNT", "SIZE", "MAX", "MIN", "NUNIQUE", "NTH", "COLLECT",
"UNIQUE"}
_LIST_AGGS = {"COLLECT"}
_STRUCT_AGGS = set()
_STRUCT_AGGS = {"CORRELATION"}
_INTERVAL_AGGS = set()
_DECIMAL_AGGS = {"COUNT", "SUM", "ARGMIN", "ARGMAX", "MIN", "MAX", "NUNIQUE",
"NTH", "COLLECT"}
Expand Down
121 changes: 120 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.

import collections
import itertools
import pickle
import warnings

Expand All @@ -13,7 +14,8 @@
from cudf._typing import DataFrameOrSeries
from cudf.api.types import is_list_like
from cudf.core.abc import Serializable
from cudf.core.column.column import arange
from cudf.core.column.column import arange, as_column
from cudf.core.multiindex import MultiIndex
from cudf.utils.utils import GetAttrGetItemMixin, cached_property


Expand Down Expand Up @@ -69,6 +71,8 @@ def __init__(
"""
self.obj = obj
self._as_index = as_index
self._by = by
self._level = level
self._sort = sort
self._dropna = dropna

Expand Down Expand Up @@ -777,6 +781,121 @@ def median(self):
"""Get the column-wise median of the values in each group."""
return self.agg("median")

def corr(self, method="pearson", min_periods=1):
"""
Compute pairwise correlation of columns, excluding NA/null values.
Parameters
----------
method: {"pearson", "kendall", "spearman"} or callable,
default "pearson". Currently only the pearson correlation
coefficient is supported.
min_periods: int, optional
Minimum number of observations required per pair of columns
to have a valid result.
Returns
----------
DataFrame
Correlation matrix.
Examples
--------
>>> import cudf
>>> gdf = cudf.DataFrame({
... "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"],
... "val1": [5, 4, 6, 4, 8, 7, 4, 5, 2],
... "val2": [4, 5, 6, 1, 2, 9, 8, 5, 1],
... "val3": [4, 5, 6, 1, 2, 9, 8, 5, 1]})
>>> gdf
id val1 val2 val3
0 a 5 4 4
1 a 4 5 5
2 a 6 6 6
3 b 4 1 1
4 b 8 2 2
5 b 7 9 9
6 c 4 8 8
7 c 5 5 5
8 c 2 1 1
>>> gdf.groupby("id").corr(method="pearson")
val1 val2 val3
id
a val1 1.000000 0.500000 0.500000
val2 0.500000 1.000000 1.000000
val3 0.500000 1.000000 1.000000
b val1 1.000000 0.385727 0.385727
val2 0.385727 1.000000 1.000000
val3 0.385727 1.000000 1.000000
c val1 1.000000 0.714575 0.714575
val2 0.714575 1.000000 1.000000
val3 0.714575 1.000000 1.000000
"""

if not method.lower() in ("pearson",):
raise NotImplementedError(
"Only pearson correlation is currently supported"
)

# create expanded dataframe consisting all combinations of the
# struct columns-pairs to be correlated
# i.e (('col1', 'col1'), ('col1', 'col2'), ('col2', 'col2'))
_cols = self.grouping.values.columns.tolist()
len_cols = len(_cols)

new_df_data = {}
for x, y in itertools.combinations_with_replacement(_cols, 2):
new_df_data[(x, y)] = cudf.DataFrame._from_data(
{"x": self.obj._data[x], "y": self.obj._data[y]}
).to_struct()
new_gb = cudf.DataFrame._from_data(new_df_data).groupby(
by=self.grouping.keys
)

try:
gb_corr = new_gb.agg(lambda x: x.corr(method, min_periods))
except RuntimeError as e:
if "Unsupported groupby reduction type-agg combination" in str(e):
raise TypeError(
"Correlation accepts only numerical column-pairs"
)
raise

# ensure that column-pair labels are arranged in ascending order
cols_list = [
(y, x) if i > j else (x, y)
for j, y in enumerate(_cols)
for i, x in enumerate(_cols)
]
cols_split = [
cols_list[i : i + len_cols]
for i in range(0, len(cols_list), len_cols)
]

# interleave: combine the correlation results for each column-pair
# into a single column
res = cudf.DataFrame._from_data(
{
x: gb_corr.loc[:, i].interleave_columns()
for i, x in zip(cols_split, _cols)
}
)

# create a multiindex for the groupby correlated dataframe,
# to match pandas behavior
unsorted_idx = gb_corr.index.repeat(len_cols)
idx_sort_order = unsorted_idx._get_sorted_inds()
sorted_idx = unsorted_idx._gather(idx_sort_order)
if len(gb_corr):
# TO-DO: Should the operation below be done on the CPU instead?
sorted_idx._data[None] = as_column(
cudf.Series(_cols).tile(len(gb_corr.index))
)
res.index = MultiIndex._from_data(sorted_idx._data)

return res

def var(self, ddof=1):
"""Compute the column-wise variance of the values in each group.
Expand Down
Loading

0 comments on commit 991136c

Please sign in to comment.