diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 98d5730b14e..fc0b7f1dc46 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.{Map => MutableMap} import scala.util.Try import scala.util.matching.Regex +import ai.rapids.cudf.{CudaException, CudaFatalException, CudfException} import com.nvidia.spark.rapids.python.PythonWorkerSemaphore import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason} @@ -288,19 +289,21 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { override def onTaskFailed(failureReason: TaskFailedReason): Unit = { failureReason match { case ef: ExceptionFailure => - val unrecoverableErrors = Seq("cudaErrorIllegalAddress", "cudaErrorLaunchTimeout", - "cudaErrorHardwareStackError", "cudaErrorIllegalInstruction", - "cudaErrorMisalignedAddress", "cudaErrorInvalidAddressSpace", "cudaErrorInvalidPc", - "cudaErrorLaunchFailure", "cudaErrorExternalDevice", "cudaErrorUnknown", - "cudaErrorECCUncorrectable") - if (unrecoverableErrors.exists(ef.description.contains(_)) || - unrecoverableErrors.exists(ef.toErrorString.contains(_))) { - logError("Stopping the Executor based on exception being a fatal CUDA error: " + - s"${ef.toErrorString}") - System.exit(20) + ef.exception match { + case Some(_: CudaFatalException) => + logError("Stopping the Executor based on exception being a fatal CUDA error: " + + s"${ef.toErrorString}") + System.exit(20) + case Some(_: CudaException) => + logDebug(s"Executor onTaskFailed because of a non-fatal CUDA error: " + + s"${ef.toErrorString}") + case Some(_: CudfException) => + logDebug(s"Executor onTaskFailed because of a CUDF error: ${ef.toErrorString}") + case _ => + logDebug(s"Executor onTaskFailed: ${ef.toErrorString}") } case other => - logDebug(s"Executor onTaskFailed not a CUDA fatal error: ${other.toString}") + logDebug(s"Executor onTaskFailed: ${other.toString}") } } }