Skip to content

Commit

Permalink
Support round and bround SQL functions (NVIDIA#1244)
Browse files Browse the repository at this point in the history
Signed-off-by: Niranjan Artal <[email protected]>
  • Loading branch information
nartal1 authored Jan 8, 2021
1 parent a53a8af commit d347102
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.Atan"></a>spark.rapids.sql.expression.Atan|`atan`|Inverse tangent|true|None|
<a name="sql.expression.Atanh"></a>spark.rapids.sql.expression.Atanh|`atanh`|Inverse hyperbolic tangent|true|None|
<a name="sql.expression.AttributeReference"></a>spark.rapids.sql.expression.AttributeReference| |References an input column|true|None|
<a name="sql.expression.BRound"></a>spark.rapids.sql.expression.BRound|`bround`|Round an expression to d decimal places using HALF_EVEN rounding mode|true|None|
<a name="sql.expression.BitwiseAnd"></a>spark.rapids.sql.expression.BitwiseAnd|`&`|Returns the bitwise AND of the operands|true|None|
<a name="sql.expression.BitwiseNot"></a>spark.rapids.sql.expression.BitwiseNot|`~`|Returns the bitwise NOT of the operands|true|None|
<a name="sql.expression.BitwiseOr"></a>spark.rapids.sql.expression.BitwiseOr|`\|`|Returns the bitwise OR of the operands|true|None|
Expand Down Expand Up @@ -197,6 +198,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.RegExpReplace"></a>spark.rapids.sql.expression.RegExpReplace|`regexp_replace`|RegExpReplace support for string literal input patterns|true|None|
<a name="sql.expression.Remainder"></a>spark.rapids.sql.expression.Remainder|`%`, `mod`|Remainder or modulo|true|None|
<a name="sql.expression.Rint"></a>spark.rapids.sql.expression.Rint|`rint`|Rounds up a double value to the nearest double equal to an integer|true|None|
<a name="sql.expression.Round"></a>spark.rapids.sql.expression.Round|`round`|Round an expression to d decimal places using HALF_UP rounding mode|true|None|
<a name="sql.expression.RowNumber"></a>spark.rapids.sql.expression.RowNumber|`row_number`|Window function that returns the index for the row within the aggregation window|true|None|
<a name="sql.expression.Second"></a>spark.rapids.sql.expression.Second|`second`|Returns the second component of the string/timestamp|true|None|
<a name="sql.expression.ShiftLeft"></a>spark.rapids.sql.expression.ShiftLeft|`shiftleft`|Bitwise shift left (<<)|true|None|
Expand Down
264 changes: 264 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,138 @@ Accelerator support is described below.
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="6">BRound</td>
<td rowSpan="6">`bround`</td>
<td rowSpan="6">Round an expression to d decimal places using HALF_EVEN rounding mode</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>value</td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>scale</td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>value</td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>scale</td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="6">BitwiseAnd</td>
<td rowSpan="6">`&`</td>
<td rowSpan="6">Returns the bitwise AND of the operands</td>
Expand Down Expand Up @@ -10388,6 +10520,138 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="6">Round</td>
<td rowSpan="6">`round`</td>
<td rowSpan="6">Round an expression to d decimal places using HALF_UP rounding mode</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>value</td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>scale</td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td>S*</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>value</td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>scale</td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="1">RowNumber</td>
<td rowSpan="1">`row_number`</td>
<td rowSpan="1">Window function that returns the index for the row within the aggregation window</td>
Expand Down
24 changes: 24 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,30 @@ def test_shift_right_unsigned(data_gen):
'shiftrightunsigned(a, cast(null as INT))',
'shiftrightunsigned(a, b)'))

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', round_gens, ids=idfn)
def test_decimal_bround(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'bround(a)',
'bround(a, -1)',
'bround(a, 1)',
'bround(a, 10)'),
conf=allow_negative_scale_of_decimal_conf)

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', round_gens, ids=idfn)
def test_decimal_round(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'round(a)',
'round(a, -1)',
'round(a, 1)',
'round(a, 10)'),
conf=allow_negative_scale_of_decimal_conf)

@approximate_float
@pytest.mark.parametrize('data_gen', double_gens, ids=idfn)
def test_cbrt(data_gen):
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,9 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
# Include decimal type while testing equalTo and notEqualTo
eq_gens_with_decimal_gen = eq_gens + decimal_gens

#gen for testing round operator
round_gens = numeric_gens + decimal_gens

date_gens = [date_gen]
date_n_time_gens = [date_gen, timestamp_gen]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,26 @@ object GpuOverrides {

override def convertToGpu(child: Expression): GpuExpression = GpuAverage(child)
}),
expr[BRound](
"Round an expression to d decimal places using HALF_EVEN rounding mode",
ExprChecks.binaryProjectNotLambda(
TypeSig.numeric, TypeSig.numeric,
("value", TypeSig.numeric, TypeSig.numeric),
("scale", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))),
(a, conf, p, r) => new BinaryExprMeta[BRound](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuBRound(lhs, rhs)
}),
expr[Round](
"Round an expression to d decimal places using HALF_UP rounding mode",
ExprChecks.binaryProjectNotLambda(
TypeSig.numeric, TypeSig.numeric,
("value", TypeSig.numeric, TypeSig.numeric),
("scale", TypeSig.lit(TypeEnum.INT), TypeSig.lit(TypeEnum.INT))),
(a, conf, p, r) => new BinaryExprMeta[Round](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuRound(lhs, rhs)
}),
expr[PythonUDF](
"UDF run in an external python process. Does not actually run on the GPU, but " +
"the transfer of data to/from it can be accelerated.",
Expand Down
Loading

0 comments on commit d347102

Please sign in to comment.