From 356d32b033c7637ff87bdbaa771febc20fa66b94 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 1 Sep 2022 12:31:26 -0500 Subject: [PATCH] Change GpuKryoRegistrator to load the classes we want to register with the ShimLoader (#6475) Signed-off-by: Thomas Graves Signed-off-by: Thomas Graves --- .../com/nvidia/spark/rapids/GpuKryoRegistrator.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala index ce8bc1e69bf..85ebd084a15 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala @@ -20,13 +20,14 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.sql.rapids.execution.{SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} class GpuKryoRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { - kryo.register(classOf[SerializeConcatHostBuffersDeserializeBatch], - new KryoJavaSerializer()) - kryo.register(classOf[SerializeBatchDeserializeHostBuffer], - new KryoJavaSerializer()) + val allClassesToRegister = Seq( + "org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch", + "org.apache.spark.sql.rapids.execution.SerializeBatchDeserializeHostBuffer") + allClassesToRegister.foreach { classToRegister => + kryo.register(ShimLoader.loadClass(classToRegister), new KryoJavaSerializer()) + } } }