Skip to content

Commit

Permalink
Add combiner for string contains
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Sep 2, 2024
1 parent ee2049a commit 3677ff6
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,21 @@ def test_case_when_all_then_values_are_scalars_with_nulls():
"tab",
sql_without_else,
conf = {'spark.rapids.sql.case_when.fuse': 'true'})

@pytest.mark.parametrize('combine_string_contains_enabled', ['true', 'false'])
def test_combine_string_contains_in_case_when(combine_string_contains_enabled):
data_gen = [("c1", string_gen)]
sql = """
SELECT
INSTR(c1, 'substring1') > 0,
INSTR(c1, 'substring2') > 0,
INSTR(c1, 'substring3') > 0
from tab
"""
# spark.rapids.sql.combined.expressions.enabled is true by default
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen),
"tab",
sql,
{ "spark.rapids.sql.expression.combined.GpuContains" : combine_string_contains_enabled}
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{Locale, Optional}

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, BinaryOperable, CaptureGroups, ColumnVector, ColumnView, DType, PadSide, RegexProgram, RoundMode, Scalar}
Expand All @@ -32,6 +34,7 @@ import com.nvidia.spark.rapids.jni.RegexRewriteUtils
import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuCombinable, GpuExpressionCombiner, GpuExpressionEquals}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -388,10 +391,11 @@ case class GpuConcatWs(children: Seq[Expression])
}

case class GpuContains(left: Expression, right: Expression)
extends GpuBinaryExpressionArgsAnyScalar
with Predicate
with ImplicitCastInputTypes
with NullIntolerant {
extends GpuBinaryExpressionArgsAnyScalar
with Predicate
with ImplicitCastInputTypes
with NullIntolerant
with GpuCombinable {

override def inputTypes: Seq[DataType] = Seq(StringType)

Expand All @@ -411,6 +415,103 @@ case class GpuContains(left: Expression, right: Expression)
doColumnar(expandedLhs, rhs)
}
}

/**
* Get a combiner that can be used to find candidates to combine
*/
override def getCombiner(): GpuExpressionCombiner = new ContainsCombiner(this)
}

case class GpuMultiContains(left: Expression, targets: Seq[UTF8String], output: StructType)
extends GpuExpression with ShimExpression {

override def otherCopyArgs: Seq[AnyRef] = Nil

override def dataType: DataType = output

override def nullable: Boolean = false

override def prettyName: String = "multi_contains"

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
val targetsBytes = targets.map(t => t.getBytes).toArray
withResource(ColumnVector.fromUTF8Strings(targetsBytes : _*)) { targetsCv =>
withResource(left.columnarEval(batch)) { lhs =>
withResource(lhs.getBase.stringContains(targetsCv)) { boolCvs =>
GpuColumnVector.from(ColumnVector.makeStruct(batch.numRows(), boolCvs: _*), dataType)
}
}
}
}
override def children: Seq[Expression] = Seq(left)
}

class ContainsCombiner(private val exp: GpuContains) extends GpuExpressionCombiner {
private var outputLocation = 0
/**
* A mapping between an expression and where in the output struct of
* the MultiGetJsonObject will the output be.
*/
private val toCombine = mutable.HashMap.empty[GpuExpressionEquals, Int]
addExpression(exp)

override def toString: String = s"ContainsCombiner $toCombine"

override def hashCode: Int = {
// We already know that we are Contains, and what we can combine is based
// on the string column being the same.
"Contains".hashCode + (exp.left.semanticHash() * 17)
}

/**
* only combine when targets are literals
*/
override def equals(o: Any): Boolean = o match {
case other: ContainsCombiner => exp.left.semanticEquals(other.exp.left) &&
exp.right.isInstanceOf[GpuLiteral] && other.exp.right.isInstanceOf[GpuLiteral]
case _ => false
}

override def addExpression(e: Expression): Unit = {
val localOutputLocation = outputLocation
outputLocation += 1
val key = GpuExpressionEquals(e)
if (!toCombine.contains(key)) {
toCombine.put(key, localOutputLocation)
}
}

override def useCount: Int = toCombine.size

private def fieldName(id: Int): String =
s"_mc_$id"

@tailrec
private def extractLiteral(exp: Expression): GpuLiteral = exp match {
case l: GpuLiteral => l
case a: Alias => extractLiteral(a.child)
case other => throw new RuntimeException("Unsupported expression in contains combiner, " +
"should be a literal type, actual type is " + other.getClass.getName)
}

private lazy val multiContains: GpuMultiContains = {
val input = toCombine.head._1.e.asInstanceOf[GpuContains].left
val fieldsNPaths = toCombine.toSeq.map {
case (k, id) =>
(id, k.e)
}.sortBy(_._1).map {
case (id, e: GpuContains) =>
val target = extractLiteral(e.right).value.asInstanceOf[UTF8String]
(StructField(fieldName(id), e.dataType, e.nullable), target)
}
val dt = StructType(fieldsNPaths.map(_._1))
GpuMultiContains(input, fieldsNPaths.map(_._2), dt)
}

override def getReplacementExpression(e: Expression): Expression = {
val localId = toCombine(GpuExpressionEquals(e))
GpuGetStructField(multiContains, localId, Some(fieldName(localId)))
}
}

case class GpuSubstring(str: Expression, pos: Expression, len: Expression)
Expand Down

0 comments on commit 3677ff6

Please sign in to comment.