diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index e3d88424b8a..1159a0257e8 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -294,8 +294,10 @@ def _binop_precision(l_dtype, r_dtype, op): p1, p2 = l_dtype.precision, r_dtype.precision s1, s2 = l_dtype.scale, r_dtype.scale if op in ("add", "sub"): - return max(s1, s2) + max(p1 - s1, p2 - s2) + 1 + result = max(s1, s2) + max(p1 - s1, p2 - s2) + 1 elif op in ("mul", "div"): - return p1 + p2 + 1 + result = p1 + p2 + 1 else: raise NotImplementedError() + + return min(result, cudf.Decimal64Dtype.MAX_PRECISION) diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 833b41636e3..40df234580c 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1772,11 +1772,11 @@ def _decimal_series(input, dtype): ( operator.add, ["100", "200"], - cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Decimal64Dtype(scale=-2, precision=17), ["0.1", "0.2"], cudf.Decimal64Dtype(scale=3, precision=4), ["100.1", "200.2"], - cudf.Decimal64Dtype(scale=3, precision=9), + cudf.Decimal64Dtype(scale=3, precision=18), ), ( operator.sub, @@ -1799,11 +1799,11 @@ def _decimal_series(input, dtype): ( operator.sub, ["100", "200"], - cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Decimal64Dtype(scale=-2, precision=10), ["0.1", "0.2"], - cudf.Decimal64Dtype(scale=3, precision=4), + cudf.Decimal64Dtype(scale=6, precision=10), ["99.9", "199.8"], - cudf.Decimal64Dtype(scale=3, precision=9), + cudf.Decimal64Dtype(scale=6, precision=18), ), ( operator.mul, @@ -1853,11 +1853,11 @@ def _decimal_series(input, dtype): ( operator.truediv, ["132.86", "15.25"], - cudf.Decimal64Dtype(scale=4, precision=6), + cudf.Decimal64Dtype(scale=4, precision=14), ["2.34", "8.50"], - cudf.Decimal64Dtype(scale=2, precision=4), + cudf.Decimal64Dtype(scale=2, precision=8), ["56.77", "1.79"], - cudf.Decimal64Dtype(scale=2, precision=11), + cudf.Decimal64Dtype(scale=2, precision=18), ), ( operator.add, @@ -1907,11 +1907,11 @@ def _decimal_series(input, dtype): ( operator.mul, ["100", "200"], - cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Decimal64Dtype(scale=-2, precision=10), ["0.1", None], - cudf.Decimal64Dtype(scale=3, precision=4), + cudf.Decimal64Dtype(scale=3, precision=12), ["10.0", None], - cudf.Decimal64Dtype(scale=1, precision=8), + cudf.Decimal64Dtype(scale=1, precision=18), ), ( operator.eq,