Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimate and validate regular expression complexities #6006

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Name | Description | Default Value
<a name="sql.reader.batchSizeBytes"></a>spark.rapids.sql.reader.batchSizeBytes|Soft limit on the maximum number of bytes the reader reads per batch. The readers will read chunks of data until this limit is met or exceeded. Note that the reader may estimate the number of bytes that will be used on the GPU in some cases based on the schema and number of rows in each batch.|2147483647
<a name="sql.reader.batchSizeRows"></a>spark.rapids.sql.reader.batchSizeRows|Soft limit on the maximum number of rows the reader will read per batch. The orc and parquet readers will read row groups until this limit is met or exceeded. The limit is respected by the csv reader.|2147483647
<a name="sql.regexp.enabled"></a>spark.rapids.sql.regexp.enabled|Specifies whether supported regular expressions will be evaluated on the GPU. Unsupported expressions will fall back to CPU. However, there are some known edge cases that will still execute on GPU and produce incorrect results and these are documented in the compatibility guide. Setting this config to false will make all regular expressions run on the CPU instead.|true
<a name="sql.regexp.maxStateMemory"></a>spark.rapids.sql.regexp.maxStateMemory|Specifies the maximum memory on GPU to be used for regular expressions.The memory usage is an estimate based on an upper-bound approximation on the complexity of the regular expression. Note that the actual memory usage may still be higher than this estimate depending on the number of rows in the datacolumn and the input strings themselves. It is recommended to not set this to more than 3 times spark.rapids.sql.batchSizeBytes|2147483647
<a name="sql.replaceSortMergeJoin.enabled"></a>spark.rapids.sql.replaceSortMergeJoin.enabled|Allow replacing sortMergeJoin with HashJoin|true
<a name="sql.rowBasedUDF.enabled"></a>spark.rapids.sql.rowBasedUDF.enabled|When set to true, optimizes a row-based UDF in a GPU operation by transferring only the data it needs between GPU and CPU inside a query operation, instead of falling this operation back to CPU. This is an experimental feature, and this config might be removed in the future.|false
<a name="sql.shuffle.spillThreads"></a>spark.rapids.sql.shuffle.spillThreads|Number of threads used to spill shuffle data to disk in the background.|6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class GpuRegExpReplaceMeta(
try {
javaPattern = Some(s.toString())
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).transpile(s.toString, replacement)
cudfPattern = Some(pat)
repl.map(GpuRegExpUtils.backrefConversion).foreach {
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, replacement)
GpuRegExpUtils.validateRegExpComplexity(this, pat)
cudfPattern = Some(pat.toRegexString)
repl.map { r => GpuRegExpUtils.backrefConversion(r.toRegexString) }.foreach {
case (hasBackref, convertedRep) =>
containsBackref = hasBackref
replacement = Some(GpuRegExpUtils.unescapeReplaceString(convertedRep))
Expand Down
20 changes: 20 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,16 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val REGEXP_MAX_STATE_MEMORY = conf("spark.rapids.sql.regexp.maxStateMemory")
.doc("Specifies the maximum memory on GPU to be used for regular expressions." +
"The memory usage is an estimate based on an upper-bound approximation on the " +
"complexity of the regular expression. Note that the actual memory usage may " +
"still be higher than this estimate depending on the number of rows in the data" +
"column and the input strings themselves. It is recommended to not set this to " +
s"more than 3 times ${GPU_BATCH_SIZE_BYTES.key}")
.longConf
.createWithDefault(Integer.MAX_VALUE)

// INTERNAL TEST AND DEBUG CONFIGS

val TEST_CONF = conf("spark.rapids.sql.test.enabled")
Expand Down Expand Up @@ -1926,6 +1936,16 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP)

lazy val maxRegExpStateMemory: Long = {
val size = get(REGEXP_MAX_STATE_MEMORY)
if (size > 3 * gpuTargetBatchSizeBytes) {
logWarning(s"${REGEXP_MAX_STATE_MEMORY.key} is more than 3 times " +
s"${GPU_BATCH_SIZE_BYTES.key}. This may cause regular expression operations to " +
s"encounter GPU out of memory errors.")
}
size
}

lazy val getSparkGpuResourceName: String = get(SPARK_GPU_RESOURCE_NAME)

lazy val isCpuBasedUDFEnabled: Boolean = get(ENABLE_CPU_BASED_UDF)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids

import org.apache.spark.sql.types.DataTypes

object RegexComplexityEstimator {
private def countStates(regex: RegexAST): Int = {
regex match {
case RegexSequence(parts) =>
parts.map(countStates).sum
case RegexGroup(true, term) =>
1 + countStates(term)
case RegexGroup(false, term) =>
countStates(term)
case RegexCharacterClass(_, _) =>
1
case RegexChoice(left, right) =>
countStates(left) + countStates(right)
case RegexRepetition(term, QuantifierFixedLength(length)) =>
length * countStates(term)
case RegexRepetition(term, SimpleQuantifier(ch)) =>
ch match {
case '*' =>
countStates(term)
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
case '+' =>
1 + countStates(term)
case '?' =>
1 + countStates(term)
}
case RegexRepetition(term, QuantifierVariableLength(minLength, maxLengthOption)) =>
maxLengthOption match {
case Some(maxLength) =>
maxLength * countStates(term)
case None =>
minLength.max(1) * countStates(term)
}
case RegexChar(_) | RegexEscaped(_) | RegexHexDigit(_) | RegexOctalChar(_) =>
1
case _ =>
0
}
}

private def estimateGpuMemory(numStates: Int, desiredBatchSizeBytes: Long): Long = {
val numRows = GpuBatchUtils.estimateRowCount(
desiredBatchSizeBytes, DataTypes.StringType.defaultSize, 1)

numStates * numRows * 2
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
}

def isValid(conf: RapidsConf, regex: RegexAST): Boolean = {
val numStates = countStates(regex)
if (estimateGpuMemory(numStates, conf.gpuTargetBatchSizeBytes) > conf.maxRegExpStateMemory) {
false
}
true
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class RegexParser(pattern: String) {
}

private def parseCharacterClass(): RegexCharacterClass = {
val supportedMetaCharacters = "\\^-]+"
val supportedMetaCharacters = "\\^-[]+"
anthony-chang marked this conversation as resolved.
Show resolved Hide resolved

def getEscapedComponent(): RegexCharacterClassComponent = {
peek() match {
Expand Down Expand Up @@ -208,7 +208,7 @@ class RegexParser(pattern: String) {
RegexEscaped(ch)
} else {
throw new RegexUnsupportedException(
s"Unsupported escaped character in character class", Some(pos-1))
s"Unsupported escaped character '$ch' in character class", Some(pos-1))
}
}
case None =>
Expand Down Expand Up @@ -683,6 +683,22 @@ class CudfRegexTranspiler(mode: RegexMode) {
* @return Regular expression and optional replacement in cuDF format
*/
def transpile(pattern: String, repl: Option[String]): (String, Option[String]) = {
val (cudfRegex, replacement) = getTranspiledAST(pattern, repl)

// write out to regex string, performing minor transformations
// such as adding additional escaping
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
}

/**
* Parse Java regular expression and translate into cuDF regular expression in AST form.
*
* @param pattern Regular expression that is valid in Java's engine
* @param repl Optional replacement pattern
* @return Regular expression AST and optional replacement in cuDF format
*/
def getTranspiledAST(
pattern: String, repl: Option[String]): (RegexAST, Option[RegexReplacement]) = {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// if we have a replacement, parse the replacement string using the regex parser to account
Expand All @@ -691,11 +707,10 @@ class CudfRegexTranspiler(mode: RegexMode) {

// validate that the regex is supported by cuDF
val cudfRegex = transpile(regex, replacement, None)
// write out to regex string, performing minor transformations
// such as adding additional escaping
(cudfRegex.toRegexString, replacement.map(_.toRegexString))
}

(cudfRegex, replacement)
}

def transpileToSplittableString(e: RegexAST): Option[String] = {
e match {
case RegexEscaped(ch) if regexMetaChars.contains(ch) => Some(ch.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,13 @@ object GpuRegExpUtils {
}
}

def validateRegExpComplexity(meta: ExprMeta[_], regex: RegexAST): Unit = {
if(!RegexComplexityEstimator.isValid(meta.conf, regex)) {
meta.willNotWorkOnGpu(s"Estimated memory needed for regular expression exceeds the maximum." +
s"Set ${RapidsConf.REGEXP_MAX_STATE_MEMORY} to change it.")
}
}

/**
* Recursively check if pattern contains only zero-match repetitions
* ?, *, {0,}, or {0,n} or any combination of them.
Expand Down Expand Up @@ -905,7 +912,10 @@ class GpuRLikeMeta(
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
try {
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode).transpile(str.toString, None)._1)
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(str.toString, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = Some(transpiledAST.toRegexString)
} catch {
case e: RegexUnsupportedException =>
willNotWorkOnGpu(e.getMessage)
Expand Down Expand Up @@ -1092,8 +1102,10 @@ class GpuRegExpExtractMeta(
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(javaRegexpPattern, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = Some(transpiledAST.toRegexString)
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
Expand Down Expand Up @@ -1211,8 +1223,10 @@ class GpuRegExpExtractAllMeta(
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
pattern = Some(new CudfRegexTranspiler(RegexFindMode)
.transpile(javaRegexpPattern, None)._1)
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(javaRegexpPattern, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = Some(transpiledAST.toRegexString)
numGroups = GpuRegExpUtils.countGroups(javaRegexpPattern)
} catch {
case e: RegexUnsupportedException =>
Expand Down Expand Up @@ -1588,7 +1602,9 @@ abstract class StringSplitRegExpMeta[INPUT <: TernaryExpression](expr: INPUT,
pattern = simplified
case None =>
try {
pattern = transpiler.transpile(utf8Str.toString, None)._1
val (transpiledAST, _) = transpiler.getTranspiledAST(utf8Str.toString, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
pattern = transpiledAST.toRegexString
isRegExp = true
} catch {
case e: RegexUnsupportedException =>
Expand Down