From 074de5a9e7d758dc6e9c3e623e71dbf94a0cc52e Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Thu, 3 Nov 2022 12:48:22 -0700 Subject: [PATCH] remove overflow error in decimal --- python/cudf/cudf/core/column/decimal.py | 10 +++++++++- python/cudf/cudf/tests/test_decimal.py | 8 +++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 0beb07bb591..5ee9024a0d8 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -399,4 +399,12 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): # to try the next dtype continue - raise OverflowError("Maximum supported decimal type is Decimal128") + # Instead of raising an overflow error, we create a `Decimal128Dtype` + # with max possible scale & precision, see example of this demonstration + # here: https://learn.microsoft.com/en-us/sql/t-sql/data-types/ + # precision-scale-and-length-transact-sql?view=sql-server-ver16#examples + scale = min( + scale, cudf.Decimal128Dtype.MAX_PRECISION - (precision - scale) + ) + precision = min(cudf.Decimal128Dtype.MAX_PRECISION, max_precision) + return cudf.Decimal128Dtype(precision=precision, scale=scale) diff --git a/python/cudf/cudf/tests/test_decimal.py b/python/cudf/cudf/tests/test_decimal.py index c37381a3af9..c7174adf342 100644 --- a/python/cudf/cudf/tests/test_decimal.py +++ b/python/cudf/cudf/tests/test_decimal.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. import decimal from decimal import Decimal @@ -377,3 +377,9 @@ def test_decimal_invalid_precision(): with pytest.raises(pa.ArrowInvalid): _ = cudf.Series([Decimal("300")], dtype=cudf.Decimal64Dtype(2, 1)) + + +def test_decimal_overflow(): + s = cudf.Series([Decimal("0.0009384233522166997927180531650178250")]) + result = s * s + assert_eq(cudf.Decimal128Dtype(precision=38, scale=37), result.dtype)