Skip to content

Commit

Permalink
Remove overflow error during decimal binops (#12063)
Browse files Browse the repository at this point in the history
Fixes: #11337 

- [x] This PR removes raising of an overflow error and rather let's the data overflow similar to what we do with other numeric dtypes.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Bradley Dice (https://github.com/bdice)

URL: #12063
  • Loading branch information
galipremsagar authored Nov 4, 2022
1 parent 0278485 commit b1c2520
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
10 changes: 9 additions & 1 deletion python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,12 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op):
# to try the next dtype
continue

raise OverflowError("Maximum supported decimal type is Decimal128")
# Instead of raising an overflow error, we create a `Decimal128Dtype`
# with max possible scale & precision, see example of this demonstration
# here: https://learn.microsoft.com/en-us/sql/t-sql/data-types/
# precision-scale-and-length-transact-sql?view=sql-server-ver16#examples
scale = min(
scale, cudf.Decimal128Dtype.MAX_PRECISION - (precision - scale)
)
precision = min(cudf.Decimal128Dtype.MAX_PRECISION, max_precision)
return cudf.Decimal128Dtype(precision=precision, scale=scale)
8 changes: 7 additions & 1 deletion python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

import decimal
from decimal import Decimal
Expand Down Expand Up @@ -377,3 +377,9 @@ def test_decimal_invalid_precision():

with pytest.raises(pa.ArrowInvalid):
_ = cudf.Series([Decimal("300")], dtype=cudf.Decimal64Dtype(2, 1))


def test_decimal_overflow():
s = cudf.Series([Decimal("0.0009384233522166997927180531650178250")])
result = s * s
assert_eq(cudf.Decimal128Dtype(precision=38, scale=37), result.dtype)

0 comments on commit b1c2520

Please sign in to comment.