Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support string repeat SQL #2728

Merged
merged 29 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec008ef
Implement `RepeatString` SQL
ttnghia Jun 16, 2021
21aa640
Update auto-generated doc
ttnghia Jun 16, 2021
1c82b52
Implement StringRepeat with repeatTimes is a numeric column
ttnghia Jul 28, 2021
521b05a
Update `supported_ops.md`
ttnghia Jul 28, 2021
289d106
Remove spaces
ttnghia Jul 28, 2021
f45c689
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 28, 2021
b127723
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 28, 2021
f880c95
Add unit tests
ttnghia Jul 29, 2021
9054651
Fix ParamCheck
ttnghia Jul 29, 2021
45e8bc0
Fix resource leaking, add null check and type check assert
ttnghia Jul 29, 2021
e4ed829
Rewrite tests
ttnghia Jul 29, 2021
7433fda
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 29, 2021
3cf5cfb
Rewrite tests
ttnghia Jul 29, 2021
ddaf723
Fix description
ttnghia Jul 29, 2021
62928a9
Update automatically generated docs
ttnghia Jul 29, 2021
540c6cf
Cleanup test
ttnghia Jul 30, 2021
4dcb4c2
Remove pattern matching
ttnghia Jul 30, 2021
89d52a7
Merge branch 'branch-21.10' into repeat_strings
ttnghia Aug 2, 2021
c0bf0de
Add scala tests
ttnghia Aug 2, 2021
fefe2b5
Cleanup python tests
ttnghia Aug 2, 2021
0872a24
MISC
ttnghia Aug 2, 2021
8ad1ab7
Rewrite test suite into unit tests
ttnghia Aug 3, 2021
ca74036
Rename file, and remove repeated tests
ttnghia Aug 3, 2021
8566c3c
Change StringRepeat operator description
ttnghia Aug 4, 2021
b19c869
Break long line into 2 lines
ttnghia Aug 4, 2021
68fb470
Update tests/src/test/scala/com/nvidia/spark/rapids/unit/StringRepeat…
ttnghia Aug 4, 2021
0af805d
Update tests/src/test/scala/com/nvidia/spark/rapids/unit/StringRepeat…
ttnghia Aug 4, 2021
0c75649
Merge branch 'branch-21.10' into repeat_strings
ttnghia Aug 4, 2021
8e419e7
Update docs
ttnghia Aug 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.StringLPad"></a>spark.rapids.sql.expression.StringLPad|`lpad`|Pad a string on the left|true|None|
<a name="sql.expression.StringLocate"></a>spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
<a name="sql.expression.StringRPad"></a>spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
<a name="sql.expression.StringRepeat"></a>spark.rapids.sql.expression.StringRepeat|`repeat`|StringRepeat operator, which repeat the given strings by given number(s) of times|true|None|
<a name="sql.expression.StringReplace"></a>spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
<a name="sql.expression.StringSplit"></a>spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
<a name="sql.expression.StringTrim"></a>spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
Expand Down
132 changes: 132 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15179,6 +15179,138 @@ Accelerator support is described below.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">StringRepeat</td>
<td rowSpan="6">`repeat`</td>
<td rowSpan="6">StringRepeat operator, which repeat the given strings by given number(s) of times</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="8">StringReplace</td>
<td rowSpan="8">`replace`</td>
<td rowSpan="8">StringReplace operator</td>
Expand Down
22 changes: 22 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,28 @@ def test_substring():
'SUBSTRING(a, 1, NULL)',
'SUBSTRING(a, 0, 0)'))

def test_repeat_scalar_and_column():
gen_s = StringGen(nullable=False)
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0], nullable=True)
(s,) = gen_scalars_for_sql(gen_s, 1)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_r).selectExpr('repeat({}, a)'.format(s)))

def test_repeat_column_and_scalar():
gen_s = StringGen(nullable=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_s).selectExpr(
'repeat(a, -10)',
'repeat(a, 0)',
'repeat(a, 10)'
))

def test_repeat_column_and_column():
gen_s = StringGen(nullable=True)
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0], nullable=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).selectExpr('repeat(a, b)'))

def test_replace():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2484,6 +2484,16 @@ object GpuOverrides {
.withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING),
ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)),
expr[StringRepeat](
"StringRepeat operator, which repeat the given strings by given number(s) of times",
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("input", TypeSig.STRING, TypeSig.STRING),
ParamCheck("repeatTimes", TypeSig.INT, TypeSig.INT))),
(in, conf, p, r) => new BinaryExprMeta[StringRepeat](in, conf, p, r) {
override def convertToGpu(
input: Expression,
repeatTimes: Expression): GpuExpression = GpuStringRepeat(input, repeatTimes)
}),
expr[StringReplace](
"StringReplace operator",
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Expand Down Expand Up @@ -2575,7 +2585,7 @@ object GpuOverrides {
(a, conf, p, r) => new ExprMeta[ConcatWs](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (a.children.size <= 1) {
// If only a separator specified and its a column, Spark returns an empty
// If only a separator specified and its a column, Spark returns an empty
// string for all entries unless they are null, then it returns null.
// This seems like edge case so instead of handling on GPU just fallback.
willNotWorkOnGpu("Only specifying separator column not supported on GPU")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,80 @@ case class GpuInitCap(child: Expression) extends GpuUnaryExpression with Implici
}
}

case class GpuStringRepeat(input: Expression, repeatTimes: Expression)
extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = input
override def right: Expression = repeatTimes
override def dataType: DataType = input.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType)

def doColumnar(input: GpuScalar, repeatTimes: GpuColumnVector): ColumnVector = {
assert(input.dataType == StringType)

withResource(GpuColumnVector.from(input, repeatTimes.getRowCount.asInstanceOf[Int],
input.dataType)) {
replicatedInput => doColumnar(replicatedInput, repeatTimes)
}
}

def doColumnar(input: GpuColumnVector, repeatTimes: GpuColumnVector): ColumnVector = {
val repeatTimesCV = repeatTimes.getBase

// Compute the output size to check for overflow.
withResource(input.getBase.repeatStringsSizes(repeatTimesCV)) { outputSizes =>
if (outputSizes.getTotalSize > Int.MaxValue.asInstanceOf[Long]) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings using the pre-computed strings' sizes.
input.getBase.repeatStrings(repeatTimesCV, outputSizes.getStringSizes)
}
}

def doColumnar(input: GpuColumnVector, repeatTimes: GpuScalar): ColumnVector = {
if (!repeatTimes.isValid) {
// If the input scala repeatTimes is invalid, the results should be all nulls.
withResource(Scalar.fromNull(DType.STRING)) {
nullString => ColumnVector.fromScalar(nullString, input.getRowCount.asInstanceOf[Int])
}
} else {
assert(repeatTimes.dataType == IntegerType)
val repeatTimesVal = repeatTimes.getBase.getInt

// Get the input size to check for overflow for the output.
// Note that this is not an accurate check since the total buffer size of the input
// strings column may be larger than the total length of strings that will be repeated in
// this function.
val inputBufferSize = input.getBase.getData.getLength
if (repeatTimesVal > 0 && inputBufferSize > Int.MaxValue / repeatTimesVal) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings.
input.getBase.repeatStrings(repeatTimesVal)
}
}

def doColumnar(numRows: Int, input: GpuScalar, repeatTimes: GpuScalar): ColumnVector = {
assert(input.dataType == StringType)

if (!repeatTimes.isValid) {
// If the input scala repeatTimes is invalid, the results should be all nulls.
withResource(Scalar.fromNull(DType.STRING)) {
nullString => ColumnVector.fromScalar(nullString, numRows)
}
} else {
assert(repeatTimes.dataType == IntegerType)
val repeatTimesVal = repeatTimes.getBase.getInt

withResource(input.getBase.repeatString(repeatTimesVal)) {
repeatedString => ColumnVector.fromScalar(repeatedString, numRows)
}
}
}

}

