diff --git a/docs/cudf/source/api_docs/groupby.rst b/docs/cudf/source/api_docs/groupby.rst index cf08d1d791b..575d7442cdf 100644 --- a/docs/cudf/source/api_docs/groupby.rst +++ b/docs/cudf/source/api_docs/groupby.rst @@ -59,6 +59,7 @@ Computations / descriptive stats GroupBy.std GroupBy.sum GroupBy.var + GroupBy.corr The following methods are available in both ``SeriesGroupBy`` and ``DataFrameGroupBy`` objects, but may differ slightly, usually in that diff --git a/docs/cudf/source/basics/groupby.rst b/docs/cudf/source/basics/groupby.rst index 04c4d42fa2a..f3269768025 100644 --- a/docs/cudf/source/basics/groupby.rst +++ b/docs/cudf/source/basics/groupby.rst @@ -127,6 +127,13 @@ Aggregations on groups is supported via the ``agg`` method: a 1 4 1 2.0 2 5 2 4.5 + >>> df.groupby("a").corr(method="pearson") + b c + a + 1 b 1.000000 0.866025 + c 0.866025 1.000000 + 2 b 1.000000 1.000000 + c 1.000000 1.000000 The following table summarizes the available aggregations and the types that support them: @@ -169,6 +176,9 @@ that support them: +------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+ | unique | ✅ | ✅ | ✅ | ✅ | | | | | +------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+ + | corr | ✅ | | | | | | | ✅ | + +------------------------------------+-----------+------------+----------+---------------+--------+----------+------------+-----------+ + GroupBy apply ------------- diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx index 4f703724cef..68f7101b6ee 100644 --- a/python/cudf/cudf/_lib/aggregation.pyx +++ b/python/cudf/cudf/_lib/aggregation.pyx @@ -1,6 +1,6 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. -from enum import Enum +from enum import Enum, IntEnum import numba import numpy as np @@ -30,6 +30,7 @@ from cudf._lib.types import Interpolation cimport cudf._lib.cpp.aggregation as libcudf_aggregation cimport cudf._lib.cpp.types as libcudf_types +from cudf._lib.cpp.aggregation cimport underlying_type_t_correlation_type import cudf @@ -57,6 +58,22 @@ class AggregationKind(Enum): UNIQUE = libcudf_aggregation.aggregation.Kind.COLLECT_SET PTX = libcudf_aggregation.aggregation.Kind.PTX CUDA = libcudf_aggregation.aggregation.Kind.CUDA + CORRELATION = libcudf_aggregation.aggregation.Kind.CORRELATION + + +class CorrelationType(IntEnum): + PEARSON = ( + + libcudf_aggregation.correlation_type.PEARSON + ) + KENDALL = ( + + libcudf_aggregation.correlation_type.KENDALL + ) + SPEARMAN = ( + + libcudf_aggregation.correlation_type.SPEARMAN + ) cdef class Aggregation: @@ -321,6 +338,22 @@ cdef class Aggregation: )) return agg + @classmethod + def corr(cls, method, libcudf_types.size_type min_periods): + cdef Aggregation agg = cls() + cdef libcudf_aggregation.correlation_type c_method = ( + ( + ( + CorrelationType[method.upper()] + ) + ) + ) + agg.c_obj = move( + libcudf_aggregation.make_correlation_aggregation[aggregation]( + c_method, min_periods + )) + return agg + cdef class RollingAggregation: """A Cython wrapper for rolling window aggregations. @@ -692,6 +725,24 @@ cdef class GroupbyAggregation: ) return agg + @classmethod + def corr(cls, method, libcudf_types.size_type min_periods): + cdef GroupbyAggregation agg = cls() + cdef libcudf_aggregation.correlation_type c_method = ( + ( + ( + CorrelationType[method.upper()] + ) + ) + ) + agg.c_obj = move( + libcudf_aggregation. + make_correlation_aggregation[groupby_aggregation]( + c_method, min_periods + )) + return agg + + cdef class GroupbyScanAggregation: """A Cython wrapper for groupby scan aggregations. diff --git a/python/cudf/cudf/_lib/cpp/aggregation.pxd b/python/cudf/cudf/_lib/cpp/aggregation.pxd index 13bfa49057c..3982b4fecbb 100644 --- a/python/cudf/cudf/_lib/cpp/aggregation.pxd +++ b/python/cudf/cudf/_lib/cpp/aggregation.pxd @@ -1,5 +1,5 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. - +# Copyright (c) 2020-2021, NVIDIA CORPORATION. +from libc.stdint cimport int32_t from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.vector cimport vector @@ -11,6 +11,7 @@ from cudf._lib.cpp.types cimport ( size_type, ) +ctypedef int32_t underlying_type_t_correlation_type cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: @@ -38,6 +39,8 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: COLLECT_SET 'cudf::aggregation::COLLECT_SET' PTX 'cudf::aggregation::PTX' CUDA 'cudf::aggregation::CUDA' + CORRELATION 'cudf::aggregation::CORRELATION' + Kind kind cdef cppclass rolling_aggregation: @@ -53,6 +56,11 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: CUDA 'cudf::udf_type::CUDA' PTX 'cudf::udf_type::PTX' + ctypedef enum correlation_type: + PEARSON 'cudf::correlation_type::PEARSON' + KENDALL 'cudf::correlation_type::KENDALL' + SPEARMAN 'cudf::correlation_type::SPEARMAN' + cdef unique_ptr[T] make_sum_aggregation[T]() except + cdef unique_ptr[T] make_product_aggregation[T]() except + @@ -106,3 +114,6 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: udf_type type, string user_defined_aggregator, data_type output_type) except + + + cdef unique_ptr[T] make_correlation_aggregation[T]( + correlation_type type, size_type min_periods) except + diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index 0968d22d465..314542c9549 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. from collections import defaultdict @@ -54,7 +54,7 @@ _CATEGORICAL_AGGS = {"COUNT", "SIZE", "NUNIQUE", "UNIQUE"} _STRING_AGGS = {"COUNT", "SIZE", "MAX", "MIN", "NUNIQUE", "NTH", "COLLECT", "UNIQUE"} _LIST_AGGS = {"COLLECT"} -_STRUCT_AGGS = set() +_STRUCT_AGGS = {"CORRELATION"} _INTERVAL_AGGS = set() _DECIMAL_AGGS = {"COUNT", "SUM", "ARGMIN", "ARGMAX", "MIN", "MAX", "NUNIQUE", "NTH", "COLLECT"} diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 7f9f61ed3fd..f1d622362e2 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1,6 +1,7 @@ # Copyright (c) 2020-2021, NVIDIA CORPORATION. import collections +import itertools import pickle import warnings @@ -13,7 +14,8 @@ from cudf._typing import DataFrameOrSeries from cudf.api.types import is_list_like from cudf.core.abc import Serializable -from cudf.core.column.column import arange +from cudf.core.column.column import arange, as_column +from cudf.core.multiindex import MultiIndex from cudf.utils.utils import GetAttrGetItemMixin, cached_property @@ -69,6 +71,8 @@ def __init__( """ self.obj = obj self._as_index = as_index + self._by = by + self._level = level self._sort = sort self._dropna = dropna @@ -777,6 +781,121 @@ def median(self): """Get the column-wise median of the values in each group.""" return self.agg("median") + def corr(self, method="pearson", min_periods=1): + """ + Compute pairwise correlation of columns, excluding NA/null values. + + Parameters + ---------- + method: {"pearson", "kendall", "spearman"} or callable, + default "pearson". Currently only the pearson correlation + coefficient is supported. + + min_periods: int, optional + Minimum number of observations required per pair of columns + to have a valid result. + + Returns + ---------- + DataFrame + Correlation matrix. + + Examples + -------- + >>> import cudf + >>> gdf = cudf.DataFrame({ + ... "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + ... "val1": [5, 4, 6, 4, 8, 7, 4, 5, 2], + ... "val2": [4, 5, 6, 1, 2, 9, 8, 5, 1], + ... "val3": [4, 5, 6, 1, 2, 9, 8, 5, 1]}) + >>> gdf + id val1 val2 val3 + 0 a 5 4 4 + 1 a 4 5 5 + 2 a 6 6 6 + 3 b 4 1 1 + 4 b 8 2 2 + 5 b 7 9 9 + 6 c 4 8 8 + 7 c 5 5 5 + 8 c 2 1 1 + >>> gdf.groupby("id").corr(method="pearson") + val1 val2 val3 + id + a val1 1.000000 0.500000 0.500000 + val2 0.500000 1.000000 1.000000 + val3 0.500000 1.000000 1.000000 + b val1 1.000000 0.385727 0.385727 + val2 0.385727 1.000000 1.000000 + val3 0.385727 1.000000 1.000000 + c val1 1.000000 0.714575 0.714575 + val2 0.714575 1.000000 1.000000 + val3 0.714575 1.000000 1.000000 + """ + + if not method.lower() in ("pearson",): + raise NotImplementedError( + "Only pearson correlation is currently supported" + ) + + # create expanded dataframe consisting all combinations of the + # struct columns-pairs to be correlated + # i.e (('col1', 'col1'), ('col1', 'col2'), ('col2', 'col2')) + _cols = self.grouping.values.columns.tolist() + len_cols = len(_cols) + + new_df_data = {} + for x, y in itertools.combinations_with_replacement(_cols, 2): + new_df_data[(x, y)] = cudf.DataFrame._from_data( + {"x": self.obj._data[x], "y": self.obj._data[y]} + ).to_struct() + new_gb = cudf.DataFrame._from_data(new_df_data).groupby( + by=self.grouping.keys + ) + + try: + gb_corr = new_gb.agg(lambda x: x.corr(method, min_periods)) + except RuntimeError as e: + if "Unsupported groupby reduction type-agg combination" in str(e): + raise TypeError( + "Correlation accepts only numerical column-pairs" + ) + raise + + # ensure that column-pair labels are arranged in ascending order + cols_list = [ + (y, x) if i > j else (x, y) + for j, y in enumerate(_cols) + for i, x in enumerate(_cols) + ] + cols_split = [ + cols_list[i : i + len_cols] + for i in range(0, len(cols_list), len_cols) + ] + + # interleave: combine the correlation results for each column-pair + # into a single column + res = cudf.DataFrame._from_data( + { + x: gb_corr.loc[:, i].interleave_columns() + for i, x in zip(cols_split, _cols) + } + ) + + # create a multiindex for the groupby correlated dataframe, + # to match pandas behavior + unsorted_idx = gb_corr.index.repeat(len_cols) + idx_sort_order = unsorted_idx._get_sorted_inds() + sorted_idx = unsorted_idx._gather(idx_sort_order) + if len(gb_corr): + # TO-DO: Should the operation below be done on the CPU instead? + sorted_idx._data[None] = as_column( + cudf.Series(_cols).tile(len(gb_corr.index)) + ) + res.index = MultiIndex._from_data(sorted_idx._data) + + return res + def var(self, ddof=1): """Compute the column-wise variance of the values in each group. diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index d07caef11d5..d555b5c4033 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -8924,3 +8924,118 @@ def test_frame_series_where_other(data): expected = gdf.where(gdf["b"] == 1, 0) actual = pdf.where(pdf["b"] == 1, 0) assert_eq(expected, actual) + + +@pytest.mark.parametrize( + "data, gkey", + [ + ( + { + "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "val1": [5, 4, 6, 4, 8, 7, 4, 5, 2], + "val2": [4, 5, 6, 1, 2, 9, 8, 5, 1], + "val3": [4, 5, 6, 1, 2, 9, 8, 5, 1], + }, + ["id", "val1", "val2"], + ), + ( + { + "id": [0] * 4 + [1] * 3, + "a": [10, 3, 4, 2, -3, 9, 10], + "b": [10, 23, -4, 2, -3, 9, 19], + }, + ["id", "a"], + ), + ( + { + "id": ["a", "a", "b", "b", "c", "c"], + "val": [None, None, None, None, None, None], + }, + ["id"], + ), + ( + { + "id": ["a", "a", "b", "b", "c", "c"], + "val1": [None, 4, 6, 8, None, 2], + "val2": [4, 5, None, 2, 9, None], + }, + ["id"], + ), + ({"id": [1.0], "val1": [2.0], "val2": [3.0]}, ["id"]), + ], +) +@pytest.mark.parametrize( + "min_per", [0, 1, 2, 3, 4], +) +def test_pearson_corr_passing(data, gkey, min_per): + gdf = cudf.DataFrame(data) + pdf = gdf.to_pandas() + + actual = gdf.groupby(gkey).corr(method="pearson", min_periods=min_per) + expected = pdf.groupby(gkey).corr(method="pearson", min_periods=min_per) + + assert_eq(expected, actual) + + +@pytest.mark.parametrize("method", ["kendall", "spearman"]) +def test_pearson_corr_unsupported_methods(method): + gdf = cudf.DataFrame( + { + "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "val1": [5, 4, 6, 4, 8, 7, 4, 5, 2], + "val2": [4, 5, 6, 1, 2, 9, 8, 5, 1], + "val3": [4, 5, 6, 1, 2, 9, 8, 5, 1], + } + ) + + with pytest.raises( + NotImplementedError, + match="Only pearson correlation is currently supported", + ): + gdf.groupby("id").corr(method) + + +def test_pearson_corr_empty_columns(): + gdf = cudf.DataFrame(columns=["id", "val1", "val2"]) + pdf = gdf.to_pandas() + + actual = gdf.groupby("id").corr("pearson") + expected = pdf.groupby("id").corr("pearson") + + assert_eq( + expected, actual, check_dtype=False, check_index_type=False, + ) + + +@pytest.mark.parametrize( + "data", + [ + { + "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "val1": ["v", "n", "k", "l", "m", "i", "y", "r", "w"], + "val2": ["d", "d", "d", "e", "e", "e", "f", "f", "f"], + }, + { + "id": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "val1": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "val2": ["d", "d", "d", "e", "e", "e", "f", "f", "f"], + }, + ], +) +@pytest.mark.parametrize("gkey", ["id", "val1", "val2"]) +def test_pearson_corr_invalid_column_types(data, gkey): + with pytest.raises( + TypeError, match="Correlation accepts only numerical column-pairs", + ): + cudf.DataFrame(data).groupby(gkey).corr("pearson") + + +def test_pearson_corr_multiindex_dataframe(): + gdf = cudf.DataFrame( + {"a": [1, 1, 2, 2], "b": [1, 1, 2, 3], "c": [2, 3, 4, 5]} + ).set_index(["a", "b"]) + + actual = gdf.groupby(level="a").corr("pearson") + expected = gdf.to_pandas().groupby(level="a").corr("pearson") + + assert_eq(expected, actual)