-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Avoid decimal
type narrowing for decimal binops
#10299
Conversation
Codecov Report
@@ Coverage Diff @@
## branch-22.04 #10299 +/- ##
================================================
- Coverage 10.67% 10.62% -0.06%
================================================
Files 122 122
Lines 20878 20977 +99
================================================
Hits 2228 2228
- Misses 18650 18749 +99
Continue to review full report at Codecov.
|
@@ -364,18 +364,51 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): | |||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a comment above,
This should at some point be hooked up to libcudf's binary_operation_fixed_point_scale
Do we only support add/sub/mul/div operations right now in Python because of limitations in this function? I know that other operations are implemented in libcudf, so piping that through might be a significant improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we only support add/sub/mul/div operations right now in Python because of limitations in this function?
Not just binary_operation_fixed_point_scale
but I think support for other binop's are not supported from libcudf side.
Looking into binary_operation_fixed_point_scale
, it seems the formula for DIV
is wrong? I could be wrong here but don't match what is specified here: https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
Though libcudf doesn't take precision
as input the python side will need calculation so probably better to have those two computations in a single place rather than having to have to look at two places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for other operators exists, e.g. MOD
/ PMOD
/ PYMOD
: #10179.
I'm fine with keeping both precision/scale calculations together here. I just wanted to make a note to ask, since I saw the comment above.
There may or may not be issues with the scale/precision calculations. I think the page you referenced has different conventions than libcudf. In my understanding:
- libcudf's scale represents powers of the radix (base 10 or base 2)
- libcudf's precision (32, 64, 128) represents bits (powers of two) used to store the integral part
Neither value appears to correspond to the linked SQL docs. That page appears to always use powers of 10 for both scale and precision. Also the definition of scale is the negative of libcudf's definition. It does not surprise me that these different conventions would result in different expressions. I spent an hour looking into this but I have no idea how to make the two definitions mathematically correspond.
Working through an example calculation here, for the SQL docs:
e1 = 4.096
p1 = 4
s1 = 3
e2 = 3.2
p2 = 2
s2 = 1
s = max(6, s1 + p2 + 1)
p = p1 - s1 + s2 + s
print(f"{e1/e2=}") # e1/e2=1.28
print(f"{p=}, {s=}") # p=8, s=6
I was confused and gave up at this point -- how could 1.28
have p=8, s=6
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @bdice, I think @codereport would have a better understanding on this than me. But I'm merging these changes for now and we can have a follow-up PR if changes need to be done.
pass | ||
else: | ||
return min_decimal_type | ||
if decimal_type.MAX_PRECISION >= lhs_rhs_max_precision: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this if
is checking the same thing as the _validate
method in the decimal dtype constructor. Is this unnecessarily duplicated? I'd fall back on the try
and remove the if
if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong here -- I see you're constructing the returned type with precision=precision
instead of precision=max_precision
. Would it be better to try and construct a type with max_precision
and return a type with precision
if that succeeds? (Or is that a bug -- should it be returning a type with max_precision
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong here -- I see you're constructing the returned type with
precision=precision
instead ofprecision=max_precision
. Would it be better to try and construct a type withmax_precision
and return a type withprecision
if that succeeds? (Or is that a bug -- should it be returning a type withmax_precision
?)
It's not a bug, the dtype is expected to have precision
and not max_precision
It looks like this
if
is checking the same thing as the_validate
method in the decimal dtype constructor. Is this unnecessarily duplicated? I'd fall back on thetry
and remove theif
if possible.
This was a necessary duplication because we want to pick a dtype
that is not less than lhs_dtype
or rhs_dtype
. i.e., avoid type narrowing.
Co-authored-by: Bradley Dice <[email protected]>
Co-authored-by: Bradley Dice <[email protected]>
Co-authored-by: Bradley Dice <[email protected]>
if decimal_type.MAX_PRECISION >= max_precision: | ||
try: | ||
return decimal_type(precision=precision, scale=scale) | ||
except ValueError: | ||
# Call to _validate fails, which means we need | ||
# to try the next dtype | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a problem larger than this, I would suggest something like bisect
to determine the type corresponding to a certain precision, but I think this is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have only a couple minor suggestions. I shared a longer comment about libcudf's decimal conventions but I'm not sure if there's anything actionable there based on what I know.
@@ -364,18 +364,51 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op): | |||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for other operators exists, e.g. MOD
/ PMOD
/ PYMOD
: #10179.
I'm fine with keeping both precision/scale calculations together here. I just wanted to make a note to ask, since I saw the comment above.
There may or may not be issues with the scale/precision calculations. I think the page you referenced has different conventions than libcudf. In my understanding:
- libcudf's scale represents powers of the radix (base 10 or base 2)
- libcudf's precision (32, 64, 128) represents bits (powers of two) used to store the integral part
Neither value appears to correspond to the linked SQL docs. That page appears to always use powers of 10 for both scale and precision. Also the definition of scale is the negative of libcudf's definition. It does not surprise me that these different conventions would result in different expressions. I spent an hour looking into this but I have no idea how to make the two definitions mathematically correspond.
Working through an example calculation here, for the SQL docs:
e1 = 4.096
p1 = 4
s1 = 3
e2 = 3.2
p2 = 2
s2 = 1
s = max(6, s1 + p2 + 1)
p = p1 - s1 + s2 + s
print(f"{e1/e2=}") # e1/e2=1.28
print(f"{p=}, {s=}") # p=8, s=6
I was confused and gave up at this point -- how could 1.28
have p=8, s=6
?
Co-authored-by: Bradley Dice <[email protected]>
@gpucibot merge |
Fixes: #10282
This PR removes decimal type narrowing and also updates the tests accordingly.