From f1bd983e6c8b9d69179d0a888dd209338251c57e Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 18 Aug 2021 14:01:01 -0500 Subject: [PATCH] Update to new cudf AST API Signed-off-by: Jason Lowe --- .../spark/rapids/GpuBoundAttribute.scala | 2 +- .../nvidia/spark/rapids/GpuExpressions.scala | 10 ++++----- .../spark/rapids/basicPhysicalOperators.scala | 6 +---- .../com/nvidia/spark/rapids/literals.scala | 2 +- .../spark/rapids/namedExpressions.scala | 2 +- .../apache/spark/sql/rapids/arithmetic.scala | 6 ++--- .../GpuBroadcastNestedLoopJoinExec.scala | 5 +---- .../spark/sql/rapids/mathExpressions.scala | 22 +++++++++---------- .../apache/spark/sql/rapids/predicates.scala | 10 ++++----- 9 files changed, 29 insertions(+), 36 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala index 42335fc5af6..379a81cf3e3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala @@ -109,7 +109,7 @@ case class GpuBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { // Spark treats all inputs as a single sequence of columns. For example, a join will put all // the columns of the left table followed by all the columns of the right table. cudf AST // instead uses explicit table references to distinguish which table is being indexed by a diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala index b3a4adb819e..f406519f82b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala @@ -159,7 +159,7 @@ trait GpuExpression extends Expression with Arm { * single sequence into cudf's separate sequences. * @return top node of the equivalent AST */ - def convertToAst(numFirstTableColumns: Int): ast.AstNode = + def convertToAst(numFirstTableColumns: Int): ast.AstExpression = throw new IllegalStateException(s"Cannot convert ${this.getClass.getSimpleName} to AST") } @@ -227,10 +227,10 @@ trait CudfUnaryExpression extends GpuUnaryExpression { override def doColumnar(input: GpuColumnVector): ColumnVector = input.getBase.unaryOp(unaryOp) - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { val astOp = CudfUnaryExpression.opToAstMap.getOrElse(unaryOp, throw new IllegalStateException(s"${this.getClass.getSimpleName} is not supported by AST")) - new ast.UnaryExpression(astOp, + new ast.UnaryOperation(astOp, child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)) } } @@ -330,11 +330,11 @@ trait CudfBinaryExpression extends GpuBinaryExpression { } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { val astOp = CudfBinaryExpression.opToAstMap.getOrElse(binaryOp, throw new IllegalStateException(s"$this is not supported by AST")) assert(left.dataType == right.dataType) - new ast.BinaryExpression(astOp, + new ast.BinaryOperation(astOp, left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns), right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 1e988c8c7f9..81f74a88994 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -185,11 +185,7 @@ case class GpuProjectAstExec( withResource(new NvtxWithMetrics("Compile ASTs", NvtxColor.ORANGE, opTime)) { _ => boundProjectList.safeMap { expr => // Use intmax for the left table column count since there's only one input table here. - val astExpr = expr.convertToAst(Int.MaxValue) match { - case e: ast.Expression => e - case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e) - } - astExpr.compile() + expr.convertToAst(Int.MaxValue).compile() } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 9b8addfdba0..14af56cdb84 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -635,7 +635,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression GpuScalar(value, dataType) } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { dataType match { case BooleanType => ast.Literal.ofBoolean(value.asInstanceOf[java.lang.Boolean]) case ByteType => ast.Literal.ofByte(value.asInstanceOf[java.lang.Byte]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala index 28a8c9e11ce..81071b7ce71 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala @@ -110,7 +110,7 @@ case class GpuAlias(child: Expression, name: String)( override def doColumnar(input: GpuColumnVector): ColumnVector = throw new IllegalStateException("GpuAlias should never have doColumnar called") - override def convertToAst(numLeftTableColumns: Int): ast.AstNode = child match { + override def convertToAst(numLeftTableColumns: Int): ast.AstExpression = child match { case e: GpuExpression => e.convertToAst(numLeftTableColumns) case e => throw new IllegalStateException(s"Attempt to convert $e to AST") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala index e028646c67c..9eca63f763f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -57,14 +57,14 @@ case class GpuUnaryMinus(child: Expression) extends GpuUnaryExpression } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { val literalZero = dataType match { case LongType => ast.Literal.ofLong(0) case FloatType => ast.Literal.ofFloat(0) case DoubleType => ast.Literal.ofDouble(0) case IntegerType => ast.Literal.ofInt(0) } - new ast.BinaryExpression(ast.BinaryOperator.SUB, literalZero, + new ast.BinaryOperation(ast.BinaryOperator.SUB, literalZero, child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)); } } @@ -81,7 +81,7 @@ case class GpuUnaryPositive(child: Expression) extends GpuUnaryExpression override def doColumnar(input: GpuColumnVector) : ColumnVector = input.getBase.incRefCount() - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index 4b276ae22c1..e3c7f01997f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -326,10 +326,7 @@ object GpuBroadcastNestedLoopJoinExecBase extends Arm { assert(joinType.isInstanceOf[InnerLike], s"Unexpected unconditional join type: $joinType") new CrossJoinIterator(builtBatch, stream, targetSize, buildSide, joinTime, totalTime) } else { - val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns) match { - case e: ast.Expression => e.compile() - case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e).compile() - } + val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns).compile() joinType match { case LeftAnti | LeftSemi => assert(buildSide == GpuBuildRight) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index 0af7f6ecfd1..3390b92f25a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -87,17 +87,17 @@ case class GpuAcoshCompat(child: Expression) extends GpuUnaryMathExpression("ACO } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { // Typically we would just use UnaryOp.ARCCOSH, but there are corner cases where cudf // produces a better result (it does not overflow) than spark does, but our goal is // to match Spark's // StrictMath.log(x + math.sqrt(x * x - 1.0)) val x = child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns) - new ast.UnaryExpression(ast.UnaryOperator.LOG, - new ast.BinaryExpression(ast.BinaryOperator.ADD, x, - new ast.UnaryExpression(ast.UnaryOperator.SQRT, - new ast.BinaryExpression(ast.BinaryOperator.SUB, - new ast.BinaryExpression(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1))))) + new ast.UnaryOperation(ast.UnaryOperator.LOG, + new ast.BinaryOperation(ast.BinaryOperator.ADD, x, + new ast.UnaryOperation(ast.UnaryOperator.SQRT, + new ast.BinaryOperation(ast.BinaryOperator.SUB, + new ast.BinaryOperation(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1))))) } } @@ -211,8 +211,8 @@ case class GpuExpm1(child: Expression) extends CudfUnaryMathExpression("EXPM1") } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { - new ast.BinaryExpression(ast.BinaryOperator.SUB, + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { + new ast.BinaryOperation(ast.BinaryOperator.SUB, super.convertToAst(numFirstTableColumns), ast.Literal.ofDouble(1)) } @@ -378,9 +378,9 @@ case class GpuCot(child: Expression) extends GpuUnaryMathExpression("COT") { } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { - new ast.BinaryExpression(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1), - new ast.UnaryExpression(ast.UnaryOperator.TAN, + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { + new ast.BinaryOperation(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1), + new ast.UnaryOperation(ast.UnaryOperator.TAN, child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala index 2995728403f..f86c24d89bb 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/predicates.scala @@ -44,11 +44,11 @@ case class GpuNot(child: Expression) extends CudfUnaryExpression override def unaryOp: UnaryOp = UnaryOp.NOT - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { child match { case c: GpuEqualTo => // optimize the AST expression since Spark doesn't have a NotEqual - new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL, + new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL, c.left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns), c.right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)) case _ => super.convertToAst(numFirstTableColumns) @@ -293,11 +293,11 @@ case class GpuEqualTo(left: Expression, right: Expression) extends CudfBinaryCom } } - override def convertToAst(numFirstTableColumns: Int): ast.AstNode = { + override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = { // Currently AST computeColumn assumes nulls compare true for EQUAL, but NOT_EQUAL will // return null for null input. - new ast.UnaryExpression(ast.UnaryOperator.NOT, - new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL, + new ast.UnaryOperation(ast.UnaryOperator.NOT, + new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL, left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns), right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))) }