diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala index f25628cb66b3..fc307bd0a755 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RmmRapidsRetryIterator.scala @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids import scala.collection.mutable +import ai.rapids.cudf.CudfColumnSizeOverflowException + import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion @@ -580,9 +582,14 @@ object RmmRapidsRetryIterator extends Logging { lastException = ex if (!topLevelIsRetry && !causedByRetry) { - // we want to throw early here, since we got an exception - // we were not prepared to handle - throw lastException + // If the exception is the result of a CUDF column size overflow, attempt split-retry. + ex match { + case _: CudfColumnSizeOverflowException => doSplit = true + case _ => + // we want to throw early here, since we got an exception + // we were not prepared to handle + throw lastException + } } // else another exception wrapped a retry. So we are going to try again }