From 118a70a1406feb30a9fd57e905b134d8ab10151d Mon Sep 17 00:00:00 2001 From: Chris Jarrett Date: Tue, 30 Mar 2021 22:59:44 -0700 Subject: [PATCH 1/4] Enable basic reductions for decimal columns --- python/cudf/cudf/_lib/reduce.pyx | 17 ++++----- python/cudf/cudf/core/column/decimal.py | 43 ++++++++++++++++++++++- python/cudf/cudf/tests/test_reductions.py | 42 +++++++++++++++++++++- 3 files changed, 90 insertions(+), 12 deletions(-) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 2185cb089a7..6ee4f39d64b 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -1,6 +1,7 @@ # Copyright (c) 2020, NVIDIA CORPORATION. import cudf +from cudf.utils.dtypes import is_decimal_dtype from cudf._lib.cpp.reduce cimport cpp_reduce, cpp_scan, scan_type, cpp_minmax from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.types cimport data_type, type_id @@ -9,12 +10,15 @@ from cudf._lib.cpp.column.column cimport column from cudf._lib.scalar cimport DeviceScalar from cudf._lib.column cimport Column from cudf._lib.types import np_to_cudf_types -from cudf._lib.types cimport underlying_type_t_type_id +from cudf._lib.types cimport underlying_type_t_type_id, dtype_to_data_type from cudf._lib.aggregation cimport make_aggregation, aggregation from libcpp.memory cimport unique_ptr from libcpp.utility cimport move, pair import numpy as np +cimport cudf._lib.cpp.types as libcudf_types +from cudf._lib. + def reduce(reduction_op, Column incol, dtype=None, **kwargs): """ @@ -32,7 +36,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): """ col_dtype = incol.dtype - if reduction_op in ['sum', 'sum_of_squares', 'product']: + if reduction_op in ['sum', 'sum_of_squares', 'product'] and not is_decimal_dtype(col_dtype): col_dtype = np.find_common_type([col_dtype], [np.uint64]) col_dtype = col_dtype if dtype is None else dtype @@ -41,15 +45,8 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): cdef unique_ptr[aggregation] c_agg = move(make_aggregation( reduction_op, kwargs )) - cdef type_id tid = ( - ( - ( - np_to_cudf_types[np.dtype(col_dtype)] - ) - ) - ) - cdef data_type c_out_dtype = data_type(tid) + cdef data_type c_out_dtype = dtype_to_data_type(col_dtype) # check empty case if len(incol) <= incol.null_count: diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 96e09a5abb5..f3b85d2d75d 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -17,7 +17,7 @@ from_decimal as cpp_from_decimal, ) from cudf.core.column import as_column - +import decimal class DecimalColumn(ColumnBase): @classmethod @@ -96,6 +96,47 @@ def as_string_column( "cudf.core.column.StringColumn", as_column([], dtype="object") ) + def reduce(self, op: str, skipna: bool = None, **kwargs) -> decimal.Decimal: + min_count = kwargs.pop("min_count", 0) + preprocessed = self._process_for_reduction( + skipna=skipna, min_count=min_count + ) + if isinstance(preprocessed, ColumnBase): + return libcudf.reduce.reduce(op, preprocessed, **kwargs) + else: + return cast(self.dtype, preprocessed) + + def sum( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> decimal.Decimal: + return self.reduce( + "sum", skipna=skipna, dtype=dtype, min_count=min_count + ) + + def product( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> decimal.Decimal: + return self.reduce( + "product", skipna=skipna, dtype=dtype, min_count=min_count + ) + + def mean(self, skipna: bool = None, dtype: Dtype = decimal.Decimal + ) -> decimal.Decimal: + return self.reduce("mean", skipna=skipna, dtype=dtype) + + def var( + self, skipna: bool = None, ddof: int = 1, dtype: Dtype = decimal.Decimal + ) -> decimal.Decimal: + return self.reduce("var", skipna=skipna, dtype=dtype, ddof=ddof) + + def std( + self, skipna: bool = None, ddof: int = 1, dtype: Dtype = decimal.Decimal + ) -> decimal.Decimal: + return self.reduce("std", skipna=skipna, dtype=dtype, ddof=ddof) + + def sum_of_squares(self, dtype: Dtype = None) -> decimal.Decimal: + return libcudf.reduce.reduce("sum_of_squares", self, dtype=dtype) + def _binop_scale(l_dtype, r_dtype, op): # This should at some point be hooked up to libcudf's diff --git a/python/cudf/cudf/tests/test_reductions.py b/python/cudf/cudf/tests/test_reductions.py index 80a2e89bf46..ae1d2c1e19b 100644 --- a/python/cudf/cudf/tests/test_reductions.py +++ b/python/cudf/cudf/tests/test_reductions.py @@ -7,12 +7,14 @@ from itertools import product import numpy as np +import pandas as pd import pytest +from decimal import Decimal import cudf from cudf.core import Series from cudf.tests import utils -from cudf.tests.utils import NUMERIC_TYPES, gen_rand +from cudf.tests.utils import NUMERIC_TYPES, gen_rand, assert_eq params_dtype = NUMERIC_TYPES @@ -49,6 +51,18 @@ def test_sum_string(): assert got == expected +@pytest.mark.parametrize( + "dtype", + [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], +) +@pytest.mark.parametrize("nelem", params_sizes) +def test_sum_decimal(dtype, nelem): + data = gen_rand("int64",nelem)/100 + expected = pd.Series(data).sum() + got = cudf.Series([str(x) for x in data]).astype(dtype).sum() + + assert_eq(Decimal(expected), got) + @pytest.mark.parametrize("dtype,nelem", params) def test_product(dtype, nelem): @@ -106,6 +120,19 @@ def test_min(dtype, nelem): assert expect == got +@pytest.mark.parametrize( + "dtype", + [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], +) +@pytest.mark.parametrize("nelem", params_sizes) +def test_min_decimal(dtype, nelem): + data = gen_rand("int64",nelem)/100 + expected = pd.Series(data).min() + got = cudf.Series([str(x) for x in data]).astype(dtype).min() + + assert_eq(Decimal(expected), got) + + @pytest.mark.parametrize("dtype,nelem", params) def test_max(dtype, nelem): dtype = np.dtype(dtype).type @@ -118,6 +145,19 @@ def test_max(dtype, nelem): assert expect == got +@pytest.mark.parametrize( + "dtype", + [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], +) +@pytest.mark.parametrize("nelem", params_sizes) +def test_max_decimal(dtype, nelem): + data = gen_rand("int64",nelem)/100 + expected = pd.Series(data).max() + got = cudf.Series([str(x) for x in data]).astype(dtype).max() + + assert_eq(Decimal(expected), got) + + @pytest.mark.parametrize("nelem", params_sizes) def test_sum_masked(nelem): dtype = np.float64 From c88b10e0e1760c5ac65525d20e8edeb3c9bfe0d5 Mon Sep 17 00:00:00 2001 From: Chris Jarrett Date: Tue, 30 Mar 2021 23:32:43 -0700 Subject: [PATCH 2/4] Fix style --- python/cudf/cudf/_lib/reduce.pyx | 6 ++++-- python/cudf/cudf/core/column/decimal.py | 20 +++++++++++++++----- python/cudf/cudf/tests/test_reductions.py | 8 +++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 6ee4f39d64b..4579530599e 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -17,7 +17,6 @@ from libcpp.utility cimport move, pair import numpy as np cimport cudf._lib.cpp.types as libcudf_types -from cudf._lib. def reduce(reduction_op, Column incol, dtype=None, **kwargs): @@ -36,7 +35,10 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): """ col_dtype = incol.dtype - if reduction_op in ['sum', 'sum_of_squares', 'product'] and not is_decimal_dtype(col_dtype): + if ( + reduction_op in ['sum', 'sum_of_squares', 'product'] + and not is_decimal_dtype(col_dtype) + ): col_dtype = np.find_common_type([col_dtype], [np.uint64]) col_dtype = col_dtype if dtype is None else dtype diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index f3b85d2d75d..38ebd7dedfc 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -19,6 +19,7 @@ from cudf.core.column import as_column import decimal + class DecimalColumn(ColumnBase): @classmethod def from_arrow(cls, data: pa.Array): @@ -96,7 +97,9 @@ def as_string_column( "cudf.core.column.StringColumn", as_column([], dtype="object") ) - def reduce(self, op: str, skipna: bool = None, **kwargs) -> decimal.Decimal: + def reduce( + self, op: str, skipna: bool = None, **kwargs + ) -> decimal.Decimal: min_count = kwargs.pop("min_count", 0) preprocessed = self._process_for_reduction( skipna=skipna, min_count=min_count @@ -104,7 +107,7 @@ def reduce(self, op: str, skipna: bool = None, **kwargs) -> decimal.Decimal: if isinstance(preprocessed, ColumnBase): return libcudf.reduce.reduce(op, preprocessed, **kwargs) else: - return cast(self.dtype, preprocessed) + return preprocessed def sum( self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 @@ -120,17 +123,24 @@ def product( "product", skipna=skipna, dtype=dtype, min_count=min_count ) - def mean(self, skipna: bool = None, dtype: Dtype = decimal.Decimal + def mean( + self, skipna: bool = None, dtype: Dtype = decimal.Decimal ) -> decimal.Decimal: return self.reduce("mean", skipna=skipna, dtype=dtype) def var( - self, skipna: bool = None, ddof: int = 1, dtype: Dtype = decimal.Decimal + self, + skipna: bool = None, + ddof: int = 1, + dtype: Dtype = decimal.Decimal, ) -> decimal.Decimal: return self.reduce("var", skipna=skipna, dtype=dtype, ddof=ddof) def std( - self, skipna: bool = None, ddof: int = 1, dtype: Dtype = decimal.Decimal + self, + skipna: bool = None, + ddof: int = 1, + dtype: Dtype = decimal.Decimal, ) -> decimal.Decimal: return self.reduce("std", skipna=skipna, dtype=dtype, ddof=ddof) diff --git a/python/cudf/cudf/tests/test_reductions.py b/python/cudf/cudf/tests/test_reductions.py index ae1d2c1e19b..ff98560b9e4 100644 --- a/python/cudf/cudf/tests/test_reductions.py +++ b/python/cudf/cudf/tests/test_reductions.py @@ -13,6 +13,7 @@ import cudf from cudf.core import Series +from cudf.core.dtypes import Decimal64Dtype from cudf.tests import utils from cudf.tests.utils import NUMERIC_TYPES, gen_rand, assert_eq @@ -51,13 +52,14 @@ def test_sum_string(): assert got == expected + @pytest.mark.parametrize( "dtype", [Decimal64Dtype(6, 3), Decimal64Dtype(10, 6), Decimal64Dtype(16, 7)], ) @pytest.mark.parametrize("nelem", params_sizes) def test_sum_decimal(dtype, nelem): - data = gen_rand("int64",nelem)/100 + data = gen_rand("int64", nelem) / 100 expected = pd.Series(data).sum() got = cudf.Series([str(x) for x in data]).astype(dtype).sum() @@ -126,7 +128,7 @@ def test_min(dtype, nelem): ) @pytest.mark.parametrize("nelem", params_sizes) def test_min_decimal(dtype, nelem): - data = gen_rand("int64",nelem)/100 + data = gen_rand("int64", nelem) / 100 expected = pd.Series(data).min() got = cudf.Series([str(x) for x in data]).astype(dtype).min() @@ -151,7 +153,7 @@ def test_max(dtype, nelem): ) @pytest.mark.parametrize("nelem", params_sizes) def test_max_decimal(dtype, nelem): - data = gen_rand("int64",nelem)/100 + data = gen_rand("int64", nelem) / 100 expected = pd.Series(data).max() got = cudf.Series([str(x) for x in data]).astype(dtype).max() From c9d83fd54bff61c1334bfe350489006781cab694 Mon Sep 17 00:00:00 2001 From: Chris Jarrett Date: Wed, 31 Mar 2021 17:13:16 -0700 Subject: [PATCH 3/4] Implement tests for product and sum of squares --- python/cudf/cudf/_lib/reduce.pyx | 2 +- python/cudf/cudf/core/column/decimal.py | 37 ++++------------ python/cudf/cudf/tests/test_reductions.py | 53 ++++++++++++++++++----- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 4579530599e..f6a8232f395 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -68,7 +68,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): c_out_dtype )) - py_result = DeviceScalar.from_unique_ptr(move(c_result)) + py_result = DeviceScalar.from_unique_ptr(move(c_result), dtype=col_dtype) return py_result.value diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 52701e7a06a..7204aebbf19 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -104,9 +104,7 @@ def as_string_column( "cudf.core.column.StringColumn", as_column([], dtype="object") ) - def reduce( - self, op: str, skipna: bool = None, **kwargs - ) -> decimal.Decimal: + def reduce(self, op: str, skipna: bool = None, **kwargs) -> Decimal: min_count = kwargs.pop("min_count", 0) preprocessed = self._process_for_reduction( skipna=skipna, min_count=min_count @@ -118,41 +116,24 @@ def reduce( def sum( self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> decimal.Decimal: + ) -> Decimal: return self.reduce( "sum", skipna=skipna, dtype=dtype, min_count=min_count ) def product( self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 - ) -> decimal.Decimal: + ) -> Decimal: return self.reduce( "product", skipna=skipna, dtype=dtype, min_count=min_count ) - def mean( - self, skipna: bool = None, dtype: Dtype = decimal.Decimal - ) -> decimal.Decimal: - return self.reduce("mean", skipna=skipna, dtype=dtype) - - def var( - self, - skipna: bool = None, - ddof: int = 1, - dtype: Dtype = decimal.Decimal, - ) -> decimal.Decimal: - return self.reduce("var", skipna=skipna, dtype=dtype, ddof=ddof) - - def std( - self, - skipna: bool = None, - ddof: int = 1, - dtype: Dtype = decimal.Decimal, - ) -> decimal.Decimal: - return self.reduce("std", skipna=skipna, dtype=dtype, ddof=ddof) - - def sum_of_squares(self, dtype: Dtype = None) -> decimal.Decimal: - return libcudf.reduce.reduce("sum_of_squares", self, dtype=dtype) + def sum_of_squares( + self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 + ) -> Decimal: + return self.reduce( + "sum_of_squares", skipna=skipna, dtype=dtype, min_count=min_count + ) def _binop_scale(l_dtype, r_dtype, op): diff --git a/python/cudf/cudf/tests/test_reductions.py b/python/cudf/cudf/tests/test_reductions.py index ff98560b9e4..c998f308417 100644 --- a/python/cudf/cudf/tests/test_reductions.py +++ b/python/cudf/cudf/tests/test_reductions.py @@ -59,11 +59,12 @@ def test_sum_string(): ) @pytest.mark.parametrize("nelem", params_sizes) def test_sum_decimal(dtype, nelem): - data = gen_rand("int64", nelem) / 100 - expected = pd.Series(data).sum() - got = cudf.Series([str(x) for x in data]).astype(dtype).sum() + data = [str(x) for x in gen_rand("int64", nelem) / 100] - assert_eq(Decimal(expected), got) + expected = pd.Series([Decimal(x) for x in data]).sum() + got = cudf.Series(data).astype(dtype).sum() + + assert_eq(expected, got) @pytest.mark.parametrize("dtype,nelem", params) @@ -86,6 +87,19 @@ def test_product(dtype, nelem): np.testing.assert_approx_equal(expect, got, significant=significant) +@pytest.mark.parametrize( + "dtype", + [Decimal64Dtype(6, 2), Decimal64Dtype(8, 4), Decimal64Dtype(10, 5)], +) +def test_product_decimal(dtype): + data = [str(x) for x in gen_rand("int8", 3) / 10] + + expected = pd.Series([Decimal(x) for x in data]).product() + got = cudf.Series(data).astype(dtype).product() + + assert_eq(expected, got) + + accuracy_for_dtype = {np.float64: 6, np.float32: 5} @@ -110,6 +124,19 @@ def test_sum_of_squares(dtype, nelem): ) +@pytest.mark.parametrize( + "dtype", + [Decimal64Dtype(6, 2), Decimal64Dtype(8, 4), Decimal64Dtype(10, 5)], +) +def test_sum_of_squares_decimal(dtype): + data = [str(x) for x in gen_rand("int8", 3) / 10] + + expected = pd.Series([Decimal(x) for x in data]).pow(2).sum() + got = cudf.Series(data).astype(dtype).sum_of_squares() + + assert_eq(expected, got) + + @pytest.mark.parametrize("dtype,nelem", params) def test_min(dtype, nelem): dtype = np.dtype(dtype).type @@ -128,11 +155,12 @@ def test_min(dtype, nelem): ) @pytest.mark.parametrize("nelem", params_sizes) def test_min_decimal(dtype, nelem): - data = gen_rand("int64", nelem) / 100 - expected = pd.Series(data).min() - got = cudf.Series([str(x) for x in data]).astype(dtype).min() + data = [str(x) for x in gen_rand("int64", nelem) / 100] - assert_eq(Decimal(expected), got) + expected = pd.Series([Decimal(x) for x in data]).min() + got = cudf.Series(data).astype(dtype).min() + + assert_eq(expected, got) @pytest.mark.parametrize("dtype,nelem", params) @@ -153,11 +181,12 @@ def test_max(dtype, nelem): ) @pytest.mark.parametrize("nelem", params_sizes) def test_max_decimal(dtype, nelem): - data = gen_rand("int64", nelem) / 100 - expected = pd.Series(data).max() - got = cudf.Series([str(x) for x in data]).astype(dtype).max() + data = [str(x) for x in gen_rand("int64", nelem) / 100] + + expected = pd.Series([Decimal(x) for x in data]).max() + got = cudf.Series(data).astype(dtype).max() - assert_eq(Decimal(expected), got) + assert_eq(expected, got) @pytest.mark.parametrize("nelem", params_sizes) From d31535bd903d59edaeba714c1ef7a0a8022dded3 Mon Sep 17 00:00:00 2001 From: Chris Jarrett Date: Thu, 1 Apr 2021 10:13:49 -0700 Subject: [PATCH 4/4] Compute precision for output scalar --- python/cudf/cudf/_lib/reduce.pyx | 33 ++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index f6a8232f395..62013ea88ae 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -1,7 +1,8 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. import cudf from cudf.utils.dtypes import is_decimal_dtype +from cudf.core.dtypes import Decimal64Dtype from cudf._lib.cpp.reduce cimport cpp_reduce, cpp_scan, scan_type, cpp_minmax from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.types cimport data_type, type_id @@ -68,7 +69,14 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): c_out_dtype )) - py_result = DeviceScalar.from_unique_ptr(move(c_result), dtype=col_dtype) + if c_result.get()[0].type().id() == libcudf_types.type_id.DECIMAL64: + scale = -c_result.get()[0].type().scale() + precision = _reduce_precision(col_dtype, reduction_op, len(incol)) + py_result = DeviceScalar.from_unique_ptr( + move(c_result), dtype=Decimal64Dtype(precision, scale) + ) + else: + py_result = DeviceScalar.from_unique_ptr(move(c_result)) return py_result.value @@ -131,3 +139,24 @@ def minmax(Column incol): py_result_max = DeviceScalar.from_unique_ptr(move(c_result.second)) return cudf.Scalar(py_result_min), cudf.Scalar(py_result_max) + + +def _reduce_precision(dtype, op, nrows): + """ + Returns the result precision when performing the reduce + operation `op` for the given dtype and column size. + + See: https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql + """ # noqa: E501 + p = dtype.precision + if op in ("min", "max"): + new_p = p + elif op == "sum": + new_p = p + nrows - 1 + elif op == "product": + new_p = p * nrows + nrows - 1 + elif op == "sum_of_squares": + new_p = 2 * p + nrows + else: + raise NotImplementedError() + return max(min(new_p, Decimal64Dtype.MAX_PRECISION), 0)