Skip to content

Commit

Permalink
Enable mixed-dtype decimal/scalar binary operations (#13171)
Browse files Browse the repository at this point in the history
Also closes #13170

Authors:
  - Ashwin Srinath (https://github.com/shwina)
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #13171
  • Loading branch information
shwina authored Apr 26, 2023
1 parent 2a511ad commit 5df4367
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 94 deletions.
8 changes: 7 additions & 1 deletion python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ def normalize_binop_value(self, other):
other = other.astype(self.dtype)
return other
elif is_scalar(other) and isinstance(other, (int, Decimal)):
return cudf.Scalar(Decimal(other), dtype=self.dtype)
other = Decimal(other)
metadata = other.as_tuple()
precision = max(len(metadata.digits), metadata.exponent)
scale = -metadata.exponent
return cudf.Scalar(
other, dtype=self.dtype.__class__(precision, scale)
)
return NotImplemented

def _decimal_quantile(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def _validate(cls, precision, scale=0):
@classmethod
def _from_decimal(cls, decimal):
"""
Create a cudf.Decimal32Dtype from a decimal.Decimal object
Create a cudf.DecimalDtype from a decimal.Decimal object
"""
metadata = decimal.as_tuple()
precision = max(len(metadata.digits), -metadata.exponent)
Expand Down
115 changes: 24 additions & 91 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,7 +2554,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(1),
["101", "201"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
False,
),
(
Expand All @@ -2563,7 +2563,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
1,
["101", "201"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
False,
),
(
Expand All @@ -2572,16 +2572,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("1.5"),
["101.5", "201.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
False,
),
(
operator.add,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("1.5")),
["101.5", "201.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=1, precision=7),
False,
),
(
Expand All @@ -2590,7 +2581,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(1),
["101", "201"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
True,
),
(
Expand All @@ -2599,7 +2590,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
1,
["101", "201"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
True,
),
(
Expand All @@ -2608,16 +2599,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("1.5"),
["101.5", "201.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
True,
),
(
operator.add,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("1.5")),
["101.5", "201.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=1, precision=7),
True,
),
(
Expand All @@ -2626,7 +2608,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
1,
["100", "200"],
cudf.Decimal32Dtype(scale=-2, precision=5),
cudf.Decimal64Dtype(scale=-2, precision=5),
False,
),
(
Expand All @@ -2635,7 +2617,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(2),
["200", "400"],
cudf.Decimal32Dtype(scale=-2, precision=5),
cudf.Decimal64Dtype(scale=-2, precision=5),
False,
),
(
Expand All @@ -2644,16 +2626,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("1.5"),
["150", "300"],
cudf.Decimal32Dtype(scale=-1, precision=6),
False,
),
(
operator.mul,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("1.5")),
["150", "300"],
cudf.Decimal32Dtype(scale=-1, precision=6),
cudf.Decimal64Dtype(scale=-1, precision=6),
False,
),
(
Expand All @@ -2662,7 +2635,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
1,
["100", "200"],
cudf.Decimal32Dtype(scale=-2, precision=5),
cudf.Decimal64Dtype(scale=-2, precision=5),
True,
),
(
Expand All @@ -2671,7 +2644,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(2),
["200", "400"],
cudf.Decimal32Dtype(scale=-2, precision=5),
cudf.Decimal64Dtype(scale=-2, precision=5),
True,
),
(
Expand All @@ -2680,16 +2653,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("1.5"),
["150", "300"],
cudf.Decimal32Dtype(scale=-1, precision=6),
True,
),
(
operator.mul,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("1.5")),
["150", "300"],
cudf.Decimal32Dtype(scale=-1, precision=6),
cudf.Decimal64Dtype(scale=-1, precision=6),
True,
),
(
Expand All @@ -2707,7 +2671,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=2, precision=5),
decimal.Decimal(2),
["50", "100"],
cudf.Decimal32Dtype(scale=6, precision=9),
cudf.Decimal64Dtype(scale=6, precision=9),
False,
),
(
Expand All @@ -2716,16 +2680,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=2, precision=4),
decimal.Decimal("1.5"),
["23.4", "36.6"],
cudf.Decimal32Dtype(scale=6, precision=9),
False,
),
(
operator.truediv,
["22.2", "93.6"],
cudf.Decimal64Dtype(scale=1, precision=3),
cudf.Scalar(decimal.Decimal("1.5")),
["14", "62"],
cudf.Decimal32Dtype(scale=6, precision=9),
cudf.Decimal64Dtype(scale=6, precision=9),
False,
),
(
Expand All @@ -2734,7 +2689,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=2, precision=5),
1,
["0", "0"],
cudf.Decimal32Dtype(scale=6, precision=9),
cudf.Decimal64Dtype(scale=6, precision=9),
True,
),
(
Expand All @@ -2752,16 +2707,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=2, precision=3),
decimal.Decimal("8.55"),
["7", "1"],
cudf.Decimal32Dtype(scale=6, precision=9),
True,
),
(
operator.truediv,
["1.1", "42.8"],
cudf.Decimal64Dtype(scale=1, precision=3),
cudf.Scalar(decimal.Decimal("90.84")),
["82.5", "2.1"],
cudf.Decimal32Dtype(scale=6, precision=9),
cudf.Decimal64Dtype(scale=6, precision=9),
True,
),
(
Expand All @@ -2770,7 +2716,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(2),
["98", "198"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
False,
),
(
Expand All @@ -2779,7 +2725,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("2.5"),
["97.5", "197.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=1, precision=7),
False,
),
(
Expand All @@ -2788,16 +2734,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
4,
["96", "196"],
cudf.Decimal32Dtype(scale=0, precision=6),
False,
),
(
operator.sub,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("2.5")),
["97.5", "197.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=0, precision=6),
False,
),
(
Expand All @@ -2806,7 +2743,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal(2),
["-98", "-198"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
True,
),
(
Expand All @@ -2815,7 +2752,7 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
4,
["-96", "-196"],
cudf.Decimal32Dtype(scale=0, precision=6),
cudf.Decimal64Dtype(scale=0, precision=6),
True,
),
(
Expand All @@ -2824,24 +2761,20 @@ def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected):
cudf.Decimal64Dtype(scale=-2, precision=3),
decimal.Decimal("2.5"),
["-97.5", "-197.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=1, precision=7),
True,
),
(
operator.sub,
["100", "200"],
cudf.Decimal64Dtype(scale=-2, precision=3),
cudf.Scalar(decimal.Decimal("2.5")),
decimal.Decimal("2.5"),
["-97.5", "-197.5"],
cudf.Decimal32Dtype(scale=1, precision=7),
cudf.Decimal64Dtype(scale=1, precision=7),
True,
),
],
)
@pytest_xfail(
reason="binop operations not supported for different "
"bit-width decimal types"
)
def test_binops_decimal_scalar(args):
op, lhs, l_dtype, rhs, expect, expect_dtype, reflect = args

Expand Down
6 changes: 5 additions & 1 deletion python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.

import decimal
from decimal import Decimal
Expand Down Expand Up @@ -385,3 +385,7 @@ def test_decimal_overflow():
s = cudf.Series([Decimal("0.0009384233522166997927180531650178250")])
result = s * s
assert_eq(cudf.Decimal128Dtype(precision=38, scale=37), result.dtype)

s = cudf.Series([1, 2], dtype=cudf.Decimal128Dtype(precision=38, scale=0))
result = s * Decimal("1.0")
assert_eq(cudf.Decimal128Dtype(precision=38, scale=-2), result.dtype)

0 comments on commit 5df4367

Please sign in to comment.