case class GpuStringReplace(
srcExpr: Expression,
searchExpr: Expression,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.unit

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.{GpuColumnVector, GpuScalar, GpuUnitTests}

import org.apache.spark.sql.rapids.GpuStringRepeat
import org.apache.spark.sql.types.DataTypes

// This class just covers some test cases that are not possible to call in PySpark tests.
// The remaining tests are in `integration_test/src/main/python/string_test.py`.
class StringRepeatUnitTest extends GpuUnitTests {
test("Test StringRepeat with scalar string and scalar repeatTimes") {
// Test repeat(NULL|str, NULL).
withResource(GpuScalar(null, DataTypes.IntegerType)) { nullInt =>
val doTest = (strScalar: GpuScalar) => {
withResource(GpuStringRepeat(null, null).doColumnar(1, strScalar, nullInt)) { result =>
assertResult(1)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
}
}

withResource(GpuScalar(null, DataTypes.StringType)) { nullStr => doTest(nullStr) }
withResource(GpuScalar("abc123", DataTypes.StringType)) { str => doTest(str) }
withResource(GpuScalar("á é í", DataTypes.StringType)) { str => doTest(str) }
}

// Test repeat(NULL, intVal).
withResource(GpuScalar(null, DataTypes.StringType)) { nullStr =>
val doTest = (intVal: Int) =>
withResource(GpuScalar(intVal, DataTypes.IntegerType)) { intScalar =>
withResource(GpuStringRepeat(null, null).doColumnar(1, nullStr, intScalar)) { result =>
assertResult(1)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
}
}

doTest(-1)
doTest(0)
doTest(1)
}

// Test repeat(str, intVal).
withResource(GpuScalar("abc123", DataTypes.StringType)) { strScalar =>
val doTest = (intVal: Int, expectedStr: String) =>
withResource(GpuScalar(intVal, DataTypes.IntegerType)) { intScalar =>
withResource(GpuStringRepeat(null, null).doColumnar(1, strScalar, intScalar)) {
result =>
withResource(result.copyToHost()) { hostResult =>
assertResult(1)(hostResult.getRowCount)
assertResult(0)(hostResult.getNullCount)
assertResult(expectedStr)(hostResult.getJavaString(0))
}
}
}

doTest(-1, "")
doTest(0, "")
doTest(1, "abc123")
doTest(2, "abc123abc123")
doTest(3, "abc123abc123abc123")
}
}

test("Test StringRepeat with NULL scalar string and column repeatTimes") {
val intCol = ColumnVector.fromBoxedInts(-3, null, -1, 0, 1, 2, null)

withResource(GpuColumnVector.from(intCol, DataTypes.IntegerType)) { intGpuColumn =>
// Test repeat(NULL, intVal).
withResource(GpuScalar(null, DataTypes.StringType)) { nullStr =>
withResource(GpuStringRepeat(null, null).doColumnar(nullStr, intGpuColumn)) { result =>
assertResult(intGpuColumn.getRowCount)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

test("Test StringRepeat with column strings and NULL scalar repeatTimes") {
val strsCol = ColumnVector.fromStrings(null, "a", "", "123", "éá")

withResource(GpuColumnVector.from(strsCol, DataTypes.StringType)) { strGpuColumn =>
// Test repeat(strs, NULL).
withResource(GpuScalar(null, DataTypes.IntegerType)) { nullInt =>
withResource(GpuStringRepeat(null, null).doColumnar(strGpuColumn, nullInt)) { result =>
assertResult(strGpuColumn.getRowCount)(result.getRowCount)
assertResult(result.getNullCount)(result.getRowCount)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

}