From 7b084c17bd556b937fe0b102dbdc4098f4c05017 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Tue, 15 Feb 2022 09:32:42 -0800 Subject: [PATCH 1/8] change type handling --- python/cudf/cudf/core/column/decimal.py | 49 +++++++++++++++++++++---- python/cudf/cudf/tests/test_binops.py | 24 ++++++------ 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index a17cace3c81..a05c7da4e34 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -364,18 +364,51 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): else: raise NotImplementedError() - for decimal_type in ( - cudf.Decimal32Dtype, - cudf.Decimal64Dtype, - cudf.Decimal128Dtype, - ): + if isinstance(lhs_dtype, type(rhs_dtype)): + # SCENARIO 1: If `lhs_dtype` & `rhs_dtype` are same, then try to + # see if `precision` & `scale` can be fit into this type. + try: + return lhs_dtype.__class__(precision=precision, scale=scale) + except ValueError: + # Call to _validate fails, which means we need + # to SCENARIO 3. + pass + else: + # SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes, + # then try to see if `precision` & `scale` can be fit into the type + # with greater MAX_PRECISION (i.e., the bigger dtype). try: - min_decimal_type = decimal_type(precision=precision, scale=scale) + if lhs_dtype.MAX_PRECISION > rhs_dtype.MAX_PRECISION: + return lhs_dtype.__class__(precision=precision, scale=scale) + else: + return rhs_dtype.__class__(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need # to try the next dtype pass - else: - return min_decimal_type + + # SCENARIO 3: If either of the above two scenarios fail, then get the + # MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check + # and return a dtype that is greater than or equal to input dtype that + # can fit `precision` & `scale`. + lhs_rhs_max_precision = max( + lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION + ) + for decimal_type in ( + cudf.Decimal32Dtype, + cudf.Decimal64Dtype, + cudf.Decimal128Dtype, + ): + if decimal_type.MAX_PRECISION >= lhs_rhs_max_precision: + try: + min_decimal_type = decimal_type( + precision=precision, scale=scale + ) + except ValueError: + # Call to _validate fails, which means we need + # to try the next dtype + pass + else: + return min_decimal_type raise OverflowError("Maximum supported decimal type is Decimal128") diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 02ca7a0cd58..f688cc3b642 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1800,7 +1800,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "2.0"], cudf.Decimal64Dtype(scale=2, precision=3), ["3.0", "4.0"], - cudf.Decimal32Dtype(scale=2, precision=4), + cudf.Decimal64Dtype(scale=2, precision=4), ), ( operator.add, @@ -1809,7 +1809,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", "3.005"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.add, @@ -1827,7 +1827,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1836,7 +1836,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", "0.995"], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1854,7 +1854,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", "3.0"], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", "6.0"], - cudf.Decimal32Dtype(scale=5, precision=8), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul, @@ -1863,7 +1863,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.1", "0.2"], cudf.Decimal64Dtype(scale=3, precision=4), ["10.0", "40.0"], - cudf.Decimal32Dtype(scale=1, precision=8), + cudf.Decimal64Dtype(scale=1, precision=8), ), ( operator.mul, @@ -1872,7 +1872,7 @@ def test_binops_with_NA_consistent(dtype, op): ["0.343", "0.500"], cudf.Decimal64Dtype(scale=3, precision=3), ["343.0", "1000.0"], - cudf.Decimal32Dtype(scale=0, precision=8), + cudf.Decimal64Dtype(scale=0, precision=8), ), ( operator.truediv, @@ -1908,7 +1908,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None, "2.0"], cudf.Decimal64Dtype(scale=1, precision=2), ["3.0", None, "4.0"], - cudf.Decimal32Dtype(scale=1, precision=3), + cudf.Decimal64Dtype(scale=1, precision=3), ), ( operator.add, @@ -1917,7 +1917,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", "1.005"], cudf.Decimal64Dtype(scale=3, precision=4), ["3.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1926,7 +1926,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.sub, @@ -1935,7 +1935,7 @@ def test_binops_with_NA_consistent(dtype, op): ["2.25", None], cudf.Decimal64Dtype(scale=3, precision=4), ["-0.75", None], - cudf.Decimal32Dtype(scale=3, precision=5), + cudf.Decimal64Dtype(scale=3, precision=5), ), ( operator.mul, @@ -1944,7 +1944,7 @@ def test_binops_with_NA_consistent(dtype, op): ["1.5", None], cudf.Decimal64Dtype(scale=3, precision=4), ["2.25", None], - cudf.Decimal32Dtype(scale=5, precision=8), + cudf.Decimal64Dtype(scale=5, precision=8), ), ( operator.mul, From 1f5ca2bd3aec9ccb0f13c62f0c36b7c2ce2a9a06 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Tue, 15 Feb 2022 09:33:18 -0800 Subject: [PATCH 2/8] change comment --- python/cudf/cudf/core/column/decimal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index a05c7da4e34..8d234eaea9f 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -384,7 +384,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): return rhs_dtype.__class__(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need - # to try the next dtype + # to SCENARIO 3. pass # SCENARIO 3: If either of the above two scenarios fail, then get the From 6d2cbab7b11eb40bf4447269e69cde982b6567e7 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Tue, 15 Feb 2022 11:38:02 -0600 Subject: [PATCH 3/8] Update decimal.py --- python/cudf/cudf/core/column/decimal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 8d234eaea9f..b41dab28de4 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -371,7 +371,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): return lhs_dtype.__class__(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need - # to SCENARIO 3. + # to goto SCENARIO 3. pass else: # SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes, @@ -384,7 +384,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): return rhs_dtype.__class__(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need - # to SCENARIO 3. + # to goto SCENARIO 3. pass # SCENARIO 3: If either of the above two scenarios fail, then get the From a2c6e8a04a87a417a15e6e3c0af1acd66e1791eb Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 16 Feb 2022 14:10:41 -0600 Subject: [PATCH 4/8] Update python/cudf/cudf/core/column/decimal.py Co-authored-by: Bradley Dice --- python/cudf/cudf/core/column/decimal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index b41dab28de4..4f05ad3280e 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -378,7 +378,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): # then try to see if `precision` & `scale` can be fit into the type # with greater MAX_PRECISION (i.e., the bigger dtype). try: - if lhs_dtype.MAX_PRECISION > rhs_dtype.MAX_PRECISION: + if lhs_dtype.MAX_PRECISION >= rhs_dtype.MAX_PRECISION: return lhs_dtype.__class__(precision=precision, scale=scale) else: return rhs_dtype.__class__(precision=precision, scale=scale) From bd54b5e2c3d8a08146271dc245a41fdf7c87279e Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 16 Feb 2022 14:10:49 -0600 Subject: [PATCH 5/8] Update python/cudf/cudf/core/column/decimal.py Co-authored-by: Bradley Dice --- python/cudf/cudf/core/column/decimal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 4f05ad3280e..34e98d082fc 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -391,7 +391,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): # MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check # and return a dtype that is greater than or equal to input dtype that # can fit `precision` & `scale`. - lhs_rhs_max_precision = max( + max_precision = max( lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION ) for decimal_type in ( From c12f9f799b0758d7f27379f579e0bfd4733e869f Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 16 Feb 2022 14:11:17 -0600 Subject: [PATCH 6/8] Update python/cudf/cudf/core/column/decimal.py Co-authored-by: Bradley Dice --- python/cudf/cudf/core/column/decimal.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 34e98d082fc..ed7c982906b 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -401,9 +401,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): ): if decimal_type.MAX_PRECISION >= lhs_rhs_max_precision: try: - min_decimal_type = decimal_type( - precision=precision, scale=scale - ) + return decimal_type(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need # to try the next dtype From 71572ecee7395c24d15993220960e4dd684aebce Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 16 Feb 2022 12:12:40 -0800 Subject: [PATCH 7/8] cleanup --- python/cudf/cudf/core/column/decimal.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index ed7c982906b..ee01874e005 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -391,22 +391,18 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): # MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check # and return a dtype that is greater than or equal to input dtype that # can fit `precision` & `scale`. - max_precision = max( - lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION - ) + max_precision = max(lhs_dtype.MAX_PRECISION, rhs_dtype.MAX_PRECISION) for decimal_type in ( cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype, ): - if decimal_type.MAX_PRECISION >= lhs_rhs_max_precision: + if decimal_type.MAX_PRECISION >= max_precision: try: return decimal_type(precision=precision, scale=scale) except ValueError: # Call to _validate fails, which means we need # to try the next dtype pass - else: - return min_decimal_type raise OverflowError("Maximum supported decimal type is Decimal128") From 56987a4d5609bb65befdd0ebc084924be1df9ffc Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 23 Feb 2022 09:37:28 -0600 Subject: [PATCH 8/8] Apply suggestions from code review Co-authored-by: Bradley Dice --- python/cudf/cudf/core/column/decimal.py | 31 +++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index ee01874e005..a31eaa52641 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -364,28 +364,23 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): else: raise NotImplementedError() - if isinstance(lhs_dtype, type(rhs_dtype)): - # SCENARIO 1: If `lhs_dtype` & `rhs_dtype` are same, then try to - # see if `precision` & `scale` can be fit into this type. - try: + try: + if isinstance(lhs_dtype, type(rhs_dtype)): + # SCENARIO 1: If `lhs_dtype` & `rhs_dtype` are same, then try to + # see if `precision` & `scale` can be fit into this type. return lhs_dtype.__class__(precision=precision, scale=scale) - except ValueError: - # Call to _validate fails, which means we need - # to goto SCENARIO 3. - pass - else: - # SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes, - # then try to see if `precision` & `scale` can be fit into the type - # with greater MAX_PRECISION (i.e., the bigger dtype). - try: + else: + # SCENARIO 2: If `lhs_dtype` & `rhs_dtype` are of different dtypes, + # then try to see if `precision` & `scale` can be fit into the type + # with greater MAX_PRECISION (i.e., the bigger dtype). if lhs_dtype.MAX_PRECISION >= rhs_dtype.MAX_PRECISION: return lhs_dtype.__class__(precision=precision, scale=scale) else: return rhs_dtype.__class__(precision=precision, scale=scale) - except ValueError: - # Call to _validate fails, which means we need - # to goto SCENARIO 3. - pass + except ValueError: + # Call to _validate fails, which means we need + # to goto SCENARIO 3. + pass # SCENARIO 3: If either of the above two scenarios fail, then get the # MAX_PRECISION of `lhs_dtype` & `rhs_dtype` so that we can only check @@ -403,6 +398,6 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): except ValueError: # Call to _validate fails, which means we need # to try the next dtype - pass + continue raise OverflowError("Maximum supported decimal type is Decimal128")