Skip to content

Commit

Permalink
Avoid decimal type narrowing for decimal binops (#10299)
Browse files Browse the repository at this point in the history
Fixes: #10282 

This PR removes decimal type narrowing and also updates the tests accordingly.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #10299
  • Loading branch information
galipremsagar authored Feb 23, 2022
1 parent 0ae9dc6 commit 496f452
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 20 deletions.
38 changes: 30 additions & 8 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,40 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op):
else:
raise NotImplementedError()

try:
if isinstance(lhs_dtype, type(rhs_dtype)):
# SCENARIO 1: If `lhs_dtype` & `rhs_dtype` are same, then try to
# see if `precision` & `scale` can be fit into this type.
return lhs_dtype.__class__(precision=precision, scale=scale)
else:
# SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes,
# then try to see if `precision` & `scale` can be fit into the type
# with greater MAX_PRECISION (i.e., the bigger dtype).
if lhs_dtype.MAX_PRECISION >= rhs_dtype.MAX_PRECISION:
return lhs_dtype.__class__(precision=precision, scale=scale)
else:
return rhs_dtype.__class__(precision=precision, scale=scale)
except ValueError:
# Call to _validate fails, which means we need
# to goto SCENARIO 3.
pass

# SCENARIO 3: If either of the above two scenarios fail, then get the
# MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check
# and return a dtype that is greater than or equal to input dtype that
# can fit `precision` & `scale`.
max_precision = max(lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION)
for decimal_type in (
cudf.Decimal32Dtype,
cudf.Decimal64Dtype,
cudf.Decimal128Dtype,
):
try:
min_decimal_type = decimal_type(precision=precision, scale=scale)
except ValueError:
# Call to _validate fails, which means we need
# to try the next dtype
pass
else:
return min_decimal_type
if decimal_type.MAX_PRECISION >= max_precision:
try:
return decimal_type(precision=precision, scale=scale)
except ValueError:
# Call to _validate fails, which means we need
# to try the next dtype
continue

raise OverflowError("Maximum supported decimal type is Decimal128")
24 changes: 12 additions & 12 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,7 @@ def test_binops_with_NA_consistent(dtype, op):
["1.5", "2.0"],
cudf.Decimal64Dtype(scale=2, precision=3),
["3.0", "4.0"],
cudf.Decimal32Dtype(scale=2, precision=4),
cudf.Decimal64Dtype(scale=2, precision=4),
),
(
operator.add,
Expand All @@ -1809,7 +1809,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", "1.005"],
cudf.Decimal64Dtype(scale=3, precision=4),
["3.75", "3.005"],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.add,
Expand All @@ -1827,7 +1827,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", "1.005"],
cudf.Decimal64Dtype(scale=3, precision=4),
["-0.75", "0.995"],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.sub,
Expand All @@ -1836,7 +1836,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", "1.005"],
cudf.Decimal64Dtype(scale=3, precision=4),
["-0.75", "0.995"],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.sub,
Expand All @@ -1854,7 +1854,7 @@ def test_binops_with_NA_consistent(dtype, op):
["1.5", "3.0"],
cudf.Decimal64Dtype(scale=3, precision=4),
["2.25", "6.0"],
cudf.Decimal32Dtype(scale=5, precision=8),
cudf.Decimal64Dtype(scale=5, precision=8),
),
(
operator.mul,
Expand All @@ -1863,7 +1863,7 @@ def test_binops_with_NA_consistent(dtype, op):
["0.1", "0.2"],
cudf.Decimal64Dtype(scale=3, precision=4),
["10.0", "40.0"],
cudf.Decimal32Dtype(scale=1, precision=8),
cudf.Decimal64Dtype(scale=1, precision=8),
),
(
operator.mul,
Expand All @@ -1872,7 +1872,7 @@ def test_binops_with_NA_consistent(dtype, op):
["0.343", "0.500"],
cudf.Decimal64Dtype(scale=3, precision=3),
["343.0", "1000.0"],
cudf.Decimal32Dtype(scale=0, precision=8),
cudf.Decimal64Dtype(scale=0, precision=8),
),
(
operator.truediv,
Expand Down Expand Up @@ -1908,7 +1908,7 @@ def test_binops_with_NA_consistent(dtype, op):
["1.5", None, "2.0"],
cudf.Decimal64Dtype(scale=1, precision=2),
["3.0", None, "4.0"],
cudf.Decimal32Dtype(scale=1, precision=3),
cudf.Decimal64Dtype(scale=1, precision=3),
),
(
operator.add,
Expand All @@ -1917,7 +1917,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", "1.005"],
cudf.Decimal64Dtype(scale=3, precision=4),
["3.75", None],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.sub,
Expand All @@ -1926,7 +1926,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", None],
cudf.Decimal64Dtype(scale=3, precision=4),
["-0.75", None],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.sub,
Expand All @@ -1935,7 +1935,7 @@ def test_binops_with_NA_consistent(dtype, op):
["2.25", None],
cudf.Decimal64Dtype(scale=3, precision=4),
["-0.75", None],
cudf.Decimal32Dtype(scale=3, precision=5),
cudf.Decimal64Dtype(scale=3, precision=5),
),
(
operator.mul,
Expand All @@ -1944,7 +1944,7 @@ def test_binops_with_NA_consistent(dtype, op):
["1.5", None],
cudf.Decimal64Dtype(scale=3, precision=4),
["2.25", None],
cudf.Decimal32Dtype(scale=5, precision=8),
cudf.Decimal64Dtype(scale=5, precision=8),
),
(
operator.mul,
Expand Down

0 comments on commit 496f452

Please sign in to comment.