Skip to content

Commit

Permalink
Async write support for ORC writer
Browse files Browse the repository at this point in the history
  • Loading branch information
jihoonson committed Dec 11, 2024
1 parent 9f23763 commit 8d4a241
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
17 changes: 17 additions & 0 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,23 @@ def create_empty_df(spark, path):
conf={'spark.rapids.sql.format.orc.write.enabled': True})


hold_gpu_configs = [True, False]
@pytest.mark.parametrize('hold_gpu', hold_gpu_configs, ids=idfn)
def test_async_writer(spark_tmp_path, hold_gpu):
data_path = spark_tmp_path + '/ORC_DATA'
num_rows = 2048
num_cols = 10
orc_gen = [int_gen for _ in range(num_cols)]
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gen)]
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list, length=num_rows).coalesce(1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path,
conf={"spark.rapids.sql.asyncWrite.queryOutput.enabled": "true",
"spark.rapids.sql.batchSizeBytes": 4 * num_cols * 100, # 100 rows per batch
"spark.rapids.sql.queryOutput.holdGpuInTask": hold_gpu})


@ignore_order
@pytest.mark.skipif(is_before_spark_320(), reason="is only supported in Spark 320+")
def test_concurrent_writer(spark_tmp_path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging {
options: Map[String, String],
dataSchema: StructType): ColumnarOutputWriterFactory = {

val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
val sqlConf = sparkSession.sessionState.conf
val orcOptions = new OrcOptions(options, sqlConf)

val conf = job.getConfiguration

Expand All @@ -180,11 +181,16 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging {
conf.asInstanceOf[JobConf]
.setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])

val asyncOutputWriteEnabled = RapidsConf.ENABLE_ASYNC_OUTPUT_WRITE.get(sqlConf)
// holdGpuBetweenBatches is on by default if asyncOutputWriteEnabled is on
val holdGpuBetweenBatches = RapidsConf.ASYNC_QUERY_OUTPUT_WRITE_HOLD_GPU_IN_TASK.get(sqlConf)
.getOrElse(asyncOutputWriteEnabled)

new ColumnarOutputWriterFactory {
override def newInstance(path: String,
dataSchema: StructType,
context: TaskAttemptContext): ColumnarOutputWriter = {
new GpuOrcWriter(path, dataSchema, context)
new GpuOrcWriter(path, dataSchema, context, holdGpuBetweenBatches, asyncOutputWriteEnabled)
}

override def getFileExtension(context: TaskAttemptContext): String = {
Expand All @@ -203,10 +209,14 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging {
}
}

class GpuOrcWriter(override val path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends ColumnarOutputWriter(context, dataSchema, "ORC", true) {
class GpuOrcWriter(
override val path: String,
dataSchema: StructType,
context: TaskAttemptContext,
holdGpuBetweenBatches: Boolean,
useAsyncWrite: Boolean)
extends ColumnarOutputWriter(context, dataSchema, "ORC", true, holdGpuBetweenBatches,
useAsyncWrite) {

override val tableWriter: TableWriter = {
val builder = SchemaUtils
Expand Down

0 comments on commit 8d4a241

Please sign in to comment.