Skip to content

Commit

Permalink
Handle TIMESTAMP_MILLIS for rebase check (#8687)
Browse files Browse the repository at this point in the history
* Handle TIMESTAMP_MILLIS in isDateTimeRebaseNeeded

Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven authored Jul 12, 2023
1 parent c40abc9 commit 04d9080
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
16 changes: 9 additions & 7 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def limited_timestamp(nullable=True):
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc),
nullable=nullable)

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
def limited_int96():
return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc))
Expand Down Expand Up @@ -262,19 +261,20 @@ def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_fact
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")

# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS
# TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('ts_write_data_gen',
[('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc))),
('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc))),
('TIMESTAMP_MILLIS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"])
def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase):
ts_write, gen = ts_write_data_gen
data_path = spark_tmp_path + '/PARQUET_DATA'
int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase
date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase
date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS" or ts_write == "TIMESTAMP_MILLIS") else rebase
with_gpu_session(
lambda spark : writeParquetUpgradeCatchException(spark,
unary_op_df(spark, gen),
data_path,
unary_op_df(spark, gen), data_path,
spark_tmp_table_factory,
int96_rebase, date_time_rebase, ts_write))
with_cpu_session(
Expand Down Expand Up @@ -458,7 +458,9 @@ def generate_map_with_empty_validity(spark, path):
lambda spark, path: spark.read.parquet(path),
data_path)

@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()),
('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc))),
('TIMESTAMP_MILLIS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))])
@pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"])
@pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"])
@pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"])
Expand Down
2 changes: 0 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ object RebaseHelper {
startTs: Long): Boolean = {
val dtype = column.getType
if (dtype.hasTimeResolution) {
// TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to properly handle
// TIMESTAMP_MILLIS, for use require so we fail if that happens
require(dtype == DType.TIMESTAMP_MICROSECONDS)
withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
// See https://github.com/NVIDIA/spark-rapids/issues/8262
RmmRapidsRetryIterator.withRestoreOnRetry(cr) {
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
scan(cb)
transform(cb) match {
case Some(transformed) =>
// because we created a new transformed batch, we need to make sure we close it
withResource(transformed) { _ =>
scanAndWrite(transformed)
write(transformed)
}
case _ =>
scanAndWrite(cb)
write(cb)
}
}
}
Expand All @@ -198,14 +199,15 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
try {
val startTimestamp = System.nanoTime
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
scan(batch)
transform(batch) match {
case Some(transformed) =>
// because we created a new transformed batch, we need to make sure we close it
withResource(transformed) { _ =>
scanAndWrite(transformed)
write(transformed)
}
case _ =>
scanAndWrite(batch)
write(batch)
}
}

Expand All @@ -223,9 +225,14 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
}
}

private def scanAndWrite(batch: ColumnarBatch): Unit = {
private def scan(batch: ColumnarBatch): Unit = {
withResource(GpuColumnVector.from(batch)) { table =>
scanTableBeforeWrite(table)
}
}

private def write(batch: ColumnarBatch): Unit = {
withResource(GpuColumnVector.from(batch)) { table =>
anythingWritten = true
tableWriter.write(table)
}
Expand Down

0 comments on commit 04d9080

Please sign in to comment.