Skip to content

Commit

Permalink
Support string repeat SQL (#2728)
Browse files Browse the repository at this point in the history
* Implement `StringRepeat`

Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Aug 5, 2021
1 parent d1327f8 commit 081ecfc
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.StringLPad"></a>spark.rapids.sql.expression.StringLPad|`lpad`|Pad a string on the left|true|None|
<a name="sql.expression.StringLocate"></a>spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
<a name="sql.expression.StringRPad"></a>spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
<a name="sql.expression.StringRepeat"></a>spark.rapids.sql.expression.StringRepeat|`repeat`|StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes|true|None|
<a name="sql.expression.StringReplace"></a>spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
<a name="sql.expression.StringSplit"></a>spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
<a name="sql.expression.StringTrim"></a>spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
Expand Down
132 changes: 132 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15202,6 +15202,138 @@ Accelerator support is described below.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">StringRepeat</td>
<td rowSpan="6">`repeat`</td>
<td rowSpan="6">StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="8">StringReplace</td>
<td rowSpan="8">`replace`</td>
<td rowSpan="8">StringReplace operator</td>
Expand Down
22 changes: 22 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,28 @@ def test_substring():
'SUBSTRING(a, 1, NULL)',
'SUBSTRING(a, 0, 0)'))

def test_repeat_scalar_and_column():
gen_s = StringGen(nullable=False)
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0], nullable=True)
(s,) = gen_scalars_for_sql(gen_s, 1)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_r).selectExpr('repeat({}, a)'.format(s)))

def test_repeat_column_and_scalar():
gen_s = StringGen(nullable=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_s).selectExpr(
'repeat(a, -10)',
'repeat(a, 0)',
'repeat(a, 10)'
))

def test_repeat_column_and_column():
gen_s = StringGen(nullable=True)
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0], nullable=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).selectExpr('repeat(a, b)'))

def test_replace():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,17 @@ object GpuOverrides {
.withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING),
ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)),
expr[StringRepeat](
"StringRepeat operator that repeats the given strings with numbers of times " +
"given by repeatTimes",
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("input", TypeSig.STRING, TypeSig.STRING),
ParamCheck("repeatTimes", TypeSig.INT, TypeSig.INT))),
(in, conf, p, r) => new BinaryExprMeta[StringRepeat](in, conf, p, r) {
override def convertToGpu(
input: Expression,
repeatTimes: Expression): GpuExpression = GpuStringRepeat(input, repeatTimes)
}),
expr[StringReplace](
"StringReplace operator",
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Expand Down Expand Up @@ -2575,7 +2586,7 @@ object GpuOverrides {
(a, conf, p, r) => new ExprMeta[ConcatWs](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (a.children.size <= 1) {
// If only a separator specified and its a column, Spark returns an empty
// If only a separator specified and its a column, Spark returns an empty
// string for all entries unless they are null, then it returns null.
// This seems like edge case so instead of handling on GPU just fallback.
willNotWorkOnGpu("Only specifying separator column not supported on GPU")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,80 @@ case class GpuInitCap(child: Expression) extends GpuUnaryExpression with Implici
}
}

case class GpuStringRepeat(input: Expression, repeatTimes: Expression)
extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = input
override def right: Expression = repeatTimes
override def dataType: DataType = input.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType)

def doColumnar(input: GpuScalar, repeatTimes: GpuColumnVector): ColumnVector = {
assert(input.dataType == StringType)

withResource(GpuColumnVector.from(input, repeatTimes.getRowCount.asInstanceOf[Int],
input.dataType)) {
replicatedInput => doColumnar(replicatedInput, repeatTimes)
}
}

def doColumnar(input: GpuColumnVector, repeatTimes: GpuColumnVector): ColumnVector = {
val repeatTimesCV = repeatTimes.getBase

// Compute the output size to check for overflow.
withResource(input.getBase.repeatStringsSizes(repeatTimesCV)) { outputSizes =>
if (outputSizes.getTotalSize > Int.MaxValue.asInstanceOf[Long]) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings using the pre-computed strings' sizes.
input.getBase.repeatStrings(repeatTimesCV, outputSizes.getStringSizes)
}
}

def doColumnar(input: GpuColumnVector, repeatTimes: GpuScalar): ColumnVector = {
if (!repeatTimes.isValid) {
// If the input scala repeatTimes is invalid, the results should be all nulls.
withResource(Scalar.fromNull(DType.STRING)) {
nullString => ColumnVector.fromScalar(nullString, input.getRowCount.asInstanceOf[Int])
}
} else {
assert(repeatTimes.dataType == IntegerType)
val repeatTimesVal = repeatTimes.getBase.getInt

// Get the input size to check for overflow for the output.
// Note that this is not an accurate check since the total buffer size of the input
// strings column may be larger than the total length of strings that will be repeated in
// this function.
val inputBufferSize = input.getBase.getData.getLength
if (repeatTimesVal > 0 && inputBufferSize > Int.MaxValue / repeatTimesVal) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings.
input.getBase.repeatStrings(repeatTimesVal)
}
}

def doColumnar(numRows: Int, input: GpuScalar, repeatTimes: GpuScalar): ColumnVector = {
assert(input.dataType == StringType)

if (!repeatTimes.isValid) {
// If the input scala repeatTimes is invalid, the results should be all nulls.
withResource(Scalar.fromNull(DType.STRING)) {
nullString => ColumnVector.fromScalar(nullString, numRows)
}
} else {
assert(repeatTimes.dataType == IntegerType)
val repeatTimesVal = repeatTimes.getBase.getInt

withResource(input.getBase.repeatString(repeatTimesVal)) {
repeatedString => ColumnVector.fromScalar(repeatedString, numRows)
}
}
}

}

