diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 7e415c79a46..e78d7171dd7 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -336,6 +336,23 @@ def create_empty_df(spark, path): conf={'spark.rapids.sql.format.orc.write.enabled': True}) +hold_gpu_configs = [True, False] +@pytest.mark.parametrize('hold_gpu', hold_gpu_configs, ids=idfn) +def test_async_writer(spark_tmp_path, hold_gpu): + data_path = spark_tmp_path + '/ORC_DATA' + num_rows = 2048 + num_cols = 10 + orc_gen = [int_gen for _ in range(num_cols)] + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gen)] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list, length=num_rows).coalesce(1).write.orc(path), + lambda spark, path: spark.read.orc(path), + data_path, + conf={"spark.rapids.sql.asyncWrite.queryOutput.enabled": "true", + "spark.rapids.sql.batchSizeBytes": 4 * num_cols * 100, # 100 rows per batch + "spark.rapids.sql.queryOutput.holdGpuInTask": hold_gpu}) + + @ignore_order @pytest.mark.skipif(is_before_spark_320(), reason="is only supported in Spark 320+") def test_concurrent_writer(spark_tmp_path): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index 1d4bc66a1da..f64ac25a014 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -169,7 +169,8 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { options: Map[String, String], dataSchema: StructType): ColumnarOutputWriterFactory = { - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) + val sqlConf = sparkSession.sessionState.conf + val orcOptions = new OrcOptions(options, sqlConf) val conf = job.getConfiguration @@ -180,11 +181,16 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { conf.asInstanceOf[JobConf] .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) + val asyncOutputWriteEnabled = RapidsConf.ENABLE_ASYNC_OUTPUT_WRITE.get(sqlConf) + // holdGpuBetweenBatches is on by default if asyncOutputWriteEnabled is on + val holdGpuBetweenBatches = RapidsConf.ASYNC_QUERY_OUTPUT_WRITE_HOLD_GPU_IN_TASK.get(sqlConf) + .getOrElse(asyncOutputWriteEnabled) + new ColumnarOutputWriterFactory { override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): ColumnarOutputWriter = { - new GpuOrcWriter(path, dataSchema, context) + new GpuOrcWriter(path, dataSchema, context, holdGpuBetweenBatches, asyncOutputWriteEnabled) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -203,10 +209,14 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { } } -class GpuOrcWriter(override val path: String, - dataSchema: StructType, - context: TaskAttemptContext) - extends ColumnarOutputWriter(context, dataSchema, "ORC", true) { +class GpuOrcWriter( + override val path: String, + dataSchema: StructType, + context: TaskAttemptContext, + holdGpuBetweenBatches: Boolean, + useAsyncWrite: Boolean) + extends ColumnarOutputWriter(context, dataSchema, "ORC", true, holdGpuBetweenBatches, + useAsyncWrite) { override val tableWriter: TableWriter = { val builder = SchemaUtils