Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support string repeat SQL #2728

Merged
merged 29 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec008ef
Implement `RepeatString` SQL
ttnghia Jun 16, 2021
21aa640
Update auto-generated doc
ttnghia Jun 16, 2021
1c82b52
Implement StringRepeat with repeatTimes is a numeric column
ttnghia Jul 28, 2021
521b05a
Update `supported_ops.md`
ttnghia Jul 28, 2021
289d106
Remove spaces
ttnghia Jul 28, 2021
f45c689
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 28, 2021
b127723
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 28, 2021
f880c95
Add unit tests
ttnghia Jul 29, 2021
9054651
Fix ParamCheck
ttnghia Jul 29, 2021
45e8bc0
Fix resource leaking, add null check and type check assert
ttnghia Jul 29, 2021
e4ed829
Rewrite tests
ttnghia Jul 29, 2021
7433fda
Merge branch 'branch-21.10' into repeat_strings
ttnghia Jul 29, 2021
3cf5cfb
Rewrite tests
ttnghia Jul 29, 2021
ddaf723
Fix description
ttnghia Jul 29, 2021
62928a9
Update automatically generated docs
ttnghia Jul 29, 2021
540c6cf
Cleanup test
ttnghia Jul 30, 2021
4dcb4c2
Remove pattern matching
ttnghia Jul 30, 2021
89d52a7
Merge branch 'branch-21.10' into repeat_strings
ttnghia Aug 2, 2021
c0bf0de
Add scala tests
ttnghia Aug 2, 2021
fefe2b5
Cleanup python tests
ttnghia Aug 2, 2021
0872a24
MISC
ttnghia Aug 2, 2021
8ad1ab7
Rewrite test suite into unit tests
ttnghia Aug 3, 2021
ca74036
Rename file, and remove repeated tests
ttnghia Aug 3, 2021
8566c3c
Change StringRepeat operator description
ttnghia Aug 4, 2021
b19c869
Break long line into 2 lines
ttnghia Aug 4, 2021
68fb470
Update tests/src/test/scala/com/nvidia/spark/rapids/unit/StringRepeat…
ttnghia Aug 4, 2021
0af805d
Update tests/src/test/scala/com/nvidia/spark/rapids/unit/StringRepeat…
ttnghia Aug 4, 2021
0c75649
Merge branch 'branch-21.10' into repeat_strings
ttnghia Aug 4, 2021
8e419e7
Update docs
ttnghia Aug 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.StringLPad"></a>spark.rapids.sql.expression.StringLPad|`lpad`|Pad a string on the left|true|None|
<a name="sql.expression.StringLocate"></a>spark.rapids.sql.expression.StringLocate|`position`, `locate`|Substring search operator|true|None|
<a name="sql.expression.StringRPad"></a>spark.rapids.sql.expression.StringRPad|`rpad`|Pad a string on the right|true|None|
<a name="sql.expression.StringRepeat"></a>spark.rapids.sql.expression.StringRepeat|`repeat`|StringRepeat operator|true|None|
<a name="sql.expression.StringReplace"></a>spark.rapids.sql.expression.StringReplace|`replace`|StringReplace operator|true|None|
<a name="sql.expression.StringSplit"></a>spark.rapids.sql.expression.StringSplit|`split`|Splits `str` around occurrences that match `regex`|true|None|
<a name="sql.expression.StringTrim"></a>spark.rapids.sql.expression.StringTrim|`trim`|StringTrim operator|true|None|
Expand Down
132 changes: 132 additions & 0 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15063,6 +15063,138 @@ Accelerator support is described below.
<td> </td>
</tr>
<tr>
<td rowSpan="6">StringRepeat</td>
<td rowSpan="6">`repeat`</td>
<td rowSpan="6">StringRepeat operator</td>
<td rowSpan="6">None</td>
<td rowSpan="3">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS (Literal value only)</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS (Literal value only)</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">lambda</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>repeatTimes</td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="8">StringReplace</td>
<td rowSpan="8">`replace`</td>
<td rowSpan="8">StringReplace operator</td>
Expand Down
35 changes: 35 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,41 @@ def test_substring():
'SUBSTRING(a, 1, NULL)',
'SUBSTRING(a, 0, 0)'))

def test_repeat_scalar_and_scalar():
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
gen_s = StringGen()
gen_r = IntegerGen()
s = gen_scalar(gen_s)
r = gen_scalar_value(gen_r, force_no_nulls=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_s).select(f.repeat(s, r)))

def test_repeat_column_and_scalar():
gen_s = StringGen()
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0])
r = gen_scalar_value(gen_r, force_no_nulls=True)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen_s).select(
f.repeat(f.col('a'), r)))

def test_repeat_column_and_column():
gen_s = StringGen()
gen_r = IntegerGen(min_val=-100, max_val=100, special_cases=[0])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).select(
f.repeat(f.col('a'), f.col('b'))))
gen_r = ByteGen(min_val=-100, max_val=100, special_cases=[0])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).select(
f.repeat(f.col('a'), f.col('b'))))
gen_r = ShortGen(min_val=-100, max_val=100, special_cases=[0])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).select(
f.repeat(f.col('a'), f.col('b'))))
gen_r = LongGen(min_val=-100, max_val=100, special_cases=[0])
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, gen_s, gen_r).select(
f.repeat(f.col('a'), f.col('b'))))

def test_replace():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2472,6 +2472,16 @@ object GpuOverrides {
.withPsNote(TypeEnum.STRING, "only a single character is allowed"), TypeSig.STRING),
ParamCheck("count", TypeSig.lit(TypeEnum.INT), TypeSig.INT))),
(in, conf, p, r) => new SubstringIndexMeta(in, conf, p, r)),
expr[StringRepeat](
"StringRepeat operator",
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Seq(ParamCheck("input", TypeSig.lit(TypeEnum.STRING) + TypeSig.STRING, TypeSig.STRING),
ParamCheck("repeatTimes", TypeSig.lit(TypeEnum.INT) + TypeSig.INT, TypeSig.INT))),
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
(in, conf, p, r) => new BinaryExprMeta[StringRepeat](in, conf, p, r) {
override def convertToGpu(
input: Expression,
repeatTimes: Expression): GpuExpression = GpuStringRepeat(input, repeatTimes)
}),
expr[StringReplace](
"StringReplace operator",
ExprChecks.projectNotLambda(TypeSig.STRING, TypeSig.STRING,
Expand Down Expand Up @@ -2563,7 +2573,7 @@ object GpuOverrides {
(a, conf, p, r) => new ExprMeta[ConcatWs](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (a.children.size <= 1) {
// If only a separator specified and its a column, Spark returns an empty
// If only a separator specified and its a column, Spark returns an empty
// string for all entries unless they are null, then it returns null.
// This seems like edge case so instead of handling on GPU just fallback.
willNotWorkOnGpu("Only specifying separator column not supported on GPU")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,60 @@ case class GpuInitCap(child: Expression) extends GpuUnaryExpression with Implici
}
}

case class GpuStringRepeat(input: Expression, repeatTimes: Expression)
extends GpuBinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = input
override def right: Expression = repeatTimes
override def dataType: DataType = input.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType)

def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector =
throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this")
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

def doColumnar(input: GpuColumnVector, repeatTimes: GpuColumnVector): ColumnVector = {
val repeatTimesCV = repeatTimes.getBase

// Compute the output size to check for overflow.
val outputSizes = input.getBase.repeatStringsSizes(repeatTimesCV)
if (outputSizes.getTotalSize > Int.MaxValue) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings using the pre-computed strings' sizes.
input.getBase.repeatStrings(repeatTimesCV, outputSizes.getStringSizes)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
}

def doColumnar(input: GpuColumnVector, repeatTimes: GpuScalar): ColumnVector = {
repeatTimes.getValue match {
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
case repeatTimesVal: Int => { // only support Int type for this overload
// Get the input size to check for overflow for the output.
// Note that this is not an accurate check since the total buffer size of the input strings
// column may be larger than the total length of strings that are repeated in this function.
val inputBufferSize = input.getBase.getData.getLength
if(inputBufferSize > Int.MaxValue / repeatTimesVal) {
throw new RuntimeException("Output strings have total size exceed maximum allowed size")
}

// Finally repeat the strings.
input.getBase.repeatStrings(repeatTimesVal)
}
case _ => throw new IllegalStateException("Invalid data type for repeatTimes (must be INT32)")
}
}

def doColumnar(numRows: Int, input: GpuScalar, repeatTimes: GpuScalar): ColumnVector = {
repeatTimes.getValue match {
case repeatTimesVal: Int => { // only support Int type for this overload
withResource(input.getBase.repeatString(repeatTimesVal)) {
repeatedString => ColumnVector.fromScalar(repeatedString, 1)
}
}
case _ => throw new IllegalStateException("Invalid data type for repeatTimes (must be INT32)")
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
}
}

}

case class GpuStringReplace(
srcExpr: Expression,
searchExpr: Expression,
Expand Down