Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TimeAdd for non-UTC time zone #10068

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,26 @@
def test_timesub(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0015 to make sure we don't go before year 0001 while doing TimeSub
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a - (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', vals, ids=idfn)
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd(data_gen):
days, seconds = data_gen
assert_gpu_and_cpu_are_equal_collect(
# We are starting at year 0005 to make sure we don't go before year 0001
# and beyond year 10000 while doing TimeAdd
lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1)
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a + (interval {} days {} seconds)".format(days, seconds)))

@pytest.mark.parametrize('data_gen', [-pow(2, 63), pow(2, 63)], ids=idfn)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_long_overflow(data_gen):
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, TimestampGen())
.selectExpr("a + (interval {} microseconds)".format(data_gen)),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column():
Expand All @@ -61,6 +67,15 @@ def test_timeadd_daytime_column():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@allow_non_gpu(*non_supported_tz_allow)
def test_timeadd_daytime_column_long_overflow():
Copy link
Collaborator

@res-life res-life Dec 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to ensure the random df will 100% overflow?
Maybe specify some constant variables to ensure overflow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By not making it actually random.

DayTimeIntervalGen Has both a min_value and a max_value. You could set it up so all of the values generated would overflow. You might need to also remove the special cases and disable nulls to be 100% sure of it.

def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTERVAL, start_field="day", end_field="second",

You could also use SetValuesGen with only values in it that would overflow.

class SetValuesGen(DataGen):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to SetValuesGen.

gen_list = [('t', TimestampGen()),('d', DayTimeIntervalGen())]
assert_gpu_and_cpu_error(
lambda spark : gen_df(spark, gen_list).selectExpr("t + d").collect(),
conf={},
error_message='long overflow')

@pytest.mark.skipif(is_before_spark_350(), reason='DayTimeInterval overflow check for seconds is not supported before Spark 3.5.0')
def test_interval_seconds_overflow_exception():
assert_gpu_and_cpu_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.GpuOverrides
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
Expand Down Expand Up @@ -107,10 +107,56 @@ case class GpuTimeAdd(
}
}

// A tricky way to check overflow. The result is overflow when positive + positive = negative
// or negative + negative = positive, so we can check the sign of the result is the same as
// the sign of the operands.
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
closeOnExcept(resWithOverflow) { _ =>
val isCvPos = withResource(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
cv.greaterOrEqualTo(zero)
}
val sameSignal = closeOnExcept(isCvPos) { isCvPos =>
val isDurationPos = duration match {
case durScalar: Scalar =>
val isPosBool = durScalar.isValid && durScalar.getLong >= 0
Scalar.fromBool(isPosBool)
case dur : AutoCloseable =>
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero =>
dur.greaterOrEqualTo(zero)
}
}
withResource(isDurationPos) { _ =>
isCvPos.equalTo(isDurationPos)
}
}
val isOverflow = withResource(sameSignal) { _ =>
val sameSignalWithRes = withResource(isCvPos) { _ =>
val isResNeg = withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
resWithOverflow.lessThan(zero)
}
withResource(isResNeg) { _ =>
isCvPos.equalTo(isResNeg)
}
}
withResource(sameSignalWithRes) { _ =>
sameSignal.and(sameSignalWithRes)
}
}
val anyOverflow = withResource(isOverflow) { _ =>
isOverflow.any()
}
withResource(anyOverflow) { _ =>
if (anyOverflow.isValid && anyOverflow.getBoolean) {
throw new ArithmeticException("long overflow")
}
}
}
resWithOverflow
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import java.util.concurrent.TimeUnit

import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar}
import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.GpuOverrides
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
Expand Down Expand Up @@ -164,10 +164,56 @@ case class GpuTimeAdd(start: Expression,
}
}

// A tricky way to check overflow. The result is overflow when positive + positive = negative
// or negative + negative = positive, so we can check the sign of the result is the same as
// the sign of the operands.
private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dulicated.
Could we extract this function into a file like: datetimeExpressionsUtils.scala. It seems that it applys for all Spark versions, so do not put this funciton into a shim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already has overflow check utility: AddOverflowChecks.basicOpOverflowCheck

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

// Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion,
// and currently BinaryOperable.implicitConversion return Long
// Directly specify the return type is TIMESTAMP_MICROSECONDS
cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
val resWithOverflow = cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS)
closeOnExcept(resWithOverflow) { _ =>
revans2 marked this conversation as resolved.
Show resolved Hide resolved
val isCvPos = withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
cv.greaterOrEqualTo(zero)
}
val sameSignal = closeOnExcept(isCvPos) { isCvPos =>
val isDurationPos = duration match {
case durScalar: Scalar =>
val isPosBool = durScalar.isValid && durScalar.getLong >= 0
Scalar.fromBool(isPosBool)
case dur : AutoCloseable =>
withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, 0)) { zero =>
dur.greaterOrEqualTo(zero)
}
}
withResource(isDurationPos) { _ =>
isCvPos.equalTo(isDurationPos)
}
}
val isOverflow = withResource(sameSignal) { _ =>
val sameSignalWithRes = withResource(isCvPos) { _ =>
val isResNeg = withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, 0)) { zero =>
resWithOverflow.lessThan(zero)
}
withResource(isResNeg) { _ =>
isCvPos.equalTo(isResNeg)
}
}
withResource(sameSignalWithRes) { _ =>
sameSignal.and(sameSignalWithRes)
}
}
val anyOverflow = withResource(isOverflow) { _ =>
isOverflow.any()
}
withResource(anyOverflow) { _ =>
if (anyOverflow.isValid && anyOverflow.getBoolean) {
throw new ArithmeticException("long overflow")
}
}
}
resWithOverflow
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl
println(s"test,type,zone,used MS")
for (zoneStr <- zones) {
// run 6 rounds, but ignore the first round.
for (i <- 1 to 6) {
val elapses = (1 to 6).map { i =>
// run on Cpu
val startOnCpu = System.nanoTime()
withCpuSparkSession(
Expand All @@ -153,8 +153,16 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl
val elapseOnGpuMS = (endOnGpu - startOnGpu) / 1000000L
if (i != 1) {
println(s"$testName,Gpu,$zoneStr,$elapseOnGpuMS")
(elapseOnCpuMS, elapseOnGpuMS)
} else {
(0L, 0L) // skip the first round
}
}
val meanCpu = elapses.map(_._1).sum / 5.0
val meanGpu = elapses.map(_._2).sum / 5.0
val speedup = meanCpu.toDouble / meanGpu.toDouble
println(f"$testName, $zoneStr: mean cpu time: $meanCpu%.2f ms, " +
f"mean gpu time: $meanGpu%.2f ms, speedup: $speedup%.2f x")
}
}

Expand All @@ -173,4 +181,29 @@ class TimeZonePerfSuite extends SparkQueryCompareTestSuite with BeforeAndAfterAl

runAndRecordTime("from_utc_timestamp", perfTest)
}

test("test timeadd") {
assume(enablePerfTest)

// cache time zone DB in advance
GpuTimeZoneDB.cacheDatabase()
Thread.sleep(5L)

def perfTest(spark: SparkSession, zone: String): DataFrame = {
spark.read.parquet(path).selectExpr(
"count(c_ts - (interval -584 days 1563 seconds))",
"count(c_ts - (interval 1943 days 1101 seconds))",
"count(c_ts - (interval 2693 days 2167 seconds))",
"count(c_ts - (interval 2729 days 0 seconds))",
"count(c_ts - (interval 44 days 1534 seconds))",
"count(c_ts - (interval 2635 days 3319 seconds))",
"count(c_ts - (interval 1885 days -2828 seconds))",
"count(c_ts - (interval 0 days 2463 seconds))",
"count(c_ts - (interval 932 days 2286 seconds))",
"count(c_ts - (interval 0 days 0 seconds))"
)
}

runAndRecordTime("time_add", perfTest)
}
}