case class GpuStringReplace(
srcExpr: Expression,
searchExpr: Expression,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.unit

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{GpuColumnVector, GpuScalar, GpuUnitTests}

import org.apache.spark.sql.rapids.GpuStringRepeat
import org.apache.spark.sql.types.DataTypes

// This class just covers some test cases that are not possible to call in PySpark tests.
// The remaining tests are in `integration_test/src/main/python/string_test.py`.
class StringRepeatUnitTest extends GpuUnitTests {
test("Test StringRepeat with scalar string and scalar repeatTimes") {
// Test repeat(NULL|str, NULL).
withResource(GpuScalar(null, DataTypes.IntegerType)) { nullInt =>
val doTest = (strScalar: GpuScalar) => {
withResource(GpuStringRepeat(null, null).doColumnar(1, strScalar, nullInt)) { result =>
assertResult(1)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
}
}

withResource(GpuScalar(null, DataTypes.StringType)) { nullStr => doTest(nullStr) }
withResource(GpuScalar("abc123", DataTypes.StringType)) { str => doTest(str) }
withResource(GpuScalar("á é í", DataTypes.StringType)) { str => doTest(str) }
}

// Test repeat(NULL, intVal).
withResource(GpuScalar(null, DataTypes.StringType)) { nullStr =>
val doTest = (intVal: Int) =>
withResource(GpuScalar(intVal, DataTypes.IntegerType)) { intScalar =>
withResource(GpuStringRepeat(null, null).doColumnar(1, nullStr, intScalar)) { result =>
assertResult(1)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
}
}

doTest(-1)
doTest(0)
doTest(1)
}

// Test repeat(str, intVal).
withResource(GpuScalar("abc123", DataTypes.StringType)) { strScalar =>
val doTest = (intVal: Int, expectedStr: String) =>
withResource(GpuScalar(intVal, DataTypes.IntegerType)) { intScalar =>
withResource(GpuStringRepeat(null, null).doColumnar(1, strScalar, intScalar)) {
result =>
withResource(result.copyToHost()) { hostResult =>
assertResult(1)(hostResult.getRowCount)
assertResult(0)(hostResult.getNullCount)
assertResult(expectedStr)(hostResult.getJavaString(0))
}
}
}

doTest(-1, "")
doTest(0, "")
doTest(1, "abc123")
doTest(2, "abc123abc123")
doTest(3, "abc123abc123abc123")
}
}

test("Test StringRepeat with NULL scalar string and column repeatTimes") {
val intCol = ColumnVector.fromBoxedInts(-3, null, -1, 0, 1, 2, null)

withResource(GpuColumnVector.from(intCol, DataTypes.IntegerType)) { intGpuColumn =>
// Test repeat(NULL, intVal).
withResource(GpuScalar(null, DataTypes.StringType)) { nullStr =>
withResource(GpuStringRepeat(null, null).doColumnar(nullStr, intGpuColumn)) { result =>
assertResult(intGpuColumn.getRowCount)(result.getRowCount)
assertResult(result.getRowCount)(result.getNullCount)
}
}
}
}

test("Test StringRepeat with column strings and NULL scalar repeatTimes") {
val strsCol = ColumnVector.fromStrings(null, "a", "", "123", "éá")

withResource(GpuColumnVector.from(strsCol, DataTypes.StringType)) { strGpuColumn =>
// Test repeat(strs, NULL).
withResource(GpuScalar(null, DataTypes.IntegerType)) { nullInt =>
withResource(GpuStringRepeat(null, null).doColumnar(strGpuColumn, nullInt)) { result =>
assertResult(strGpuColumn.getRowCount)(result.getRowCount)
assertResult(result.getRowCount)(result.getNullCount)
}
}
}
}

}

0 comments on commit 081ecfc

Please sign in to comment.