-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes creation of invalid DecimalType in GpuDivide.tagExprForGpu #1991
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1714,29 +1714,39 @@ object GpuOverrides { | |
(childExprs.head.dataType, childExprs(1).dataType) match { | ||
case (l: DecimalType, r: DecimalType) => | ||
val outputType = GpuDivideUtil.decimalDataType(l, r) | ||
// We will never hit a case where outputType.precision < outputType.scale + r.scale. | ||
// So there is no need to protect against that. | ||
// The only two cases in which there is a possibility of the intermediary scale | ||
// exceeding the intermediary precision is when l.precision < l.scale or l | ||
// .precision < 0, both of which aren't possible. | ||
// Proof: | ||
// case 1: | ||
// outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1 | ||
// outputType.scale = p1 + s2 + p2 + 1 + 1 | ||
// To find out if outputType.precision < outputType.scale simplifies to p1 < s1, | ||
// which is never possible | ||
// | ||
// case 2: | ||
// outputType.precision = p1 - s1 + s2 + 6 + 1 | ||
// outputType.scale = 6 + 1 | ||
// To find out if outputType.precision < outputType.scale simplifies to p1 < 0 | ||
// which is never possible | ||
// Case 1: OutputType.precision doesn't get truncated | ||
// We will never hit a case where outputType.precision < outputType.scale + r.scale. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good (and fairly easy) to add a unit test for this. I had a go and found some values that do hit this case but it is possible that I am using values that couldn't happen in real life. I am not sure. Here is one example:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps I am hitting case 2 here where precision is getting truncated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does look like I am hitting case 2 here so I think the logic in the comments is sound, but a unit test would make me more confident and would protect against future regressions if any of this code gets updated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for the unit test to verify that this is working. The logic looks good to me. I am a little concerned about the coupling between this code and the code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I wrote a unit test locally to verify this. I don't know why I didn't check it in. Let me do that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So, like you said this won't happen in reality as by the time we get the call both |
||
// So there is no need to protect against that. | ||
// The only two cases in which there is a possibility of the intermediary scale | ||
// exceeding the intermediary precision is when l.precision < l.scale or l | ||
// .precision < 0, both of which aren't possible. | ||
// Proof: | ||
// case 1: | ||
// outputType.precision = p1 - s1 + s2 + s1 + p2 + 1 + 1 | ||
// outputType.scale = p1 + s2 + p2 + 1 + 1 | ||
// To find out if outputType.precision < outputType.scale simplifies to p1 < s1, | ||
// which is never possible | ||
// | ||
// case 2: | ||
// outputType.precision = p1 - s1 + s2 + 6 + 1 | ||
// outputType.scale = 6 + 1 | ||
// To find out if outputType.precision < outputType.scale simplifies to p1 < 0 | ||
// which is never possible | ||
// Case 2: OutputType.precision gets truncated to 38 | ||
// In this case we have to make sure the r.precision + l.scale + r.scale + 1 <= 38 | ||
// Otherwise the intermediate result will overflow | ||
// TODO We should revisit the proof one more time after we support 128-bit decimals | ||
val intermediateResult = DecimalType(outputType.precision, outputType.scale + r.scale) | ||
if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) { | ||
willNotWorkOnGpu("The actual output precision of the divide is too large" + | ||
if (l.precision + l.scale + r.scale + 1 > 38) { | ||
willNotWorkOnGpu("The intermediate output precision of the divide is too " + | ||
s"large to be supported on the GPU i.e. Decimal(${outputType.precision}, " + | ||
s"${outputType.scale + r.scale})") | ||
} else { | ||
val intermediateResult = | ||
DecimalType(outputType.precision, outputType.scale + r.scale) | ||
if (intermediateResult.precision > DType.DECIMAL64_MAX_PRECISION) { | ||
willNotWorkOnGpu("The actual output precision of the divide is too large" + | ||
s" to fit on the GPU $intermediateResult") | ||
} | ||
} | ||
case _ => // NOOP | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remove your patch and keep the test there will be no real difference in the results, unless someone goes through the logs manually and looks that in one case we failed by falling back to the CPU, and in another case we failed by throwing an exception about going over the limit. In both cases the xfail ignored the exception that was thrown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the test to check if the test raises IllegalArgumentException. Let me know if you want me to tighten it down even more. Ideally we should be checking the message to make sure its failing because its not columnar but I am not sure how to accomplish that using the pytest xfail