diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 72a9c80c5c5..c96b7d0a5ac 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -45,7 +45,6 @@ def limited_timestamp(nullable=True): return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc), nullable=nullable) -# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS # TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070 def limited_int96(): return TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(2262, 4, 11, tzinfo=timezone.utc)) @@ -262,19 +261,20 @@ def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_fact df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) assert e_info.match(r".*SparkUpgradeException.*") -# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS # TODO - we are limiting the INT96 values, see https://github.com/rapidsai/cudf/issues/8070 -@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) +@pytest.mark.parametrize('ts_write_data_gen', + [('INT96', TimestampGen(start=datetime(1677, 9, 22, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc))), + ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc))), + ('TIMESTAMP_MILLIS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1899, 12, 31, tzinfo=timezone.utc)))]) @pytest.mark.parametrize('rebase', ["CORRECTED","EXCEPTION"]) def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write_data_gen, spark_tmp_table_factory, rebase): ts_write, gen = ts_write_data_gen data_path = spark_tmp_path + '/PARQUET_DATA' int96_rebase = "EXCEPTION" if (ts_write == "INT96") else rebase - date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS") else rebase + date_time_rebase = "EXCEPTION" if (ts_write == "TIMESTAMP_MICROS" or ts_write == "TIMESTAMP_MILLIS") else rebase with_gpu_session( lambda spark : writeParquetUpgradeCatchException(spark, - unary_op_df(spark, gen), - data_path, + unary_op_df(spark, gen), data_path, spark_tmp_table_factory, int96_rebase, date_time_rebase, ts_write)) with_cpu_session( @@ -458,7 +458,9 @@ def generate_map_with_empty_validity(spark, path): lambda spark, path: spark.read.parquet(path), data_path) -@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) +@pytest.mark.parametrize('ts_write_data_gen', [('INT96', limited_int96()), + ('TIMESTAMP_MICROS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc))), + ('TIMESTAMP_MILLIS', TimestampGen(start=datetime(1, 1, 1, tzinfo=timezone.utc), end=datetime(1582, 1, 1, tzinfo=timezone.utc)))]) @pytest.mark.parametrize('date_time_rebase_write', ["CORRECTED"]) @pytest.mark.parametrize('date_time_rebase_read', ["EXCEPTION", "CORRECTED"]) @pytest.mark.parametrize('int96_rebase_write', ["CORRECTED"]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala index 0540bfee48f..81fb4c79ddc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala @@ -48,8 +48,6 @@ object RebaseHelper { startTs: Long): Boolean = { val dtype = column.getType if (dtype.hasTimeResolution) { - // TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to properly handle - // TIMESTAMP_MILLIS, for use require so we fail if that happens require(dtype == DType.TIMESTAMP_MICROSECONDS) withResource( Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala index 014d8fdf498..6d8f078ec81 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala @@ -174,14 +174,15 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, // See https://github.com/NVIDIA/spark-rapids/issues/8262 RmmRapidsRetryIterator.withRestoreOnRetry(cr) { withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => + scan(cb) transform(cb) match { case Some(transformed) => // because we created a new transformed batch, we need to make sure we close it withResource(transformed) { _ => - scanAndWrite(transformed) + write(transformed) } case _ => - scanAndWrite(cb) + write(cb) } } } @@ -198,14 +199,15 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, try { val startTimestamp = System.nanoTime withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => + scan(batch) transform(batch) match { case Some(transformed) => // because we created a new transformed batch, we need to make sure we close it withResource(transformed) { _ => - scanAndWrite(transformed) + write(transformed) } case _ => - scanAndWrite(batch) + write(batch) } } @@ -223,9 +225,14 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, } } - private def scanAndWrite(batch: ColumnarBatch): Unit = { + private def scan(batch: ColumnarBatch): Unit = { withResource(GpuColumnVector.from(batch)) { table => scanTableBeforeWrite(table) + } + } + + private def write(batch: ColumnarBatch): Unit = { + withResource(GpuColumnVector.from(batch)) { table => anythingWritten = true tableWriter.write(table) }