diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index a17cace3c81..a31eaa52641 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -364,18 +364,40 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): else: raise NotImplementedError() + try: + if isinstance(lhs_dtype, type(rhs_dtype)): + # SCENARIO 1: If `lhs_dtype` & `rhs_dtype` are same, then try to + # see if `precision` & `scale` can be fit into this type. + return lhs_dtype.__class__(precision=precision, scale=scale) + else: + # SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes, + # then try to see if `precision` & `scale` can be fit into the type + # with greater MAX_PRECISION (i.e., the bigger dtype). + if lhs_dtype.MAX_PRECISION >= rhs_dtype.MAX_PRECISION: + return lhs_dtype.__class__(precision=precision, scale=scale) + else: + return rhs_dtype.__class__(precision=precision, scale=scale) + except ValueError: + # Call to _validate fails, which means we need + # to goto SCENARIO 3. + pass + + # SCENARIO 3: If either of the above two scenarios fail, then get the + # MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check + # and return a dtype that is greater than or equal to input dtype that + # can fit `precision` & `scale`. + max_precision = max(lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION) for decimal_type in ( cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype, ): - try: - min_decimal_type = decimal_type(precision=precision, scale=scale) - except ValueError: - # Call to _validate fails, which means we need - # to try the next dtype - pass - else: - return min_decimal_type + if decimal_type.MAX_PRECISION >= max_precision: + try: + return decimal_type(precision=precision, scale=scale) + except ValueError: + # Call to _validate fails, which means we need + # to try the next dtype + continue raise OverflowError("Maximum supported decimal type is Decimal128") diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 02ca7a0cd58..f688cc3b642 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1800,7 +1800,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "2.0"], cudf.Decimal64Dtype(scale=2, precision=3), ["3.0", "4.0"], - cudf.Decimal32Dtype(scale=2, precision=4), + cudf.Decimal64Dtype(scale=2, precision=4), ), ( operator.add, @@ -1809,7 +1809,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", "3.005"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.add, @@ -1827,7 +1827,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1836,7 +1836,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1854,7 +1854,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "3.0"], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", "6.0"], - cudf.Decimal32Dtype(scale=5, precision=8), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul, @@ -1863,7 +1863,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=3, precision=4), ["10.0", "40.0"], - cudf.Decimal32Dtype(scale=1, precision=8), + cudf.Decimal64Dtype(scale=1, precision=8), ), ( operator.mul, @@ -1872,7 +1872,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.343", "0.500"], cudf.Decimal64Dtype(scale=3, precision=3), ["343.0", "1000.0"], - cudf.Decimal32Dtype(scale=0, precision=8), + cudf.Decimal64Dtype(scale=0, precision=8), ), ( operator.truediv, @@ -1908,7 +1908,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None, "2.0"], cudf.Decimal64Dtype(scale=1, precision=2), ["3.0", None, "4.0"], - cudf.Decimal32Dtype(scale=1, precision=3), + cudf.Decimal64Dtype(scale=1, precision=3), ), ( operator.add, @@ -1917,7 +1917,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1926,7 +1926,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1935,7 +1935,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.mul, @@ -1944,7 +1944,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", None], - cudf.Decimal32Dtype(scale=5, precision=8), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul,