Skip to content

Commit

Permalink
Merge pull request NVIDIA#1825 from NVIDIA/branch-0.4
Browse files Browse the repository at this point in the history
[auto-merge] branch-0.4 to branch-0.5 [skip ci] [bot]
  • Loading branch information
nvauto authored Feb 26, 2021
2 parents d90c724 + 4710c3e commit 85a7b81
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 7 deletions.
1 change: 0 additions & 1 deletion integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def tmp(something):
return meta + idfn(something)
return tmp

@pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/999')
@ignore_order
@approximate_float
@pytest.mark.parametrize('c_gen', lead_lag_data_gens, ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.spark311

import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Lag, Lead, Literal, OffsetWindowFunction}
import org.apache.spark.sql.types.IntegerType

/**
* Spark 3.1.1-specific replacement for com.nvidia.spark.rapids.OffsetWindowFunctionMeta.
* This is required primarily for two reasons:
* 1. com.nvidia.spark.rapids.OffsetWindowFunctionMeta (compiled against Spark 3.0.x)
* fails class load in Spark 3.1.x. (`expr.input` is not recognized as an Expression.)
* 2. The semantics of offsets in LAG() are reversed/negated in Spark 3.1.1.
* E.g. The expression `LAG(col, 5)` causes Lag.offset to be set to `-5`,
* as opposed to `5`, in prior versions of Spark.
* This class adjusts the LAG offset to use similar semantics to Spark 3.0.x.
*/
abstract class OffsetWindowFunctionMeta[INPUT <: OffsetWindowFunction] (
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
lazy val input: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.input, conf, Some(this))
lazy val adjustedOffset: Expression = {
expr match {
case lag: Lag =>
GpuOverrides.extractLit(lag.offset) match {
case Some(Literal(offset: Int, IntegerType)) =>
Literal(-offset, IntegerType)
case _ =>
throw new IllegalStateException(
s"Only integer literal offsets are supported for LAG. Found:${lag.offset}")
}
case lead: Lead =>
GpuOverrides.extractLit(lead.offset) match {
case Some(Literal(offset: Int, IntegerType)) =>
Literal(offset, IntegerType)
case _ =>
throw new IllegalStateException(
s"Only integer literal offsets are supported for LEAD. Found:${lead.offset}")
}
case other =>
throw new IllegalStateException(s"$other is not a supported window function")
}
}
lazy val offset: BaseExprMeta[_] =
GpuOverrides.wrapExpr(adjustedOffset, conf, Some(this))
lazy val default: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.default, conf, Some(this))

override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty

override def tagExprForGpu(): Unit = {
expr match {
case Lead(_,_,_) => // Supported.
case Lag(_,_,_) => // Supported.
case other =>
willNotWorkOnGpu( s"Only LEAD/LAG offset window functions are supported. Found: $other")
}

if (GpuOverrides.extractLit(expr.offset).isEmpty) { // Not a literal offset.
willNotWorkOnGpu(
s"Only integer literal offsets are supported for LEAD/LAG. Found: ${expr.offset}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ class Spark311Shims extends Spark301Shims {

// stringChecks are the same
// binaryChecks are the same

override val decimalChecks: TypeSig = none
override val sparkDecimalSig: TypeSig = numeric + BOOLEAN + STRING

Expand Down Expand Up @@ -196,8 +195,37 @@ class Spark311Shims extends Spark301Shims {
childExprs(1).convertToGpu(),
childExprs(2).convertToGpu())
}
}),
// Spark 3.1.1-specific LEAD expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
(lead, conf, p, r) => new OffsetWindowFunctionMeta[Lead](lead, conf, p, r) {
override def convertToGpu(): GpuExpression =
GpuLead(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}),
// Spark 3.1.1-specific LAG expression, using custom OffsetWindowFunctionMeta.
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all,
Seq(ParamCheck("input", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP, TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default", TypeSig.numeric + TypeSig.BOOLEAN +
TypeSig.DATE + TypeSig.TIMESTAMP + TypeSig.NULL, TypeSig.all))),
(lag, conf, p, r) => new OffsetWindowFunctionMeta[Lag](lag, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuLag(input.convertToGpu(), offset.convertToGpu(), default.convertToGpu())
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
super.exprs301 ++ exprs311
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,43 @@ abstract class OffsetWindowFunctionMeta[INPUT <: OffsetWindowFunction] (
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends ExprMeta[INPUT](expr, conf, parent, rule) {
val input: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.input, conf, Some(this))
val offset: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.offset, conf, Some(this))
val default: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.default, conf, Some(this))
override val childExprs: Seq[BaseExprMeta[_]] = Seq(input, offset, default)
lazy val input: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.input, conf, Some(this))
lazy val offset: BaseExprMeta[_] = {
expr match {
case Lead(_,_,_) => // Supported.
case Lag(_,_,_) => // Supported.
case other =>
throw new IllegalStateException(
s"Only LEAD/LAG offset window functions are supported. Found: $other")
}

val literalOffset = GpuOverrides.extractLit(expr.offset) match {
case Some(Literal(offset: Int, IntegerType)) =>
Literal(offset, IntegerType)
case _ =>
throw new IllegalStateException(
s"Only integer literal offsets are supported for LEAD/LAG. Found: ${expr.offset}")
}

GpuOverrides.wrapExpr(literalOffset, conf, Some(this))
}
lazy val default: BaseExprMeta[_] = GpuOverrides.wrapExpr(expr.default, conf, Some(this))

override val childExprs: Seq[BaseExprMeta[_]] = Seq.empty

override def tagExprForGpu(): Unit = {
expr match {
case Lead(_,_,_) => // Supported.
case Lag(_,_,_) => // Supported.
case other =>
willNotWorkOnGpu( s"Only LEAD/LAG offset window functions are supported. Found: $other")
}

if (GpuOverrides.extractLit(expr.offset).isEmpty) { // Not a literal offset.
willNotWorkOnGpu(
s"Only integer literal offsets are supported for LEAD/LAG. Found: ${expr.offset}")
}
}
}

trait GpuOffsetWindowFunction extends GpuAggregateWindowFunction {
Expand Down

0 comments on commit 85a7b81

Please sign in to comment.