Skip to content

Commit

Permalink
Update to new cudf AST API (#3252)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Aug 19, 2021
1 parent c7cf7d4 commit 37986e4
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ case class GpuBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Spark treats all inputs as a single sequence of columns. For example, a join will put all
// the columns of the left table followed by all the columns of the right table. cudf AST
// instead uses explicit table references to distinguish which table is being indexed by a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ trait GpuExpression extends Expression with Arm {
* single sequence into cudf's separate sequences.
* @return top node of the equivalent AST
*/
def convertToAst(numFirstTableColumns: Int): ast.AstNode =
def convertToAst(numFirstTableColumns: Int): ast.AstExpression =
throw new IllegalStateException(s"Cannot convert ${this.getClass.getSimpleName} to AST")
}

Expand Down Expand Up @@ -227,10 +227,10 @@ trait CudfUnaryExpression extends GpuUnaryExpression {

override def doColumnar(input: GpuColumnVector): ColumnVector = input.getBase.unaryOp(unaryOp)

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val astOp = CudfUnaryExpression.opToAstMap.getOrElse(unaryOp,
throw new IllegalStateException(s"${this.getClass.getSimpleName} is not supported by AST"))
new ast.UnaryExpression(astOp,
new ast.UnaryOperation(astOp,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
}
}
Expand Down Expand Up @@ -330,11 +330,11 @@ trait CudfBinaryExpression extends GpuBinaryExpression {
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val astOp = CudfBinaryExpression.opToAstMap.getOrElse(binaryOp,
throw new IllegalStateException(s"$this is not supported by AST"))
assert(left.dataType == right.dataType)
new ast.BinaryExpression(astOp,
new ast.BinaryOperation(astOp,
left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,7 @@ case class GpuProjectAstExec(
withResource(new NvtxWithMetrics("Compile ASTs", NvtxColor.ORANGE, opTime)) { _ =>
boundProjectList.safeMap { expr =>
// Use intmax for the left table column count since there's only one input table here.
val astExpr = expr.convertToAst(Int.MaxValue) match {
case e: ast.Expression => e
case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e)
}
astExpr.compile()
expr.convertToAst(Int.MaxValue).compile()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression
GpuScalar(value, dataType)
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
dataType match {
case BooleanType => ast.Literal.ofBoolean(value.asInstanceOf[java.lang.Boolean])
case ByteType => ast.Literal.ofByte(value.asInstanceOf[java.lang.Byte])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ case class GpuAlias(child: Expression, name: String)(
override def doColumnar(input: GpuColumnVector): ColumnVector =
throw new IllegalStateException("GpuAlias should never have doColumnar called")

override def convertToAst(numLeftTableColumns: Int): ast.AstNode = child match {
override def convertToAst(numLeftTableColumns: Int): ast.AstExpression = child match {
case e: GpuExpression => e.convertToAst(numLeftTableColumns)
case e => throw new IllegalStateException(s"Attempt to convert $e to AST")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ case class GpuUnaryMinus(child: Expression) extends GpuUnaryExpression
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val literalZero = dataType match {
case LongType => ast.Literal.ofLong(0)
case FloatType => ast.Literal.ofFloat(0)
case DoubleType => ast.Literal.ofDouble(0)
case IntegerType => ast.Literal.ofInt(0)
}
new ast.BinaryExpression(ast.BinaryOperator.SUB, literalZero,
new ast.BinaryOperation(ast.BinaryOperator.SUB, literalZero,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns));
}
}
Expand All @@ -81,7 +81,7 @@ case class GpuUnaryPositive(child: Expression) extends GpuUnaryExpression

override def doColumnar(input: GpuColumnVector) : ColumnVector = input.getBase.incRefCount()

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,7 @@ object GpuBroadcastNestedLoopJoinExecBase extends Arm {
assert(joinType.isInstanceOf[InnerLike], s"Unexpected unconditional join type: $joinType")
new CrossJoinIterator(builtBatch, stream, targetSize, buildSide, joinTime, totalTime)
} else {
val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns) match {
case e: ast.Expression => e.compile()
case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e).compile()
}
val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns).compile()
joinType match {
case LeftAnti | LeftSemi =>
assert(buildSide == GpuBuildRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ case class GpuAcoshCompat(child: Expression) extends GpuUnaryMathExpression("ACO
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Typically we would just use UnaryOp.ARCCOSH, but there are corner cases where cudf
// produces a better result (it does not overflow) than spark does, but our goal is
// to match Spark's
// StrictMath.log(x + math.sqrt(x * x - 1.0))
val x = child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)
new ast.UnaryExpression(ast.UnaryOperator.LOG,
new ast.BinaryExpression(ast.BinaryOperator.ADD, x,
new ast.UnaryExpression(ast.UnaryOperator.SQRT,
new ast.BinaryExpression(ast.BinaryOperator.SUB,
new ast.BinaryExpression(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1)))))
new ast.UnaryOperation(ast.UnaryOperator.LOG,
new ast.BinaryOperation(ast.BinaryOperator.ADD, x,
new ast.UnaryOperation(ast.UnaryOperator.SQRT,
new ast.BinaryOperation(ast.BinaryOperator.SUB,
new ast.BinaryOperation(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1)))))
}
}

Expand Down Expand Up @@ -211,8 +211,8 @@ case class GpuExpm1(child: Expression) extends CudfUnaryMathExpression("EXPM1")
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
new ast.BinaryExpression(ast.BinaryOperator.SUB,
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
new ast.BinaryOperation(ast.BinaryOperator.SUB,
super.convertToAst(numFirstTableColumns),
ast.Literal.ofDouble(1))
}
Expand Down Expand Up @@ -378,9 +378,9 @@ case class GpuCot(child: Expression) extends GpuUnaryMathExpression("COT") {
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
new ast.BinaryExpression(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1),
new ast.UnaryExpression(ast.UnaryOperator.TAN,
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
new ast.BinaryOperation(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1),
new ast.UnaryOperation(ast.UnaryOperator.TAN,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ case class GpuNot(child: Expression) extends CudfUnaryExpression

override def unaryOp: UnaryOp = UnaryOp.NOT

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
child match {
case c: GpuEqualTo =>
// optimize the AST expression since Spark doesn't have a NotEqual
new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL,
new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL,
c.left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
c.right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
case _ => super.convertToAst(numFirstTableColumns)
Expand Down Expand Up @@ -293,11 +293,11 @@ case class GpuEqualTo(left: Expression, right: Expression) extends CudfBinaryCom
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Currently AST computeColumn assumes nulls compare true for EQUAL, but NOT_EQUAL will
// return null for null input.
new ast.UnaryExpression(ast.UnaryOperator.NOT,
new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL,
new ast.UnaryOperation(ast.UnaryOperator.NOT,
new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL,
left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)))
}
Expand Down

0 comments on commit 37986e4

Please sign in to comment.