Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to new cudf AST API #3252

Merged
merged 1 commit into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ case class GpuBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Spark treats all inputs as a single sequence of columns. For example, a join will put all
// the columns of the left table followed by all the columns of the right table. cudf AST
// instead uses explicit table references to distinguish which table is being indexed by a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ trait GpuExpression extends Expression with Arm {
* single sequence into cudf's separate sequences.
* @return top node of the equivalent AST
*/
def convertToAst(numFirstTableColumns: Int): ast.AstNode =
def convertToAst(numFirstTableColumns: Int): ast.AstExpression =
throw new IllegalStateException(s"Cannot convert ${this.getClass.getSimpleName} to AST")
}

Expand Down Expand Up @@ -227,10 +227,10 @@ trait CudfUnaryExpression extends GpuUnaryExpression {

override def doColumnar(input: GpuColumnVector): ColumnVector = input.getBase.unaryOp(unaryOp)

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val astOp = CudfUnaryExpression.opToAstMap.getOrElse(unaryOp,
throw new IllegalStateException(s"${this.getClass.getSimpleName} is not supported by AST"))
new ast.UnaryExpression(astOp,
new ast.UnaryOperation(astOp,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
}
}
Expand Down Expand Up @@ -330,11 +330,11 @@ trait CudfBinaryExpression extends GpuBinaryExpression {
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val astOp = CudfBinaryExpression.opToAstMap.getOrElse(binaryOp,
throw new IllegalStateException(s"$this is not supported by AST"))
assert(left.dataType == right.dataType)
new ast.BinaryExpression(astOp,
new ast.BinaryOperation(astOp,
left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,7 @@ case class GpuProjectAstExec(
withResource(new NvtxWithMetrics("Compile ASTs", NvtxColor.ORANGE, opTime)) { _ =>
boundProjectList.safeMap { expr =>
// Use intmax for the left table column count since there's only one input table here.
val astExpr = expr.convertToAst(Int.MaxValue) match {
case e: ast.Expression => e
case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e)
}
astExpr.compile()
expr.convertToAst(Int.MaxValue).compile()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression
GpuScalar(value, dataType)
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
dataType match {
case BooleanType => ast.Literal.ofBoolean(value.asInstanceOf[java.lang.Boolean])
case ByteType => ast.Literal.ofByte(value.asInstanceOf[java.lang.Byte])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ case class GpuAlias(child: Expression, name: String)(
override def doColumnar(input: GpuColumnVector): ColumnVector =
throw new IllegalStateException("GpuAlias should never have doColumnar called")

override def convertToAst(numLeftTableColumns: Int): ast.AstNode = child match {
override def convertToAst(numLeftTableColumns: Int): ast.AstExpression = child match {
case e: GpuExpression => e.convertToAst(numLeftTableColumns)
case e => throw new IllegalStateException(s"Attempt to convert $e to AST")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ case class GpuUnaryMinus(child: Expression) extends GpuUnaryExpression
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
val literalZero = dataType match {
case LongType => ast.Literal.ofLong(0)
case FloatType => ast.Literal.ofFloat(0)
case DoubleType => ast.Literal.ofDouble(0)
case IntegerType => ast.Literal.ofInt(0)
}
new ast.BinaryExpression(ast.BinaryOperator.SUB, literalZero,
new ast.BinaryOperation(ast.BinaryOperator.SUB, literalZero,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns));
}
}
Expand All @@ -81,7 +81,7 @@ case class GpuUnaryPositive(child: Expression) extends GpuUnaryExpression

override def doColumnar(input: GpuColumnVector) : ColumnVector = input.getBase.incRefCount()

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,7 @@ object GpuBroadcastNestedLoopJoinExecBase extends Arm {
assert(joinType.isInstanceOf[InnerLike], s"Unexpected unconditional join type: $joinType")
new CrossJoinIterator(builtBatch, stream, targetSize, buildSide, joinTime, totalTime)
} else {
val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns) match {
case e: ast.Expression => e.compile()
case e => new ast.UnaryExpression(ast.UnaryOperator.IDENTITY, e).compile()
}
val compiledAst = boundCondition.get.convertToAst(numFirstTableColumns).compile()
joinType match {
case LeftAnti | LeftSemi =>
assert(buildSide == GpuBuildRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ case class GpuAcoshCompat(child: Expression) extends GpuUnaryMathExpression("ACO
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Typically we would just use UnaryOp.ARCCOSH, but there are corner cases where cudf
// produces a better result (it does not overflow) than spark does, but our goal is
// to match Spark's
// StrictMath.log(x + math.sqrt(x * x - 1.0))
val x = child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)
new ast.UnaryExpression(ast.UnaryOperator.LOG,
new ast.BinaryExpression(ast.BinaryOperator.ADD, x,
new ast.UnaryExpression(ast.UnaryOperator.SQRT,
new ast.BinaryExpression(ast.BinaryOperator.SUB,
new ast.BinaryExpression(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1)))))
new ast.UnaryOperation(ast.UnaryOperator.LOG,
new ast.BinaryOperation(ast.BinaryOperator.ADD, x,
new ast.UnaryOperation(ast.UnaryOperator.SQRT,
new ast.BinaryOperation(ast.BinaryOperator.SUB,
new ast.BinaryOperation(ast.BinaryOperator.MUL, x, x), ast.Literal.ofDouble(1)))))
}
}

Expand Down Expand Up @@ -211,8 +211,8 @@ case class GpuExpm1(child: Expression) extends CudfUnaryMathExpression("EXPM1")
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
new ast.BinaryExpression(ast.BinaryOperator.SUB,
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
new ast.BinaryOperation(ast.BinaryOperator.SUB,
super.convertToAst(numFirstTableColumns),
ast.Literal.ofDouble(1))
}
Expand Down Expand Up @@ -378,9 +378,9 @@ case class GpuCot(child: Expression) extends GpuUnaryMathExpression("COT") {
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
new ast.BinaryExpression(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1),
new ast.UnaryExpression(ast.UnaryOperator.TAN,
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
new ast.BinaryOperation(ast.BinaryOperator.DIV, ast.Literal.ofDouble(1),
new ast.UnaryOperation(ast.UnaryOperator.TAN,
child.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ case class GpuNot(child: Expression) extends CudfUnaryExpression

override def unaryOp: UnaryOp = UnaryOp.NOT

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
child match {
case c: GpuEqualTo =>
// optimize the AST expression since Spark doesn't have a NotEqual
new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL,
new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL,
c.left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
c.right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns))
case _ => super.convertToAst(numFirstTableColumns)
Expand Down Expand Up @@ -293,11 +293,11 @@ case class GpuEqualTo(left: Expression, right: Expression) extends CudfBinaryCom
}
}

override def convertToAst(numFirstTableColumns: Int): ast.AstNode = {
override def convertToAst(numFirstTableColumns: Int): ast.AstExpression = {
// Currently AST computeColumn assumes nulls compare true for EQUAL, but NOT_EQUAL will
// return null for null input.
new ast.UnaryExpression(ast.UnaryOperator.NOT,
new ast.BinaryExpression(ast.BinaryOperator.NOT_EQUAL,
new ast.UnaryOperation(ast.UnaryOperator.NOT,
new ast.BinaryOperation(ast.BinaryOperator.NOT_EQUAL,
left.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns),
right.asInstanceOf[GpuExpression].convertToAst(numFirstTableColumns)))
}
Expand Down