Skip to content

Commit

Permalink
Compute precision for output scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Jarrett committed Apr 1, 2021
1 parent c9d83fd commit d31535b
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions python/cudf/cudf/_lib/reduce.pyx
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.

import cudf
from cudf.utils.dtypes import is_decimal_dtype
from cudf.core.dtypes import Decimal64Dtype
from cudf._lib.cpp.reduce cimport cpp_reduce, cpp_scan, scan_type, cpp_minmax
from cudf._lib.cpp.scalar.scalar cimport scalar
from cudf._lib.cpp.types cimport data_type, type_id
Expand Down Expand Up @@ -68,7 +69,14 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs):
c_out_dtype
))

py_result = DeviceScalar.from_unique_ptr(move(c_result), dtype=col_dtype)
if c_result.get()[0].type().id() == libcudf_types.type_id.DECIMAL64:
scale = -c_result.get()[0].type().scale()
precision = _reduce_precision(col_dtype, reduction_op, len(incol))
py_result = DeviceScalar.from_unique_ptr(
move(c_result), dtype=Decimal64Dtype(precision, scale)
)
else:
py_result = DeviceScalar.from_unique_ptr(move(c_result))
return py_result.value


Expand Down Expand Up @@ -131,3 +139,24 @@ def minmax(Column incol):
py_result_max = DeviceScalar.from_unique_ptr(move(c_result.second))

return cudf.Scalar(py_result_min), cudf.Scalar(py_result_max)


def _reduce_precision(dtype, op, nrows):
"""
Returns the result precision when performing the reduce
operation `op` for the given dtype and column size.
See: https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
""" # noqa: E501
p = dtype.precision
if op in ("min", "max"):
new_p = p
elif op == "sum":
new_p = p + nrows - 1
elif op == "product":
new_p = p * nrows + nrows - 1
elif op == "sum_of_squares":
new_p = 2 * p + nrows
else:
raise NotImplementedError()
return max(min(new_p, Decimal64Dtype.MAX_PRECISION), 0)

0 comments on commit d31535b

Please sign in to comment.