diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml
index d746cb791..2392f8dd2 100644
--- a/.github/workflows/unittests.yml
+++ b/.github/workflows/unittests.yml
@@ -97,9 +97,7 @@ jobs:
- name: Run unit tests
run: |
mvn clean install -N
- cd arrow-data-source
- mvn clean install -DskipTests -Dbuild_arrow=OFF
- cd ..
+ mvn clean install -DskipTests -Dbuild_arrow=OFF -pl arrow-data-source
mvn clean package -P full-scala-compiler -Phadoop-2.7.4 -am -pl native-sql-engine/core -DskipTests -Dbuild_arrow=OFF
mvn test -P full-scala-compiler -DmembersOnlySuites=org.apache.spark.sql.nativesql -am -DfailIfNoTests=false -Dexec.skip=true -DargLine="-Dspark.test.home=/tmp/spark-3.1.1-bin-hadoop2.7" &> log-file.log
echo '#!/bin/bash' > grep.sh
@@ -144,9 +142,7 @@ jobs:
- name: Run unit tests
run: |
mvn clean install -N
- cd arrow-data-source
- mvn clean install -DskipTests -Dbuild_arrow=OFF
- cd ..
+ mvn clean install -DskipTests -Dbuild_arrow=OFF -pl arrow-data-source
mvn clean package -P full-scala-compiler -Phadoop-3.2 -am -pl native-sql-engine/core -DskipTests -Dbuild_arrow=OFF
mvn test -P full-scala-compiler -DmembersOnlySuites=org.apache.spark.sql.nativesql -am -DfailIfNoTests=false -Dexec.skip=true -DargLine="-Dspark.test.home=/tmp/spark-3.1.1-bin-hadoop3.2" &> log-file.log
echo '#!/bin/bash' > grep.sh
diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala b/arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala
similarity index 100%
rename from arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala
rename to arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala
diff --git a/arrow-data-source/common/src/main/scala/com/intel/oap/sql/execution/RowToArrowColumnarExec.scala b/arrow-data-source/common/src/main/scala/com/intel/oap/sql/execution/RowToArrowColumnarExec.scala
index b9b58fcb9..a6396b0c5 100644
--- a/arrow-data-source/common/src/main/scala/com/intel/oap/sql/execution/RowToArrowColumnarExec.scala
+++ b/arrow-data-source/common/src/main/scala/com/intel/oap/sql/execution/RowToArrowColumnarExec.scala
@@ -310,4 +310,8 @@ case class RowToArrowColumnarExec(child: SparkPlan) extends UnaryExecNode {
}
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): RowToArrowColumnarExec =
+ copy(child = newChild)
}
diff --git a/arrow-data-source/parquet/pom.xml b/arrow-data-source/parquet/pom.xml
index 01cc5110d..28c38e76a 100644
--- a/arrow-data-source/parquet/pom.xml
+++ b/arrow-data-source/parquet/pom.xml
@@ -22,6 +22,12 @@
spark-arrow-datasource-standard
${project.version}
+
+ com.intel.oap
+ spark-sql-columnar-shims-common
+ ${project.version}
+ provided
+
diff --git a/arrow-data-source/parquet/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/arrow-data-source/parquet/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 836e9f2c9..88a432daf 100644
--- a/arrow-data-source/parquet/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/arrow-data-source/parquet/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import scala.util.{Failure, Try}
import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat
+import com.intel.oap.sql.shims.SparkShimLoader
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, JobID, OutputCommitter, TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
@@ -274,6 +275,7 @@ class ParquetFileFormat
val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
val isCaseSensitive = sqlConf.caseSensitiveAnalysis
+ val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf)
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
@@ -292,11 +294,17 @@ class ParquetFileFormat
lazy val footerFileMetaData =
ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData
+
+ val datetimeRebaseMode =
+ SparkShimLoader.getSparkShims.getDatetimeRebaseMode(footerFileMetaData, parquetOptions)
+
// Try to push down filters when filter push-down is enabled.
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
- val parquetFilters = new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp,
- pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
+ val parquetFilters =
+ SparkShimLoader.getSparkShims.newParquetFilters(parquetSchema: MessageType,
+ pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith,
+ pushDownInFilterThreshold, isCaseSensitive, footerFileMetaData, parquetOptions)
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
@@ -322,10 +330,6 @@ class ParquetFileFormat
None
}
- val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
- footerFileMetaData.getKeyValueMetaData.get,
- SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ))
-
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext =
new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
@@ -337,12 +341,14 @@ class ParquetFileFormat
}
val taskContext = Option(TaskContext.get())
if (enableVectorizedReader) {
- val vectorizedReader = new VectorizedParquetRecordReader(
- convertTz.orNull,
- datetimeRebaseMode.toString,
- "",
- enableOffHeapColumnVector && taskContext.isDefined,
- capacity)
+ val vectorizedReader = SparkShimLoader.getSparkShims
+ .newVectorizedParquetRecordReader(
+ convertTz.orNull,
+ footerFileMetaData,
+ parquetOptions,
+ enableOffHeapColumnVector && taskContext.isDefined,
+ capacity)
+
val iter = new RecordReaderIterator(vectorizedReader)
// SPARK-23457 Register a task completion listener before `initialization`.
taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
@@ -358,8 +364,8 @@ class ParquetFileFormat
} else {
logDebug(s"Falling back to parquet-mr")
// ParquetRecordReader returns InternalRow
- val readSupport = new ParquetReadSupport(
- convertTz, enableVectorizedReader = false, datetimeRebaseMode, SQLConf.LegacyBehaviorPolicy.LEGACY)
+ val readSupport = SparkShimLoader.getSparkShims.newParquetReadSupport(
+ convertTz, false, footerFileMetaData, parquetOptions)
val reader = if (pushed.isDefined && enableRecordFilter) {
val parquetFilter = FilterCompat.get(pushed.get, null)
new ParquetRecordReader[InternalRow](readSupport, parquetFilter)
diff --git a/arrow-data-source/pom.xml b/arrow-data-source/pom.xml
index 3dc19f574..c58684a6d 100644
--- a/arrow-data-source/pom.xml
+++ b/arrow-data-source/pom.xml
@@ -42,7 +42,7 @@
-
+
javax.servlet
javax.servlet-api
3.1.0
diff --git a/arrow-data-source/standard/pom.xml b/arrow-data-source/standard/pom.xml
index 6b6171723..4b37a639f 100644
--- a/arrow-data-source/standard/pom.xml
+++ b/arrow-data-source/standard/pom.xml
@@ -18,6 +18,12 @@
spark-arrow-datasource-common
${project.version}
+
+ com.intel.oap
+ spark-sql-columnar-shims-common
+ ${project.version}
+ provided
+
diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala
index 7f1d6e153..8c4173d61 100644
--- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala
+++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/ArrowWriteExtension.scala
@@ -137,6 +137,10 @@ object ArrowWriteExtension {
private case class ColumnarToFakeRowLogicAdaptor(child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: LogicalPlan): ColumnarToFakeRowLogicAdaptor =
+ copy(child = newChild)
}
private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition {
@@ -149,6 +153,10 @@ object ArrowWriteExtension {
}
override def output: Seq[Attribute] = child.output
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarToFakeRowAdaptor =
+ copy(child = newChild)
}
case class SimpleStrategy() extends Strategy {
diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
index 47a21048e..32632bfe0 100644
--- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
+++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala
@@ -94,6 +94,11 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
override def close(): Unit = {
writeQueue.close()
}
+
+ // Do NOT add override keyword for compatibility on spark 3.1.
+ def path(): String = {
+ path
+ }
}
}
}
diff --git a/native-sql-engine/core/pom.xml b/native-sql-engine/core/pom.xml
index 6012d746e..1cdd279de 100644
--- a/native-sql-engine/core/pom.xml
+++ b/native-sql-engine/core/pom.xml
@@ -44,6 +44,33 @@
${build_protobuf}
${build_jemalloc}
+
+
+
+ spark-3.1.1
+
+ true
+
+
+
+ com.intel.oap
+ spark-sql-columnar-shims-spark311
+ ${project.version}
+
+
+
+
+ spark-3.2
+
+
+ com.intel.oap
+ spark-sql-columnar-shims-spark321
+ ${project.version}
+
+
+
+
+
@@ -166,19 +193,19 @@
com.fasterxml.jackson.core
jackson-core
- 2.10.0
+ ${jackson.version}
test
com.fasterxml.jackson.core
jackson-annotations
- 2.10.0
+ ${jackson.version}
test
com.fasterxml.jackson.core
jackson-databind
- 2.10.0
+ ${jackson.version}
test
@@ -299,7 +326,7 @@
com.intel.oap
spark-sql-columnar-shims-common
${project.version}
- compile
+ provided
org.apache.logging.log4j
diff --git a/native-sql-engine/core/src/main/java/com/intel/oap/datasource/VectorizedParquetArrowReader.java b/native-sql-engine/core/src/main/java/com/intel/oap/datasource/VectorizedParquetArrowReader.java
index ff238fc32..cdd6f4b1b 100644
--- a/native-sql-engine/core/src/main/java/com/intel/oap/datasource/VectorizedParquetArrowReader.java
+++ b/native-sql-engine/core/src/main/java/com/intel/oap/datasource/VectorizedParquetArrowReader.java
@@ -24,7 +24,7 @@
import java.util.*;
import java.util.stream.Collectors;
-import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader;
+import org.apache.spark.sql.execution.datasources.VectorizedParquetRecordReaderChild;
import com.intel.oap.datasource.parquet.ParquetReader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
@@ -46,7 +46,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader {
+public class VectorizedParquetArrowReader extends VectorizedParquetRecordReaderChild {
private static final Logger LOG =
LoggerFactory.getLogger(VectorizedParquetArrowReader.class);
private ParquetReader reader = null;
@@ -70,7 +70,8 @@ public class VectorizedParquetArrowReader extends VectorizedParquetRecordReader
public VectorizedParquetArrowReader(String path, ZoneId convertTz, boolean useOffHeap,
int capacity, StructType sourceSchema, StructType readDataSchema, String tmp_dir) {
- super(convertTz, "CORRECTED", "LEGACY", useOffHeap, capacity);
+ // TODO: datetimeRebaseTz & int96RebaseTz are set to "", needs to check the impact.
+ super(convertTz, "CORRECTED", "", "LEGACY", "", useOffHeap, capacity);
this.capacity = capacity;
this.path = path;
this.tmp_dir = tmp_dir;
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/CoalesceBatchesExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/CoalesceBatchesExec.scala
index 476c48e6b..08d9422a9 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/CoalesceBatchesExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/CoalesceBatchesExec.scala
@@ -147,6 +147,10 @@ case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode {
new CloseableColumnBatchIterator(res)
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): CoalesceBatchesExec =
+ copy(child = newChild)
}
object CoalesceBatchesExec {
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala
index 0a7243801..66e729137 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBasicPhysicalOperators.scala
@@ -100,8 +100,11 @@ case class ColumnarConditionProjectExec(
}
}
- def isNullIntolerant(expr: Expression): Boolean = expr match {
- case e: NullIntolerant => e.children.forall(isNullIntolerant)
+ // In spark 3.2, PredicateHelper has already introduced isNullIntolerant with completely same
+ // code. If we use the same method name, override keyword is required. But in spark3.1, no
+ // method is overridden. So we use an independent method name.
+ def isNullIntolerantInternal(expr: Expression): Boolean = expr match {
+ case e: NullIntolerant => e.children.forall(isNullIntolerantInternal)
case _ => false
}
@@ -110,7 +113,7 @@ case class ColumnarConditionProjectExec(
val notNullAttributes = if (condition != null) {
val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
- case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
+ case IsNotNull(a) => isNullIntolerantInternal(a) && a.references.subsetOf(child.outputSet)
case _ => false
}
notNullPreds.flatMap(_.references).distinct.map(_.exprId)
@@ -267,6 +270,9 @@ case class ColumnarConditionProjectExec(
}
}
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarConditionProjectExec =
+ copy(child = newChild)
}
case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
@@ -308,6 +314,10 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
: org.apache.spark.rdd.RDD[org.apache.spark.sql.catalyst.InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}
+
+ // For spark 3.2.
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): ColumnarUnionExec =
+ copy(children = newChildren)
}
//TODO(): consolidate locallimit and globallimit
@@ -380,6 +390,10 @@ case class ColumnarLocalLimitExec(limit: Int, child: SparkPlan) extends LimitExe
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}
+ protected def withNewChildInternal(newChild: SparkPlan):
+ ColumnarLocalLimitExec =
+ copy(child = newChild)
+
}
case class ColumnarGlobalLimitExec(limit: Int, child: SparkPlan) extends LimitExec {
@@ -451,4 +465,8 @@ case class ColumnarGlobalLimitExec(limit: Int, child: SparkPlan) extends LimitEx
: org.apache.spark.rdd.RDD[org.apache.spark.sql.catalyst.InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}
+
+ protected def withNewChildInternal(newChild: SparkPlan):
+ ColumnarGlobalLimitExec =
+ copy(child = newChild)
}
\ No newline at end of file
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala
index f5b3109a6..ed4e00bf2 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBroadcastHashJoinExec.scala
@@ -21,9 +21,11 @@ import com.google.common.collect.Lists
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression._
import com.intel.oap.vectorized.{ExpressionEvaluator, _}
+import com.intel.oap.sql.shims.SparkShimLoader
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -35,6 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils}
@@ -59,7 +62,7 @@ case class ColumnarBroadcastHashJoinExec(
nullAware: Boolean = false)
extends BaseJoinExec
with ColumnarCodegenSupport
- with ShuffledJoin {
+ with ColumnarShuffledJoin {
val sparkConf = sparkContext.getConf
val numaBindingInfo = GazellePluginConfig.getConf.numaBindingInfo
@@ -89,6 +92,9 @@ case class ColumnarBroadcastHashJoinExec(
}
buildCheck()
+ // A method in ShuffledJoin of spark3.2.
+ def isSkewJoin: Boolean = false
+
def buildCheck(): Unit = {
joinType match {
case _: InnerLike =>
@@ -145,13 +151,13 @@ case class ColumnarBroadcastHashJoinExec(
throw new UnsupportedOperationException(
s"ColumnarBroadcastHashJoinExec doesn't support doExecute")
}
-
val isNullAwareAntiJoin : Boolean = nullAware
- val broadcastHashJoinOutputPartitioningExpandLimit: Int = sqlContext.getConf(
- "spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit").trim().toInt
-
override lazy val outputPartitioning: Partitioning = {
+ val broadcastHashJoinOutputPartitioningExpandLimit: Int =
+ SparkShimLoader
+ .getSparkShims
+ .getBroadcastHashJoinOutputPartitioningExpandLimit(this: SparkPlan)
joinType match {
case _: InnerLike if broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
@@ -193,7 +199,10 @@ case class ColumnarBroadcastHashJoinExec(
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
- val maxNumCombinations = broadcastHashJoinOutputPartitioningExpandLimit
+ val maxNumCombinations =
+ SparkShimLoader
+ .getSparkShims
+ .getBroadcastHashJoinOutputPartitioningExpandLimit(this: SparkPlan)
var currentNumCombinations = 0
def generateExprCombinations(current: Seq[Expression],
@@ -640,4 +649,9 @@ case class ColumnarBroadcastHashJoinExec(
}
}
+
+ // For spark 3.2.
+ protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan):
+ ColumnarBroadcastHashJoinExec =
+ copy(left = newLeft, right = newRight)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarCoalesceExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarCoalesceExec.scala
index 2851b4a5d..49e54577f 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarCoalesceExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarCoalesceExec.scala
@@ -69,6 +69,10 @@ case class ColumnarCoalesceExec(numPartitions: Int, child: SparkPlan) extends Un
child.executeColumnar().coalesce(numPartitions, shuffle = false)
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarCoalesceExec =
+ copy(child = newChild)
}
object ColumnarCoalesceExec {
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala
index a92365e7b..c82234650 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarExpandExec.scala
@@ -133,4 +133,8 @@ case class ColumnarExpandExec(
new CloseableColumnBatchIterator(res)
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarExpandExec =
+ copy(child = newChild)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala
index 29f629fa1..72b343a2f 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala
@@ -33,7 +33,6 @@ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{UserAddedJarUtils, Utils, ExecutorManager}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -295,7 +294,7 @@ case class ColumnarHashAggregateExec(
val aggregateFunc = exp.aggregateFunction
val out_res = aggregateFunc.children.head.asInstanceOf[Literal].value
aggregateFunc match {
- case Sum(_) =>
+ case _: Sum =>
mode match {
case Partial | PartialMerge =>
val sum = aggregateFunc.asInstanceOf[Sum]
@@ -314,7 +313,7 @@ case class ColumnarHashAggregateExec(
putDataIntoVector(resultColumnVectors, out_res, idx)
idx += 1
}
- case Average(_) =>
+ case _: Average =>
mode match {
case Partial | PartialMerge =>
putDataIntoVector(resultColumnVectors, out_res, idx) // sum
@@ -389,7 +388,7 @@ case class ColumnarHashAggregateExec(
var idx = 0
for (expr <- aggregateExpressions) {
expr.aggregateFunction match {
- case Average(_) | StddevSamp(_, _) | Sum(_) | Max(_) | Min(_) =>
+ case _: Average | _: Sum | StddevSamp(_, _) | Max(_) | Min(_) =>
expr.mode match {
case Final =>
resultColumnVectors(idx).putNull(0)
@@ -471,7 +470,7 @@ case class ColumnarHashAggregateExec(
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
- case Average(_) =>
+ case _: Average =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType)
val avg = aggregateFunc.asInstanceOf[Average]
@@ -493,7 +492,7 @@ case class ColumnarHashAggregateExec(
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar Average")
}
- case Sum(_) =>
+ case _: Sum =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType)
val sum = aggregateFunc.asInstanceOf[Sum]
@@ -695,4 +694,8 @@ case class ColumnarHashAggregateExec(
s"ColumnarHashAggregate(keys=$keyString, functions=$functionString)"
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarHashAggregateExec =
+ copy(child = newChild)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala
index 9a723e505..35a3de30f 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledHashJoinExec.scala
@@ -69,7 +69,7 @@ case class ColumnarShuffledHashJoinExec(
projectList: Seq[NamedExpression] = null)
extends BaseJoinExec
with ColumnarCodegenSupport
- with ShuffledJoin {
+ with ColumnarShuffledJoin {
val sparkConf = sparkContext.getConf
val numaBindingInfo = GazellePluginConfig.getConf.numaBindingInfo
@@ -82,6 +82,9 @@ case class ColumnarShuffledHashJoinExec(
buildCheck()
+ // For spark 3.2.
+ def isSkewJoin: Boolean = false
+
protected lazy val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
@@ -583,4 +586,9 @@ case class ColumnarShuffledHashJoinExec(
}
}
+ // For spark 3.2.
+ protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan):
+ ColumnarShuffledHashJoinExec =
+ copy(left = newLeft, right = newRight)
+
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledJoin.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledJoin.scala
new file mode 100644
index 000000000..ae70f6eb8
--- /dev/null
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarShuffledJoin.scala
@@ -0,0 +1,82 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.oap.execution
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.joins.BaseJoinExec
+
+/**
+ * This code for this trait is ported from ShuffledJoin of spark. From spark 3.2,
+ * ShuffledJoin extends JoinCodegenSupport, which is not applicable for
+ * ColumnarShuffledHashJoinExec & ColumnarBroadcastHashJoinExec. So we creat this
+ * trait for compatibility on spark3.2.
+ */
+trait ColumnarShuffledJoin extends BaseJoinExec {
+ def isSkewJoin: Boolean
+
+ override def nodeName: String = {
+ if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
+ }
+
+ override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator
+
+ override def requiredChildDistribution: Seq[Distribution] = {
+ if (isSkewJoin) {
+ // We re-arrange the shuffle partitions to deal with skew join, and the new children
+ // partitioning doesn't satisfy `HashClusteredDistribution`.
+ UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ } else {
+ HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
+ }
+ }
+
+ override def outputPartitioning: Partitioning = joinType match {
+ case _: InnerLike =>
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case LeftExistence(_) => left.outputPartitioning
+ case x =>
+ throw new IllegalArgumentException(
+ s"ShuffledJoin should not take $x as the JoinType")
+ }
+
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case _: InnerLike =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ (left.output ++ right.output).map(_.withNullability(true))
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
+ case LeftExistence(_) =>
+ left.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"${getClass.getSimpleName} not take $x as the JoinType")
+ }
+ }
+}
\ No newline at end of file
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortExec.scala
index 95352b0f2..ed78fa8ff 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortExec.scala
@@ -240,4 +240,7 @@ case class ColumnarSortExec(
}
}
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarSortExec =
+ copy(child = newChild)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala
index 8c1f669b5..d0e4d018c 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarSortMergeJoinExec.scala
@@ -458,4 +458,9 @@ case class ColumnarSortMergeJoinExec(
new CloseableColumnBatchIterator(vjoinResult)
}
}
+
+ // For spark 3.2.
+ protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan):
+ ColumnarSortMergeJoinExec =
+ copy(left = newLeft, right = newRight)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala
index fae78cb5a..59d32ef20 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWholeStageCodegenExec.scala
@@ -594,4 +594,8 @@ case class ColumnarWholeStageCodegenExec(child: SparkPlan)(val codegenStageId: I
new CloseableColumnBatchIterator(resIter)
}
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarWholeStageCodegenExec =
+ copy(child = newChild)(codegenStageId)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala
index 3b4709ae9..4ae836a09 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala
@@ -345,6 +345,10 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException()
}
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarWindowExec =
+ copy(child = newChild)
}
object ColumnarWindowExec extends Logging {
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/DataToArrowColumnarExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/DataToArrowColumnarExec.scala
index 5ecb4cfcb..467750188 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/DataToArrowColumnarExec.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/DataToArrowColumnarExec.scala
@@ -110,4 +110,7 @@ case class DataToArrowColumnarExec(child: SparkPlan, numPartitions: Int) extends
child.executeBroadcast[ColumnarHashedRelation]())
}
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): DataToArrowColumnarExec =
+ copy(child = newChild)
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarDateTimeExpressions.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarDateTimeExpressions.scala
index 4f4884654..04d8dc252 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarDateTimeExpressions.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarDateTimeExpressions.scala
@@ -565,6 +565,11 @@ object ColumnarDateTimeExpressions {
}
}
}
+
+ // For spark 3.2.
+ protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression):
+ ColumnarGetTimestamp =
+ copy(leftChild = newLeft, rightChild = newRight)
}
class ColumnarFromUnixTime(left: Expression, right: Expression)
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala
index b752e437b..9d625a95a 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala
@@ -385,8 +385,6 @@ object ColumnarExpressionConverter extends Logging {
containsSubquery(i.value)
case ss: Substring =>
containsSubquery(ss.str) || containsSubquery(ss.pos) || containsSubquery(ss.len)
- case oaps: com.intel.oap.expression.ColumnarScalarSubquery =>
- return true
case s: org.apache.spark.sql.execution.ScalarSubquery =>
return true
case c: Concat =>
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala
index 91ff8e3ff..9466570ad 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala
@@ -99,7 +99,7 @@ class ColumnarHashAggregation(
val mode = aggregateExpression.mode
try {
aggregateFunc match {
- case Average(_) =>
+ case _: Average =>
mode match {
case Partial =>
val childrenColumnarFuncNodeList =
@@ -126,7 +126,7 @@ class ColumnarHashAggregation(
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
- case Sum(_) =>
+ case _: Sum =>
mode match {
case Partial =>
val childrenColumnarFuncNodeList =
@@ -243,7 +243,7 @@ class ColumnarHashAggregation(
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
- case Average(_) =>
+ case _: Average =>
mode match {
case Partial => {
val avg = aggregateFunc.asInstanceOf[Average]
@@ -270,7 +270,7 @@ class ColumnarHashAggregation(
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
- case Sum(_) =>
+ case _: Sum =>
mode match {
case Partial | PartialMerge => {
val sum = aggregateFunc.asInstanceOf[Sum]
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarSubquery.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarSubquery.scala
deleted file mode 100644
index 054a61a15..000000000
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarSubquery.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.intel.oap.expression
-
-import com.google.common.collect.Lists
-import org.apache.arrow.gandiva.evaluator._
-import org.apache.arrow.gandiva.exceptions.GandivaException
-import org.apache.arrow.gandiva.expression._
-import org.apache.arrow.vector.types.DateUnit
-import org.apache.arrow.vector.types.pojo.ArrowType
-import org.apache.arrow.vector.types.pojo.Field
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.execution.BaseSubqueryExec
-import org.apache.spark.sql.execution.ExecSubqueryExpression
-import org.apache.spark.sql.execution.ScalarSubquery
-import org.apache.spark.sql.types._
-
-import scala.collection.mutable.ListBuffer
-
-class ColumnarScalarSubquery(
- query: ScalarSubquery)
- extends Expression with ColumnarExpression {
-
- override def dataType: DataType = query.dataType
- buildCheck()
- override def children: Seq[Expression] = Nil
- override def nullable: Boolean = true
- override def toString: String = query.toString
- override def eval(input: InternalRow): Any = query.eval(input)
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = query.doGenCode(ctx, ev)
- override def canEqual(that: Any): Boolean = query.canEqual(that)
- override def productArity: Int = query.productArity
- override def productElement(n: Int): Any = query.productElement(n)
- override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
- val value = query.eval(null)
- val resultType = CodeGeneration.getResultType(query.dataType)
- query.dataType match {
- case t: StringType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- (TreeBuilder.makeStringLiteral(value.toString().asInstanceOf[String]), resultType)
- }
- case t: IntegerType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType)
- }
- case t: LongType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Long]), resultType)
- }
- case t: DoubleType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Double]), resultType)
- }
- case d: DecimalType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- val v = value.asInstanceOf[Decimal]
- (TreeBuilder.makeDecimalLiteral(v.toString, v.precision, v.scale), resultType)
- }
- case d: DateType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- val origIntNode = TreeBuilder.makeLiteral(value.asInstanceOf[Integer])
- val dateNode = TreeBuilder.makeFunction("castDATE",
- Lists.newArrayList(origIntNode), new ArrowType.Date(DateUnit.DAY))
- (dateNode, new ArrowType.Date(DateUnit.DAY))
- }
- case b: BooleanType =>
- value match {
- case null =>
- (TreeBuilder.makeNull(resultType), resultType)
- case _ =>
- (TreeBuilder.makeLiteral(value.asInstanceOf[java.lang.Boolean]), resultType)
- }
- }
- }
- def buildCheck(): Unit = {
- val supportedTypes =
- List(StringType, IntegerType, LongType, DoubleType, DateType, BooleanType)
- if (supportedTypes.indexOf(dataType) == -1 &&
- !dataType.isInstanceOf[DecimalType]) {
- throw new UnsupportedOperationException(
- s"$dataType is not supported in ColumnarScalarSubquery")
- }
- }
-}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala
index 339dda82b..71c2fbda4 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/ColumnarOverrides.scala
@@ -19,14 +19,15 @@ package com.intel.oap.extension
import com.intel.oap.GazellePluginConfig
import com.intel.oap.GazelleSparkExtensionsInjector
-
-import scala.collection.mutable
import com.intel.oap.execution._
import com.intel.oap.extension.columnar.ColumnarGuardRule
import com.intel.oap.extension.columnar.RowGuard
import com.intel.oap.sql.execution.RowToArrowColumnarExec
+import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark.{MapOutputStatistics, SparkContext}
import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
@@ -50,15 +51,20 @@ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, ColumnarArrowEvalPythonExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
import org.apache.spark.util.ShufflePartitionUtils
+import scala.collection.mutable
+
case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val columnarConf: GazellePluginConfig = GazellePluginConfig.getSessionConf
var isSupportAdaptive: Boolean = true
def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match {
- case RowGuard(child: CustomShuffleReaderExec) =>
+ case RowGuard(child: SparkPlan)
+ if SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(child) =>
replaceWithColumnarPlan(child)
case plan: RowGuard =>
val actualPlan = plan.child match {
@@ -83,7 +89,27 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
ColumnarArrowEvalPythonExec(plan.udfs, plan.resultAttrs, columnarChild, plan.evalType)
case plan: BatchScanExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- new ColumnarBatchScanExec(plan.output, plan.scan)
+ val runtimeFilters = SparkShimLoader.getSparkShims.getRuntimeFilters(plan)
+ new ColumnarBatchScanExec(plan.output, plan.scan, runtimeFilters) {
+ // This method is a commonly shared implementation for ColumnarBatchScanExec.
+ // We move it outside of shim layer to break the cyclic dependency caused by
+ // ColumnarDataSourceRDD.
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val numOutputBatches = longMetric("numOutputBatches")
+ val scanTime = longMetric("scanTime")
+ val inputSize = longMetric("inputSize")
+ val inputColumnarRDD =
+ new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory,
+ true, scanTime, numInputBatches, inputSize, tmpDir)
+ inputColumnarRDD.map { r =>
+ numOutputRows += r.numRows()
+ numOutputBatches += 1
+ r
+ }
+ }
+ }
case plan: CoalesceExec =>
ColumnarCoalesceExec(plan.numPartitions, replaceWithColumnarPlan(plan.child))
case plan: InMemoryTableScanExec =>
@@ -236,27 +262,34 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
plan
- case plan: CustomShuffleReaderExec if columnarConf.enableColumnarShuffle =>
- plan.child match {
+ case plan
+ if (SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(plan)
+ && columnarConf.enableColumnarShuffle) =>
+ val child = SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan)
+ val partitionSpecs =
+ SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(plan)
+ child match {
case shuffle: ColumnarShuffleExchangeAdaptor =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
CoalesceBatchesExec(
- ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs))
- case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeAdaptor) =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- CoalesceBatchesExec(
- ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs))
- case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) =>
- reused match {
- case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeAdaptor) =>
+ ColumnarCustomShuffleReaderExec(child, partitionSpecs))
+ // Use the below code to replace the above to realize compatibility on spark 3.1 & 3.2.
+ case shuffleQueryStageExec: ShuffleQueryStageExec =>
+ shuffleQueryStageExec.plan match {
+ case s: ColumnarShuffleExchangeAdaptor =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
+ CoalesceBatchesExec(
+ ColumnarCustomShuffleReaderExec(child, partitionSpecs))
+ case r @ ReusedExchangeExec(_, s: ColumnarShuffleExchangeAdaptor) =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
CoalesceBatchesExec(
ColumnarCustomShuffleReaderExec(
- plan.child,
- plan.partitionSpecs))
+ child,
+ partitionSpecs))
case _ =>
plan
}
+
case _ =>
plan
}
@@ -296,17 +329,15 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
def fallBackBroadcastQueryStage(curPlan: BroadcastQueryStageExec): BroadcastQueryStageExec = {
curPlan.plan match {
case originalBroadcastPlan: ColumnarBroadcastExchangeAdaptor =>
- BroadcastQueryStageExec(
- curPlan.id,
- BroadcastExchangeExec(
- originalBroadcastPlan.mode,
- DataToArrowColumnarExec(originalBroadcastPlan, 1)))
+ val newBroadcast = BroadcastExchangeExec(
+ originalBroadcastPlan.mode,
+ DataToArrowColumnarExec(originalBroadcastPlan, 1))
+ SparkShimLoader.getSparkShims.newBroadcastQueryStageExec(curPlan.id, newBroadcast)
case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeAdaptor) =>
- BroadcastQueryStageExec(
- curPlan.id,
- BroadcastExchangeExec(
- originalBroadcastPlan.mode,
- DataToArrowColumnarExec(curPlan.plan, 1)))
+ val newBroadcast = BroadcastExchangeExec(
+ originalBroadcastPlan.mode,
+ DataToArrowColumnarExec(curPlan.plan, 1))
+ SparkShimLoader.getSparkShims.newBroadcastQueryStageExec(curPlan.id, newBroadcast)
case _ =>
curPlan
}
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/StrategyOverrides.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/StrategyOverrides.scala
index 148661e38..87e6b3d12 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/StrategyOverrides.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/StrategyOverrides.scala
@@ -100,6 +100,10 @@ case class LocalWindowExec(
// todo implement this to fall back
throw new UnsupportedOperationException()
}
+
+ protected def withNewChildInternal(newChild: SparkPlan):
+ LocalWindowExec =
+ copy(child = newChild)
}
object LocalWindowApply extends Strategy {
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala
index 5da11f660..68e186b8f 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala
+++ b/native-sql-engine/core/src/main/scala/com/intel/oap/extension/columnar/ColumnarGuardRule.scala
@@ -20,6 +20,8 @@ package com.intel.oap.extension.columnar
import com.intel.oap.GazellePluginConfig
import com.intel.oap.execution._
import com.intel.oap.extension.LocalWindowExec
+import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -35,6 +37,7 @@ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.python.ColumnarArrowEvalPythonExec
import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.vectorized.ColumnarBatch
case class RowGuard(child: SparkPlan) extends SparkPlan {
def output: Seq[Attribute] = child.output
@@ -42,6 +45,11 @@ case class RowGuard(child: SparkPlan) extends SparkPlan {
throw new UnsupportedOperationException
}
def children: Seq[SparkPlan] = Seq(child)
+
+ // For spark 3.2.
+ // TODO: can newChild have more than one element?
+ protected def withNewChildrenInternal(newChild: IndexedSeq[SparkPlan]): RowGuard =
+ copy(child = newChild.head)
}
case class ColumnarGuardRule() extends Rule[SparkPlan] {
@@ -72,7 +80,27 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
ColumnarArrowEvalPythonExec(plan.udfs, plan.resultAttrs, plan.child, plan.evalType)
case plan: BatchScanExec =>
if (!enableColumnarBatchScan) return false
- new ColumnarBatchScanExec(plan.output, plan.scan)
+ val runtimeFilters = SparkShimLoader.getSparkShims.getRuntimeFilters(plan)
+ new ColumnarBatchScanExec(plan.output, plan.scan, runtimeFilters) {
+ // This method is a commonly shared implementation for ColumnarBatchScanExec.
+ // We move it outside of shim layer to break the cyclic dependency caused by
+ // ColumnarDataSourceRDD.
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val numInputBatches = longMetric("numInputBatches")
+ val numOutputBatches = longMetric("numOutputBatches")
+ val scanTime = longMetric("scanTime")
+ val inputSize = longMetric("inputSize")
+ val inputColumnarRDD =
+ new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory,
+ true, scanTime, numInputBatches, inputSize, tmpDir)
+ inputColumnarRDD.map { r =>
+ numOutputRows += r.numRows()
+ numOutputBatches += 1
+ r
+ }
+ }
+ }
case plan: FileSourceScanExec =>
if (plan.supportsColumnar) {
return false
@@ -141,13 +169,16 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
left match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
- case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
- new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
- case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
- plan match {
- case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
- new ColumnarBroadcastExchangeExec(b.mode, b.child)
- case _ =>
+ case broadcastQueryStageExec: BroadcastQueryStageExec =>
+ broadcastQueryStageExec.plan match {
+ case plan: BroadcastExchangeExec =>
+ new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
+ case plan: ReusedExchangeExec =>
+ plan match {
+ case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
+ new ColumnarBroadcastExchangeExec(b.mode, b.child)
+ case _ =>
+ }
}
case _ =>
}
@@ -155,13 +186,16 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
right match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
- case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
- new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
- case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
- plan match {
- case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
- new ColumnarBroadcastExchangeExec(b.mode, b.child)
- case _ =>
+ case broadcastQueryStageExec: BroadcastQueryStageExec =>
+ broadcastQueryStageExec.plan match {
+ case plan: BroadcastExchangeExec =>
+ new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
+ case plan: ReusedExchangeExec =>
+ plan match {
+ case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
+ new ColumnarBroadcastExchangeExec(b.mode, b.child)
+ case _ =>
+ }
}
case _ =>
}
@@ -257,7 +291,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
case p if !supportCodegen(p) =>
// insert row guard them recursively
p.withNewChildren(p.children.map(insertRowGuardOrNot))
- case p: CustomShuffleReaderExec =>
+ case p if SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(p) =>
p.withNewChildren(p.children.map(insertRowGuardOrNot))
case p: BroadcastQueryStageExec =>
p
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index a3b05e30c..17b531777 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting
import com.intel.oap.GazellePluginConfig
import com.intel.oap.expression.ConverterUtils
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller
+import com.intel.oap.sql.shims.SparkShimLoader
import com.intel.oap.vectorized.{ArrowWritableColumnVector, ShuffleSplitterJniWrapper, SplitResult}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark._
@@ -94,7 +95,8 @@ class ColumnarShuffleWriter[K, V](
override def write(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null)
+ SparkShimLoader.getSparkShims.shuffleBlockResolverWriteAndCommit(
+ shuffleBlockResolver, dep.shuffleId, mapId, partitionLengths, null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
return
}
@@ -197,11 +199,8 @@ class ColumnarShuffleWriter[K, V](
partitionLengths = splitResult.getPartitionLengths
rawPartitionLengths = splitResult.getRawPartitionLengths
try {
- shuffleBlockResolver.writeIndexFileAndCommit(
- dep.shuffleId,
- mapId,
- partitionLengths,
- dataTmp)
+ SparkShimLoader.getSparkShims.shuffleBlockResolverWriteAndCommit(
+ shuffleBlockResolver, dep.shuffleId, mapId, partitionLengths, dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
logError(s"Error while deleting temp file ${dataTmp.getAbsolutePath}")
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index fba719f96..8d23395cc 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -20,6 +20,8 @@ package org.apache.spark.shuffle.sort
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
+import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
@@ -107,12 +109,13 @@ class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with Loggin
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
- new SortShuffleWriter(
+ SparkShimLoader.getSparkShims.newSortShuffleWriter(
shuffleBlockResolver,
other,
mapId,
context,
- shuffleExecutorComponents)
+ shuffleExecutorComponents
+ ).asInstanceOf[SortShuffleWriter[K, V, _]]
}
}
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
index b33de1cb8..99f7f7f69 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
@@ -21,6 +21,7 @@ import java.util.concurrent._
import com.google.common.collect.Lists
import com.intel.oap.expression._
+import com.intel.oap.sql.shims.SparkShimLoader
import com.intel.oap.vectorized.{ArrowWritableColumnVector, ExpressionEvaluator}
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field}
@@ -83,10 +84,13 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
lazy val completionFuture: scala.concurrent.Future[broadcast.Broadcast[Any]] =
promise.future
+ @transient
+ private lazy val maxBroadcastRows = SparkShimLoader.getSparkShims.getMaxBroadcastRows(mode)
+
@transient
private[sql] lazy val relationFuture: java.util.concurrent.Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
- sqlContext.sparkSession,
+ SparkShimLoader.getSparkShims.getSparkSession(this),
BroadcastExchangeExec.executionContext) {
var relation: Any = null
try {
@@ -162,9 +166,9 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
/////////////////////////////////////////////////////////////////////////////
- if (numRows >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_ROWS) {
+ if (numRows >= maxBroadcastRows) {
throw new SparkException(
- s"Cannot broadcast the table over ${BroadcastExchangeExec.MAX_BROADCAST_TABLE_ROWS} rows: $numRows rows")
+ s"Cannot broadcast the table over ${maxBroadcastRows} rows: $numRows rows")
}
longMetric("collectTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeCollect)
@@ -254,6 +258,9 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}
}
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarBroadcastExchangeExec =
+ copy(child = newChild)
}
class ColumnarBroadcastExchangeAdaptor(mode: BroadcastMode, child: SparkPlan)
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 8b54cd5ff..f13ffb688 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -156,7 +156,8 @@ case class ColumnarShuffleExchangeExec(
cachedShuffleRDD
}
- // 'shuffleDependency' is only needed when enable AQE. Columnar shuffle will use 'columnarShuffleDependency'
+ // 'shuffleDependency' is only needed when enable AQE. Columnar shuffle will
+ // use 'columnarShuffleDependency'
@transient
lazy val shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow] =
new ShuffleDependency[Int, InternalRow, InternalRow](
@@ -170,6 +171,9 @@ case class ColumnarShuffleExchangeExec(
override val shuffleHandle: ShuffleHandle = columnarShuffleDependency.shuffleHandle
}
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarShuffleExchangeExec =
+ copy(child = newChild)
}
class ColumnarShuffleExchangeAdaptor(
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarBatchRDD.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarBatchRDD.scala
index 760a5af4f..1d64aa271 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarBatchRDD.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarBatchRDD.scala
@@ -77,12 +77,15 @@ class ShuffledColumnarBatchRDD(
override def getPreferredLocations(partition: Partition): Seq[String] = {
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
partition.asInstanceOf[ShuffledColumnarBatchRDDPartition].spec match {
- case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+ // Type matching instead using CoalescedPartitionSpec(startReducerIndex, endReducerIndex)
+ // or CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) for either spark3.1
+ // or spark 3.2. Thus, this piece of code is compatible with both spark 3.1 & 3.2.
+ case coalescedPartitionSpec: CoalescedPartitionSpec =>
// TODO order by partition size.
- startReducerIndex.until(endReducerIndex).flatMap { reducerIndex =>
+ coalescedPartitionSpec.startReducerIndex.until(
+ coalescedPartitionSpec.endReducerIndex).flatMap { reducerIndex =>
tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
}
-
case PartialReducerPartitionSpec(_, startMapIndex, endMapIndex, _) =>
tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
@@ -97,11 +100,12 @@ class ShuffledColumnarBatchRDD(
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
val reader = split.asInstanceOf[ShuffledColumnarBatchRDDPartition].spec match {
- case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+ // Use type matching, similar to the purpose of the code in getPreferredLocations.
+ case coalescedPartitionSpec: CoalescedPartitionSpec =>
SparkEnv.get.shuffleManager.getReader(
dependency.shuffleHandle,
- startReducerIndex,
- endReducerIndex,
+ coalescedPartitionSpec.startReducerIndex,
+ coalescedPartitionSpec.endReducerIndex,
context,
sqlMetricsReporter)
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
index 19a048b42..df7828ab5 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.adaptive
+// import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
@@ -48,18 +50,18 @@ case class ColumnarCustomShuffleReaderExec(
partitionSpecs.map(_.asInstanceOf[PartialMapperPartitionSpec].mapIndex).toSet.size ==
partitionSpecs.length) {
child match {
- case ShuffleQueryStageExec(_, s: ColumnarShuffleExchangeAdaptor) =>
- s.child.outputPartitioning
- case ShuffleQueryStageExec(
- _,
- r @ ReusedExchangeExec(_, s: ColumnarShuffleExchangeAdaptor)) =>
- s.child.outputPartitioning match {
- case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
- case other => other
- }
- case _ =>
- throw new IllegalStateException("operating on canonicalization plan")
- }
+ case shuffleQueryStageExec: ShuffleQueryStageExec =>
+ shuffleQueryStageExec.plan match {
+ case s: ColumnarShuffleExchangeAdaptor => s.child.outputPartitioning
+ case r @ ReusedExchangeExec(_, s: ColumnarShuffleExchangeAdaptor) =>
+ s.child.outputPartitioning match {
+ case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
+ case other => other
+ }
+ }
+ case _ =>
+ throw new IllegalStateException("operating on canonicalization plan")
+ }
} else {
UnknownPartitioning(partitionSpecs.length)
}
@@ -90,4 +92,8 @@ case class ColumnarCustomShuffleReaderExec(
override protected def doExecute(): RDD[InternalRow] =
throw new UnsupportedOperationException()
+
+ // For spark 3.2.
+ protected def withNewChildInternal(newChild: SparkPlan): ColumnarCustomShuffleReaderExec =
+ copy(child = newChild)
}
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala
index fb1875dd4..0e9c143a6 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/python/ColumnarArrowPythonRunner.scala
@@ -29,6 +29,7 @@ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, SpecialLengths}
+import org.apache.spark.sql.BasePythonRunnerChild
//import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.python.PythonUDFRunner
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
@@ -36,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.BasePythonRunnerChild
import org.apache.spark.util.Utils
/**
@@ -48,7 +50,7 @@ class ColumnarArrowPythonRunner(
schema: StructType,
timeZoneId: String,
conf: Map[String, String])
- extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs, evalType, argOffsets) {
+ extends BasePythonRunnerChild(funcs, evalType, argOffsets) {
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
@@ -58,73 +60,6 @@ class ColumnarArrowPythonRunner(
"Pandas execution requires more than 4 bytes. Please set higher buffer. " +
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
- protected def newReaderIterator(
- stream: DataInputStream,
- writerThread: WriterThread,
- startTime: Long,
- env: SparkEnv,
- worker: Socket,
- releasedOrClosed: AtomicBoolean,
- context: TaskContext): Iterator[ColumnarBatch] = {
-
- new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
- private val allocator = SparkMemoryUtils.contextAllocator().newChildAllocator(
- s"stdin reader for $pythonExec", 0, Long.MaxValue)
-
- private var reader: ArrowStreamReader = _
- private var root: VectorSchemaRoot = _
- private var schema: StructType = _
- private var vectors: Array[ColumnVector] = _
-
- context.addTaskCompletionListener[Unit] { _ =>
- if (reader != null) {
- reader.close(false)
- }
- allocator.close()
- }
-
- private var batchLoaded = true
-
- protected override def read(): ColumnarBatch = {
- if (writerThread.exception.isDefined) {
- throw writerThread.exception.get
- }
- try {
- if (reader != null && batchLoaded) {
- batchLoaded = reader.loadNextBatch()
- if (batchLoaded) {
- val batch = new ColumnarBatch(vectors)
- batch.setNumRows(root.getRowCount)
- batch
- } else {
- reader.close(false)
- allocator.close()
- // Reach end of stream. Call `read()` again to read control data.
- read()
- }
- } else {
- stream.readInt() match {
- case SpecialLengths.START_ARROW_STREAM =>
- reader = new ArrowStreamReader(stream, allocator)
- root = reader.getVectorSchemaRoot()
- schema = ArrowUtils.fromArrowSchema(root.getSchema())
- vectors = ArrowWritableColumnVector.loadColumns(root.getRowCount, root.getFieldVectors).toArray[ColumnVector]
- read()
- case SpecialLengths.TIMING_DATA =>
- handleTimingData()
- read()
- case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- throw handlePythonException()
- case SpecialLengths.END_OF_DATA_SECTION =>
- handleEndOfDataSection()
- null
- }
- }
- } catch handleException
- }
- }
- }
-
protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/util/ShufflePartitionUtils.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/util/ShufflePartitionUtils.scala
index 0acf9afdb..c19e2d1c6 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/util/ShufflePartitionUtils.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/util/ShufflePartitionUtils.scala
@@ -17,13 +17,14 @@
package org.apache.spark.util
+import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
-import org.apache.spark.sql.execution.adaptive.OptimizeSkewedJoin.supportedJoinTypes
-import org.apache.spark.sql.execution.adaptive.{CustomShuffleReaderExec, OptimizeSkewedJoin, ShuffleQueryStageExec, ShuffleStage, ShuffleStageInfo}
-import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION}
+import org.apache.spark.sql.execution.adaptive.{OptimizeSkewedJoin, ShuffleQueryStageExec, ShuffleStageInfo}
+import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import scala.collection.mutable
@@ -37,7 +38,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
object ShufflePartitionUtils {
def withCustomShuffleReaders(plan: SparkPlan): Boolean = {
- plan.children.forall(_.isInstanceOf[CustomShuffleReaderExec])
+ plan.children.forall(p => SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(p))
}
def isShuffledHashJoinTypeOptimizable(joinType: JoinType): Boolean = {
@@ -50,16 +51,22 @@ object ShufflePartitionUtils {
def reoptimizeShuffledHashJoinInput(plan: ShuffledHashJoinExec): ShuffledHashJoinExec =
plan match {
- case shj @ ShuffledHashJoinExec(_, _, joinType, _, _,
- s1 @ ShuffleStage(leftStageInfo: ShuffleStageInfo),
- s2 @ ShuffleStage(rightStageInfo: ShuffleStageInfo))
- if isShuffledHashJoinTypeOptimizable(joinType) =>
+ // TODO: p.left p.right may need to be checked.
+ case p: ShuffledHashJoinExec if isShuffledHashJoinTypeOptimizable(p.joinType) =>
+ val leftStageInfo: ShuffleStageInfo = p.left match {
+ case ShuffleStage(leftStage: ShuffleStageInfo) => leftStage
+ case _ => throw new RuntimeException("Fix me!")
+ }
+ val rightStageInfo: ShuffleStageInfo = p.left match {
+ case ShuffleStage(rightStage: ShuffleStageInfo) => rightStage
+ case _ => throw new RuntimeException("Fix me!")
+ }
val left = plan.left
val right = plan.right
val shuffleStages = Array(left, right)
- .map(c => c.asInstanceOf[CustomShuffleReaderExec]
- .child.asInstanceOf[ShuffleQueryStageExec]).toList
+ .map(c => SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(c)
+ .asInstanceOf[ShuffleQueryStageExec]).toList
if (shuffleStages.isEmpty) {
return plan
@@ -67,7 +74,7 @@ object ShufflePartitionUtils {
if (!shuffleStages.forall(s => s.shuffle.shuffleOrigin match {
case ENSURE_REQUIREMENTS => true
- case REPARTITION => true
+ case so if SparkShimLoader.getSparkShims.isRepartition(so) => true
case _ => false
})) {
return plan
@@ -104,8 +111,8 @@ object ShufflePartitionUtils {
val offHeapOptimizationTarget = buildSizeLimit
- val leftSpecs = left.asInstanceOf[CustomShuffleReaderExec].partitionSpecs
- val rightSpecs = right.asInstanceOf[CustomShuffleReaderExec].partitionSpecs
+ val leftSpecs = SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(left)
+ val rightSpecs = SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(right)
if (leftSpecs.size != rightSpecs.size) {
throw new IllegalStateException("Input partition mismatch for ColumnarShuffledHashJoin")
@@ -164,16 +171,13 @@ object ShufflePartitionUtils {
}
}
- val leftReader = left.asInstanceOf[CustomShuffleReaderExec]
- val rightReader = right.asInstanceOf[CustomShuffleReaderExec]
+ val leftReaderChild = SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(left)
+ val rightReaderChild = SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(right)
// todo equality check?
plan.withNewChildren(
- Array(
- CustomShuffleReaderExec(leftReader.child,
- leftJoinedParts),
- CustomShuffleReaderExec(rightReader.child,
- rightJoinedParts)
+ Array(SparkShimLoader.getSparkShims.newCustomShuffleReaderExec(leftReaderChild, leftJoinedParts),
+ SparkShimLoader.getSparkShims.newCustomShuffleReaderExec(rightReaderChild, rightJoinedParts)
)).asInstanceOf[ShuffledHashJoinExec]
case _ =>
plan
@@ -191,17 +195,27 @@ object ShufflePartitionUtils {
}
Some(ShuffleStageInfo(s, mapStats, s.getRuntimeStatistics, partitions))
- case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs)
- if s.mapStats.isDefined && partitionSpecs.nonEmpty &&
- OptimizeSkewedJoin.supportedShuffleOrigins.contains(s.shuffle.shuffleOrigin) =>
- val statistics = s.getRuntimeStatistics
- val mapStats = s.mapStats.get
+ case plan if SparkShimLoader.getSparkShims.isCustomShuffleReaderExec(plan) &&
+ SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan)
+ .isInstanceOf[ShuffleQueryStageExec] &&
+ SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan)
+ .asInstanceOf[ShuffleQueryStageExec].mapStats.isDefined &&
+ SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(plan).nonEmpty &&
+ OptimizeSkewedJoin.supportedShuffleOrigins.contains(
+ SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan)
+ .asInstanceOf[ShuffleQueryStageExec].shuffle.shuffleOrigin) =>
+ val child = SparkShimLoader.getSparkShims.getChildOfCustomShuffleReaderExec(plan)
+ .asInstanceOf[ShuffleQueryStageExec]
+ val partitionSpecs =
+ SparkShimLoader.getSparkShims.getPartitionSpecsOfCustomShuffleReaderExec(plan)
+ val statistics = child.getRuntimeStatistics
+ val mapStats = child.mapStats.get
val sizes = mapStats.bytesByPartitionId
val partitions = partitionSpecs.map {
- case spec @ CoalescedPartitionSpec(start, end) =>
+ case spec: CoalescedPartitionSpec =>
var sum = 0L
- var i = start
- while (i < end) {
+ var i = spec.startReducerIndex
+ while (i < spec.endReducerIndex) {
sum += sizes(i)
i += 1
}
@@ -209,7 +223,7 @@ object ShufflePartitionUtils {
case other => throw new IllegalArgumentException(
s"Expect CoalescedPartitionSpec but got $other")
}
- Some(ShuffleStageInfo(s, mapStats, s.getRuntimeStatistics, partitions))
+ Some(ShuffleStageInfo(child, mapStats, child.getRuntimeStatistics, partitions))
case _ => None
}
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/util/UserAddedJarUtils.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/util/UserAddedJarUtils.scala
index ec59d4408..590c44db8 100644
--- a/native-sql-engine/core/src/main/scala/org/apache/spark/util/UserAddedJarUtils.scala
+++ b/native-sql-engine/core/src/main/scala/org/apache/spark/util/UserAddedJarUtils.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.util
+import com.intel.oap.sql.shims.SparkShimLoader
+
import org.apache.spark.{SparkConf, SparkContext}
import java.io.File
import java.nio.file.Files
@@ -33,7 +35,9 @@ object UserAddedJarUtils {
//TODO: don't fetch when exists
val targetPath = Paths.get(targetDir + "/" + targetFileName)
if (Files.notExists(targetPath)) {
- Utils.doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf, null, null)
+ SparkShimLoader
+ .getSparkShims
+ .doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf)
} else {}
}
}
diff --git a/pom.xml b/pom.xml
index 2dba22730..8a92f613c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -46,11 +46,21 @@
- spark-3.1.1
+ spark-3.1
${spark311.version}
+ 2.12.10
+ 2.10.0
+
+ spark-3.2
+
+ ${spark321.version}
+ 2.12.15
+ 2.12.0
+
+
hadoop-2.7.4
@@ -111,9 +121,15 @@
+
+ 3.1.1
+ 3.1.1
+ 3.2.1
+
2.12.10
+ 1.8
+ 2.10.0
2.12
- 3.1.1
4.0.0
2.17.1
arrow-memory-unsafe
@@ -130,7 +146,6 @@
ON
spark-sql-columnar
OAP Project Spark Columnar Plugin
- 3.1.1
diff --git a/shims/common/pom.xml b/shims/common/pom.xml
index a807fd05e..8e8a17831 100644
--- a/shims/common/pom.xml
+++ b/shims/common/pom.xml
@@ -26,6 +26,7 @@
spark-sql-columnar-shims-common
${project.name.prefix} Shims Common
+ 1.4.0-SNAPSHOT
jar
@@ -78,7 +79,19 @@
org.apache.spark
spark-sql_${scala.binary.version}
${spark311.version}
- provided
+ provided
+
+
+ com.intel.oap
+ spark-arrow-datasource-common
+ ${project.version}
+ provided
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+ ${hadoop.version}
+ provided
diff --git a/shims/common/src/main/scala/com/intel/oap/sql/shims/SparkShims.scala b/shims/common/src/main/scala/com/intel/oap/sql/shims/SparkShims.scala
index a67e08dcb..62420a6f8 100644
--- a/shims/common/src/main/scala/com/intel/oap/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/com/intel/oap/sql/shims/SparkShims.scala
@@ -16,6 +16,30 @@
package com.intel.oap.sql.shims
+import com.intel.oap.spark.sql.ArrowWriteQueue
+import java.io.File
+import java.time.ZoneId
+
+import org.apache.parquet.hadoop.metadata.FileMetaData
+import org.apache.parquet.schema.MessageType
+import org.apache.spark.SparkConf
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.MigratableResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
+import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+import org.apache.spark.sql.execution.datasources.OutputWriter
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetOptions, ParquetReadSupport, VectorizedParquetRecordReader}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleOrigin}
+import org.apache.spark.sql.internal.SQLConf
+
sealed abstract class ShimDescriptor
case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor {
@@ -24,4 +48,69 @@ case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimD
trait SparkShims {
def getShimDescriptor: ShimDescriptor
+
+ def shuffleBlockResolverWriteAndCommit(shuffleBlockResolver: MigratableResolver,
+ shuffleId: Int, mapId: Long, partitionLengths: Array[Long], dataTmp: File): Unit
+
+ def getDatetimeRebaseMode(fileMetaData: FileMetaData, parquetOptions: ParquetOptions): SQLConf.LegacyBehaviorPolicy.Value
+
+ def newParquetFilters(parquetSchema: MessageType,
+ pushDownDate: Boolean,
+ pushDownTimestamp: Boolean,
+ pushDownDecimal: Boolean,
+ pushDownStringStartWith: Boolean,
+ pushDownInFilterThreshold: Int,
+ isCaseSensitive: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions): ParquetFilters
+
+ def newVectorizedParquetRecordReader(convertTz: ZoneId,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions,
+ useOffHeap: Boolean,
+ capacity: Int): VectorizedParquetRecordReader
+
+ def newParquetReadSupport(convertTz: Option[ZoneId],
+ enableVectorizedReader: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions): ParquetReadSupport
+
+ def getRuntimeFilters(plan: BatchScanExec): Seq[Expression]
+
+ def getBroadcastHashJoinOutputPartitioningExpandLimit(plan: SparkPlan): Int
+
+ /**
+ * The access modifier of IndexShuffleBlockResolver & BaseShuffleHandle is private[spark]. So we
+ * use their corresponding base types here. They will be checked and converted at implementation place.
+ * SortShuffleWriter's access modifier is private[spark], so we let the return type be AnyRef and
+ * make the conversion at the place where this method is called.
+ * */
+ def newSortShuffleWriter(resolver: MigratableResolver, shuffleHandle: ShuffleHandle,
+ mapId: Long, context: TaskContext,
+ shuffleExecutorComponents: ShuffleExecutorComponents): AnyRef
+
+ def getMaxBroadcastRows(mode: BroadcastMode): Long
+
+ def getSparkSession(plan: SparkPlan): SparkSession
+
+ def doFetchFile(urlString: String, targetDirHandler: File, targetFileName: String, sparkConf: SparkConf): Unit
+
+ def newBroadcastQueryStageExec(id: Int, plan: BroadcastExchangeExec): BroadcastQueryStageExec
+
+ def isCustomShuffleReaderExec(plan: SparkPlan): Boolean
+
+ /**
+ * Return SparkPlan type since the type name is changed from spark 3.2.
+ * TODO: need tests.
+ */
+ def newCustomShuffleReaderExec(child: SparkPlan, partitionSpecs : Seq[ShufflePartitionSpec]): SparkPlan
+
+ def getChildOfCustomShuffleReaderExec(plan: SparkPlan): SparkPlan
+
+ def getPartitionSpecsOfCustomShuffleReaderExec(plan: SparkPlan): Seq[ShufflePartitionSpec]
+
+ /**
+ * REPARTITION is changed to REPARTITION_BY_COL from spark 3.2.
+ */
+ def isRepartition(shuffleOrigin: ShuffleOrigin): Boolean
}
diff --git a/shims/pom.xml b/shims/pom.xml
index ecfb69f96..9e45e9dd6 100644
--- a/shims/pom.xml
+++ b/shims/pom.xml
@@ -29,7 +29,6 @@
pom
- 2.12.10
2.12
4.3.0
3.2.2
@@ -75,7 +74,7 @@
- spark-3.1.1
+ spark-3.1
@@ -83,6 +82,15 @@
spark311
+
+ spark-3.2
+
+
+
+ common
+ spark321
+
+
diff --git a/shims/spark311/pom.xml b/shims/spark311/pom.xml
index dedd4ebe7..ef9a8cf40 100644
--- a/shims/spark311/pom.xml
+++ b/shims/spark311/pom.xml
@@ -92,5 +92,29 @@
${spark311.version}
provided
+
+ com.intel.oap
+ spark-arrow-datasource-common
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+ 23.0
+ provided
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ provided
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+ ${hadoop.version}
+ provided
+
diff --git a/shims/spark311/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java b/shims/spark311/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java
new file mode 100644
index 000000000..c3406e583
--- /dev/null
+++ b/shims/spark311/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources;
+
+import java.time.ZoneId;
+import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader;
+
+/**
+ * A class to help fix compatibility issues for class who extends VectorizedParquetRecordReader.
+ */
+public class VectorizedParquetRecordReaderChild extends VectorizedParquetRecordReader {
+
+ public VectorizedParquetRecordReaderChild(ZoneId convertTz,
+ String datetimeRebaseMode,
+ String datetimeRebaseTz,
+ String int96RebaseMode,
+ String int96RebaseTz,
+ boolean useOffHeap,
+ int capacity) {
+ super(convertTz, datetimeRebaseMode, int96RebaseMode, useOffHeap, capacity);
+ }
+}
\ No newline at end of file
diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala b/shims/spark311/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
similarity index 70%
rename from native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
rename to shims/spark311/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
index ed8be5ad8..e0cfccd9d 100644
--- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
+++ b/shims/spark311/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
@@ -17,17 +17,23 @@
package com.intel.oap.execution
-import com.intel.oap.GazellePluginConfig
+//import com.intel.oap.GazellePluginConfig
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Scan)
- extends BatchScanExec(output, scan) {
- val tmpDir: String = GazellePluginConfig.getConf.tmpFile
+/** For spark 3.1, the runtimeFilters: Seq[Expression] is not introduced in BatchScanExec.
+ * This class lacks the implementation for doExecuteColumnar.
+ */
+abstract class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Scan,
+ runtimeFilters: Seq[Expression])
+ extends BatchScanExec(output, scan) {
+ // tmpDir is used by ParquetReader, which looks useless (may be removed in the future).
+ // Here, "/tmp" is directly used, no need to get it set through configuration.
+ // val tmpDir: String = GazellePluginConfig.getConf.tmpFile
+ val tmpDir: String = "/tmp"
override def supportsColumnar(): Boolean = true
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
@@ -35,20 +41,6 @@ class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Sc
"numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_batchscan"),
"inputSize" -> SQLMetrics.createSizeMetric(sparkContext, "input size in bytes"))
- override def doExecuteColumnar(): RDD[ColumnarBatch] = {
- val numOutputRows = longMetric("numOutputRows")
- val numInputBatches = longMetric("numInputBatches")
- val numOutputBatches = longMetric("numOutputBatches")
- val scanTime = longMetric("scanTime")
- val inputSize = longMetric("inputSize")
- val inputColumnarRDD =
- new ColumnarDataSourceRDD(sparkContext, partitions, readerFactory, true, scanTime, numInputBatches, inputSize, tmpDir)
- inputColumnarRDD.map { r =>
- numOutputRows += r.numRows()
- numOutputBatches += 1
- r
- }
- }
override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarBatchScanExec]
diff --git a/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/Spark311Shims.scala
index c1c806672..e637877e9 100644
--- a/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/Spark311Shims.scala
+++ b/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/Spark311Shims.scala
@@ -16,8 +16,167 @@
package com.intel.oap.sql.shims.spark311
-import com.intel.oap.sql.shims.{SparkShims, ShimDescriptor}
+import com.intel.oap.execution.ColumnarBatchScanExec
+import com.intel.oap.spark.sql.ArrowWriteQueue
+import com.intel.oap.sql.shims.{ShimDescriptor, SparkShims}
+import java.io.File
+import java.time.ZoneId
+
+import org.apache.parquet.hadoop.metadata.FileMetaData
+import org.apache.parquet.schema.MessageType
+import org.apache.spark.SparkConf
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.MigratableResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.util.ShimUtils
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
+import org.apache.spark.sql.execution.ShufflePartitionSpec
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetOptions, ParquetReadSupport, VectorizedParquetRecordReader}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
+import org.apache.spark.sql.execution.datasources.{DataSourceUtils, OutputWriter}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, REPARTITION, ReusedExchangeExec, ShuffleExchangeExec, ShuffleOrigin}
+import org.apache.spark.sql.internal.SQLConf
class Spark311Shims extends SparkShims {
+
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
+
+ override def shuffleBlockResolverWriteAndCommit(shuffleBlockResolver: MigratableResolver,
+ shuffleId: Int, mapId: Long, partitionLengths: Array[Long], dataTmp: File): Unit =
+ ShimUtils.shuffleBlockResolverWriteAndCommit(shuffleBlockResolver, shuffleId, mapId, partitionLengths, dataTmp)
+
+ override def getDatetimeRebaseMode(fileMetaData: FileMetaData, parquetOptions: ParquetOptions):
+ SQLConf.LegacyBehaviorPolicy.Value = {
+ DataSourceUtils.datetimeRebaseMode(
+ fileMetaData.getKeyValueMetaData.get,
+ SQLConf.get.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ))
+ }
+
+ override def newParquetFilters(parquetSchema: MessageType,
+ pushDownDate: Boolean,
+ pushDownTimestamp: Boolean,
+ pushDownDecimal: Boolean,
+ pushDownStringStartWith: Boolean,
+ pushDownInFilterThreshold: Int,
+ isCaseSensitive: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions
+ ): ParquetFilters = {
+ new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp,
+ pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
+ }
+
+ override def newVectorizedParquetRecordReader(convertTz: ZoneId,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions,
+ useOffHeap: Boolean,
+ capacity: Int): VectorizedParquetRecordReader = {
+ new VectorizedParquetRecordReader(
+ convertTz,
+ getDatetimeRebaseMode(fileMetaData, parquetOptions).toString,
+ "",
+ useOffHeap,
+ capacity)
+ }
+
+ override def newParquetReadSupport(convertTz: Option[ZoneId],
+ enableVectorizedReader: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions): ParquetReadSupport = {
+ val datetimeRebaseMode = getDatetimeRebaseMode(fileMetaData, parquetOptions)
+ new ParquetReadSupport(
+ convertTz, enableVectorizedReader = false, datetimeRebaseMode, SQLConf.LegacyBehaviorPolicy.LEGACY)
+ }
+
+ /**
+ * The runtimeFilters is just available from spark 3.2.
+ */
+ override def getRuntimeFilters(plan: BatchScanExec): Seq[Expression] = {
+ return null
+ }
+
+ override def getBroadcastHashJoinOutputPartitioningExpandLimit(plan: SparkPlan): Int = {
+ plan.sqlContext.getConf(
+ "spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit").trim().toInt
+ }
+
+ override def newSortShuffleWriter(resolver: MigratableResolver, shuffleHandle: ShuffleHandle,
+ mapId: Long, context: TaskContext,
+ shuffleExecutorComponents: ShuffleExecutorComponents): AnyRef = {
+ ShimUtils.newSortShuffleWriter(
+ resolver,
+ shuffleHandle,
+ mapId,
+ context,
+ shuffleExecutorComponents)
+ }
+
+ override def getMaxBroadcastRows(mode: BroadcastMode): Long = {
+ BroadcastExchangeExec.MAX_BROADCAST_TABLE_ROWS
+ }
+
+ override def getSparkSession(plan: SparkPlan): SparkSession = {
+ plan.sqlContext.sparkSession
+ }
+
+ override def doFetchFile(urlString: String, targetDirHandler: File,
+ targetFileName: String, sparkConf: SparkConf): Unit = {
+ ShimUtils.doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf)
+ }
+
+ override def newBroadcastQueryStageExec(id: Int, plan: BroadcastExchangeExec):
+ BroadcastQueryStageExec = {
+ BroadcastQueryStageExec(id, plan)
+ }
+
+ /**
+ * CustomShuffleReaderExec is renamed to AQEShuffleReadExec from spark 3.2.
+ */
+ override def isCustomShuffleReaderExec(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: CustomShuffleReaderExec => true
+ case _ => false
+ }
+ }
+
+ override def newCustomShuffleReaderExec(child: SparkPlan, partitionSpecs : Seq[ShufflePartitionSpec]): SparkPlan = {
+ CustomShuffleReaderExec(child, partitionSpecs)
+ }
+
+ /**
+ * Only applicable to CustomShuffleReaderExec. Otherwise, an exception will be thrown.
+ */
+ override def getChildOfCustomShuffleReaderExec(plan: SparkPlan): SparkPlan = {
+ plan match {
+ case plan: CustomShuffleReaderExec => plan.child
+ case _ => throw new RuntimeException("CustomShuffleReaderExec is expected!")
+ }
+ }
+
+ /**
+ * Only applicable to CustomShuffleReaderExec. Otherwise, an exception will be thrown.
+ */
+ override def getPartitionSpecsOfCustomShuffleReaderExec(plan: SparkPlan): Seq[ShufflePartitionSpec] = {
+ plan match {
+ case plan: CustomShuffleReaderExec => plan.partitionSpecs
+ case _ => throw new RuntimeException("CustomShuffleReaderExec is expected!")
+ }
+ }
+
+ override def isRepartition(shuffleOrigin: ShuffleOrigin): Boolean = {
+ shuffleOrigin match {
+ case REPARTITION => true
+ case _ => false
+ }
+ }
+
}
\ No newline at end of file
diff --git a/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/SparkShimProvider.scala b/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/SparkShimProvider.scala
index 59c7505e9..725788785 100644
--- a/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/SparkShimProvider.scala
+++ b/shims/spark311/src/main/scala/com/intel/oap/sql/shims/spark311/SparkShimProvider.scala
@@ -30,5 +30,5 @@ class SparkShimProvider extends com.intel.oap.sql.shims.SparkShimProvider {
def matches(version: String): Boolean = {
SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
- }
+ }
}
diff --git a/shims/spark311/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala b/shims/spark311/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala
new file mode 100644
index 000000000..12fd87193
--- /dev/null
+++ b/shims/spark311/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+import java.io.DataInputStream
+import java.net.Socket
+import java.util.concurrent.atomic.AtomicBoolean
+
+import com.intel.oap.vectorized.ArrowWritableColumnVector
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, SpecialLengths}
+import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
+
+/**
+ * We put this class in this package in order to access ArrowUtils, which has
+ * private[sql] access modifier.
+ */
+abstract class BasePythonRunnerChild(funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]])
+ extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs, evalType, argOffsets) {
+
+ /**
+ * The implementation is completely as same as that for spark 3.1 except the arguments for
+ * newReaderIterator & ReaderIterator. The pid: Option[Int] is introduced from spark 3.2.
+ * TODO: put the implementation into some common place.
+ */
+ protected def newReaderIterator(
+ stream: DataInputStream,
+ writerThread: WriterThread,
+ startTime: Long,
+ env: SparkEnv,
+ worker: Socket,
+ releasedOrClosed: AtomicBoolean,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+
+ // The pid argument is added for ReaderIterator since spark3.2. So we introduce
+ // ReaderIteratorChild to help fix the compatibility issues.
+ new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+ private val allocator = SparkMemoryUtils.contextAllocator().newChildAllocator(
+ s"stdin reader for $pythonExec", 0, Long.MaxValue)
+
+ private var reader: ArrowStreamReader = _
+ private var root: VectorSchemaRoot = _
+ private var schema: StructType = _
+ private var vectors: Array[ColumnVector] = _
+
+ context.addTaskCompletionListener[Unit] { _ =>
+ if (reader != null) {
+ reader.close(false)
+ }
+ allocator.close()
+ }
+
+ private var batchLoaded = true
+
+ protected override def read(): ColumnarBatch = {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ try {
+ if (reader != null && batchLoaded) {
+ batchLoaded = reader.loadNextBatch()
+ if (batchLoaded) {
+ val batch = new ColumnarBatch(vectors)
+ batch.setNumRows(root.getRowCount)
+ batch
+ } else {
+ reader.close(false)
+ allocator.close()
+ // Reach end of stream. Call `read()` again to read control data.
+ read()
+ }
+ } else {
+ stream.readInt() match {
+ case SpecialLengths.START_ARROW_STREAM =>
+ reader = new ArrowStreamReader(stream, allocator)
+ root = reader.getVectorSchemaRoot()
+ schema = ArrowUtils.fromArrowSchema(root.getSchema())
+ vectors = ArrowWritableColumnVector.loadColumns(root.getRowCount, root.getFieldVectors).toArray[ColumnVector]
+ read()
+ case SpecialLengths.TIMING_DATA =>
+ handleTimingData()
+ read()
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+ throw handlePythonException()
+ case SpecialLengths.END_OF_DATA_SECTION =>
+ handleEndOfDataSection()
+ null
+ }
+ }
+ } catch handleException
+ }
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/shims/spark311/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
similarity index 100%
rename from native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
rename to shims/spark311/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/shims/spark311/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
similarity index 100%
rename from native-sql-engine/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
rename to shims/spark311/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
diff --git a/shims/spark311/src/main/scala/org/apache/spark/util/ShimUtils.scala b/shims/spark311/src/main/scala/org/apache/spark/util/ShimUtils.scala
new file mode 100644
index 000000000..a9f79b97f
--- /dev/null
+++ b/shims/spark311/src/main/scala/org/apache/spark/util/ShimUtils.scala
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2020 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.MigratableResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+
+object ShimUtils {
+
+ /**
+ * Only applicable to IndexShuffleBlockResolver. We move the implementation here, because
+ * IndexShuffleBlockResolver's access modifier is private[spark].
+ */
+ def shuffleBlockResolverWriteAndCommit(shuffleBlockResolver: MigratableResolver,
+ shuffleId: Int, mapId: Long,
+ partitionLengths: Array[Long], dataTmp: File): Unit = {
+ shuffleBlockResolver match {
+ case resolver: IndexShuffleBlockResolver =>
+ resolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, dataTmp)
+ case _ => throw new RuntimeException ("IndexShuffleBlockResolver is expected!")
+ }
+ }
+
+ def newSortShuffleWriter(resolver: MigratableResolver, shuffleHandle: ShuffleHandle,
+ mapId: Long, context: TaskContext,
+ shuffleExecutorComponents: ShuffleExecutorComponents): AnyRef = {
+ resolver match {
+ case indexShuffleBlockResolver: IndexShuffleBlockResolver =>
+ shuffleHandle match {
+ case baseShuffleHandle: BaseShuffleHandle[_, _, _] =>
+ new SortShuffleWriter(
+ indexShuffleBlockResolver,
+ baseShuffleHandle,
+ mapId,
+ context,
+ shuffleExecutorComponents)
+ case _ => throw new RuntimeException("BaseShuffleHandle is expected!")
+ }
+ case _ => throw new RuntimeException("IndexShuffleBlockResolver is expected!")
+ }
+ }
+
+ /**
+ * We move the implementation into this package because Utils has private[spark]
+ * access modifier.
+ */
+ def doFetchFile(urlString: String, targetDirHandler: File,
+ targetFileName: String, sparkConf: SparkConf): Unit = {
+ Utils.doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf, null, null)
+ }
+}
\ No newline at end of file
diff --git a/shims/spark321/pom.xml b/shims/spark321/pom.xml
new file mode 100644
index 000000000..aab60e3d5
--- /dev/null
+++ b/shims/spark321/pom.xml
@@ -0,0 +1,114 @@
+
+
+
+ 4.0.0
+
+
+ com.intel.oap
+ spark-sql-columnar-shims
+ 1.4.0-SNAPSHOT
+ ../pom.xml
+
+
+ spark-sql-columnar-shims-spark321
+ ${project.name.prefix} Shims for Spark 3.2.1
+ jar
+
+
+
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.2
+
+
+ scala-compile-first
+ process-resources
+
+ compile
+
+
+
+ scala-test-compile-first
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.3
+
+
+ ${java.version}
+ UTF-8
+ 1024m
+ true
+
+ -Xlint:all,-serial,-path
+
+
+
+
+
+
+
+ src/main/resources
+
+
+
+
+
+
+ com.intel.oap
+ ${project.prefix}-shims-common
+ ${project.version}
+ compile
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark321.version}
+ provided
+
+
+ com.intel.oap
+ spark-arrow-datasource-common
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+ 11.0.2
+ provided
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ provided
+
+
+
diff --git a/shims/spark321/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java b/shims/spark321/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java
new file mode 100644
index 000000000..5a987c4b1
--- /dev/null
+++ b/shims/spark321/src/main/java/org/apache/spark/sql/execution/datasources/VectorizedParquetRecordReaderChild.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources;
+
+import java.time.ZoneId;
+import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader;
+
+/**
+ * A class to help fix compatibility issues for class who extends VectorizedParquetRecordReader.
+ */
+public class VectorizedParquetRecordReaderChild extends VectorizedParquetRecordReader {
+
+ public VectorizedParquetRecordReaderChild(ZoneId convertTz,
+ String datetimeRebaseMode,
+ String datetimeRebaseTz,
+ String int96RebaseMode,
+ String int96RebaseTz,
+ boolean useOffHeap,
+ int capacity) {
+ super(convertTz, datetimeRebaseMode, datetimeRebaseTz, int96RebaseMode, int96RebaseTz, useOffHeap, capacity);
+ }
+}
\ No newline at end of file
diff --git a/shims/spark321/src/main/resources/META-INF/services/com.intel.oap.sql.shims.SparkShimProvider b/shims/spark321/src/main/resources/META-INF/services/com.intel.oap.sql.shims.SparkShimProvider
new file mode 100644
index 000000000..65a1b25cf
--- /dev/null
+++ b/shims/spark321/src/main/resources/META-INF/services/com.intel.oap.sql.shims.SparkShimProvider
@@ -0,0 +1 @@
+com.intel.oap.sql.shims.spark321.SparkShimProvider
\ No newline at end of file
diff --git a/shims/spark321/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala b/shims/spark321/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
new file mode 100644
index 000000000..99c12d394
--- /dev/null
+++ b/shims/spark321/src/main/scala/com/intel/oap/execution/ColumnarBatchScanExec.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.oap.execution
+
+//import com.intel.oap.GazellePluginConfig
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * The runtimeFilters is not actually used in ColumnarBatchScanExec currently.
+ * This class lacks the implementation for doExecuteColumnar.
+ */
+abstract class ColumnarBatchScanExec(output: Seq[AttributeReference], @transient scan: Scan,
+ runtimeFilters: Seq[Expression])
+ extends BatchScanExec(output, scan, runtimeFilters) {
+ // tmpDir is used by ParquetReader, which looks useless (may be removed in the future).
+ // Here, "/tmp" is directly used, no need to get it set through configuration.
+ // val tmpDir: String = GazellePluginConfig.getConf.tmpFile
+ val tmpDir: String = "/tmp"
+ override def supportsColumnar(): Boolean = true
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "input_batches"),
+ "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"),
+ "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_batchscan"),
+ "inputSize" -> SQLMetrics.createSizeMetric(sparkContext, "input size in bytes"))
+
+ override def canEqual(other: Any): Boolean = other.isInstanceOf[ColumnarBatchScanExec]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ColumnarBatchScanExec =>
+ (that canEqual this) && super.equals(that)
+ case _ => false
+ }
+}
diff --git a/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/Spark321Shims.scala b/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/Spark321Shims.scala
new file mode 100644
index 000000000..ce8d2d33e
--- /dev/null
+++ b/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/Spark321Shims.scala
@@ -0,0 +1,233 @@
+/*
+ * Copyright 2020 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.oap.sql.shims.spark321
+
+import com.intel.oap.execution.ColumnarBatchScanExec
+import com.intel.oap.spark.sql.ArrowWriteQueue
+import com.intel.oap.sql.shims.{ShimDescriptor, SparkShims}
+import java.io.File
+import java.time.ZoneId
+
+import org.apache.parquet.hadoop.metadata.FileMetaData
+import org.apache.parquet.schema.MessageType
+import org.apache.spark.SparkConf
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.MigratableResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.util.ShimUtils
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.execution.ShufflePartitionSpec
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetOptions, ParquetReadSupport, VectorizedParquetRecordReader}
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
+import org.apache.spark.sql.execution.datasources.{DataSourceUtils, OutputWriter}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec, ShuffleOrigin}
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.LongType
+
+class Spark321Shims extends SparkShims {
+
+ override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
+
+ override def shuffleBlockResolverWriteAndCommit(shuffleBlockResolver: MigratableResolver,
+ shuffleId: Int, mapId: Long,
+ partitionLengths: Array[Long],
+ dataTmp: File): Unit =
+ ShimUtils.shuffleBlockResolverWriteAndCommit(
+ shuffleBlockResolver, shuffleId, mapId, partitionLengths, dataTmp)
+
+ def getDatetimeRebaseSpec(fileMetaData: FileMetaData, parquetOptions: ParquetOptions): RebaseSpec = {
+ DataSourceUtils.datetimeRebaseSpec(
+ fileMetaData.getKeyValueMetaData.get,
+ parquetOptions.datetimeRebaseModeInRead)
+ }
+
+ def getInt96RebaseSpec(fileMetaData: FileMetaData, parquetOptions: ParquetOptions): RebaseSpec = {
+ DataSourceUtils.int96RebaseSpec(
+ fileMetaData.getKeyValueMetaData.get,
+ parquetOptions.datetimeRebaseModeInRead)
+ }
+
+ override def getDatetimeRebaseMode(fileMetaData: FileMetaData, parquetOptions: ParquetOptions):
+ SQLConf.LegacyBehaviorPolicy.Value = {
+ getDatetimeRebaseSpec(fileMetaData, parquetOptions).mode
+ }
+
+ override def newParquetFilters(parquetSchema: MessageType,
+ pushDownDate: Boolean,
+ pushDownTimestamp: Boolean,
+ pushDownDecimal: Boolean,
+ pushDownStringStartWith: Boolean,
+ pushDownInFilterThreshold: Int,
+ isCaseSensitive: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions):
+ ParquetFilters = {
+ return new ParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp,
+ pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold,
+ isCaseSensitive,
+ getDatetimeRebaseSpec(fileMetaData, parquetOptions))
+ }
+
+ override def newVectorizedParquetRecordReader(convertTz: ZoneId,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions,
+ useOffHeap: Boolean,
+ capacity: Int): VectorizedParquetRecordReader = {
+ val rebaseSpec = getDatetimeRebaseSpec(fileMetaData: FileMetaData, parquetOptions: ParquetOptions)
+ // TODO: int96RebaseMode & int96RebaseTz are set to "", need to verify.
+ new VectorizedParquetRecordReader(convertTz, rebaseSpec.mode.toString,
+ rebaseSpec.timeZone, "", "", useOffHeap, capacity)
+ }
+
+ override def newParquetReadSupport(convertTz: Option[ZoneId],
+ enableVectorizedReader: Boolean,
+ fileMetaData: FileMetaData,
+ parquetOptions: ParquetOptions): ParquetReadSupport = {
+ val datetimeRebaseSpec = getDatetimeRebaseSpec(fileMetaData, parquetOptions)
+ val int96RebaseSpec = getInt96RebaseSpec(fileMetaData, parquetOptions)
+ new ParquetReadSupport(convertTz, enableVectorizedReader, datetimeRebaseSpec, int96RebaseSpec)
+ }
+
+ override def getRuntimeFilters(plan: BatchScanExec): Seq[Expression] = {
+ return plan.runtimeFilters
+ }
+
+ override def getBroadcastHashJoinOutputPartitioningExpandLimit(plan: SparkPlan): Int = {
+ plan.conf.broadcastHashJoinOutputPartitioningExpandLimit
+ }
+
+ override def newSortShuffleWriter(resolver: MigratableResolver, shuffleHandle: ShuffleHandle,
+ mapId: Long, context: TaskContext,
+ shuffleExecutorComponents: ShuffleExecutorComponents):
+ AnyRef = {
+ ShimUtils.newSortShuffleWriter(
+ resolver,
+ shuffleHandle,
+ mapId,
+ context,
+ shuffleExecutorComponents)
+ }
+
+ /** TODO: to see whether the below piece of code can be used for both spark 3.1/3.2.
+ * */
+ override def getMaxBroadcastRows(mode: BroadcastMode): Long = {
+ // The below code is ported from BroadcastExchangeExec of spark 3.2.
+ val maxBroadcastRows = mode match {
+ case HashedRelationBroadcastMode(key, _)
+ // NOTE: LongHashedRelation is used for single key with LongType. This should be kept
+ // consistent with HashedRelation.apply.
+ if !(key.length == 1 && key.head.dataType == LongType) =>
+ // Since the maximum number of keys that BytesToBytesMap supports is 1 << 29,
+ // and only 70% of the slots can be used before growing in UnsafeHashedRelation,
+ // here the limitation should not be over 341 million.
+ (BytesToBytesMap.MAX_CAPACITY / 1.5).toLong
+ case _ => 512000000
+ }
+ maxBroadcastRows
+ }
+
+ override def getSparkSession(plan: SparkPlan): SparkSession = {
+ plan.session
+ }
+
+ override def doFetchFile(urlString: String, targetDirHandler: File,
+ targetFileName: String, sparkConf: SparkConf): Unit = {
+ ShimUtils.doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf)
+ }
+
+// /**
+// * Fix compatibility issue that ShuffleQueryStageExec has an additional argument in spark 3.2.
+// * ShuffleExchangeExec replaces ColumnarShuffleExchangeAdaptor to avoid cyclic dependency. This
+// * changes need futher test to verify.
+// */
+// override def outputPartitioningForColumnarCustomShuffleReaderExec(child: SparkPlan): Partitioning = {
+// child match {
+// case ShuffleQueryStageExec(_, s: ShuffleExchangeExec, _) =>
+// s.child.outputPartitioning
+// case ShuffleQueryStageExec(
+// _,
+// r @ ReusedExchangeExec(_, s: ShuffleExchangeExec), _) =>
+// s.child.outputPartitioning match {
+// case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
+// case other => other
+// }
+// case _ =>
+// throw new IllegalStateException("operating on canonicalization plan")
+// }
+// }
+
+ override def newBroadcastQueryStageExec(id: Int, plan: BroadcastExchangeExec):
+ BroadcastQueryStageExec = {
+ BroadcastQueryStageExec(id, plan, plan.doCanonicalize)
+ }
+
+ /**
+ * CustomShuffleReaderExec is renamed to AQEShuffleReadExec from spark 3.2.
+ */
+ override def isCustomShuffleReaderExec(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: AQEShuffleReadExec => true
+ case _ => false
+ }
+ }
+
+ override def newCustomShuffleReaderExec(child: SparkPlan, partitionSpecs:
+ Seq[ShufflePartitionSpec]): SparkPlan = {
+ AQEShuffleReadExec(child, partitionSpecs)
+ }
+
+ /**
+ * Only applicable to AQEShuffleReadExec. Otherwise, an exception will be thrown.
+ */
+ override def getChildOfCustomShuffleReaderExec(plan: SparkPlan): SparkPlan = {
+ plan match {
+ case p: AQEShuffleReadExec => p.child
+ case _ => throw new RuntimeException("AQEShuffleReadExec is expected!")
+ }
+ }
+
+ /**
+ * Only applicable to AQEShuffleReadExec. Otherwise, an exception will be thrown.
+ */
+ override def getPartitionSpecsOfCustomShuffleReaderExec(plan: SparkPlan):
+ Seq[ShufflePartitionSpec] = {
+ plan match {
+ case p: AQEShuffleReadExec => p.partitionSpecs
+ case _ => throw new RuntimeException("AQEShuffleReadExec is expected!")
+ }
+ }
+
+ override def isRepartition(shuffleOrigin: ShuffleOrigin): Boolean = {
+ shuffleOrigin match {
+ case REPARTITION_BY_COL => true
+ case _ => false
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/SparkShimProvider.scala b/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/SparkShimProvider.scala
new file mode 100644
index 000000000..15b013189
--- /dev/null
+++ b/shims/spark321/src/main/scala/com/intel/oap/sql/shims/spark321/SparkShimProvider.scala
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2020 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.intel.oap.sql.shims.spark321
+
+import com.intel.oap.sql.shims.{SparkShims, SparkShimDescriptor}
+
+object SparkShimProvider {
+ val DESCRIPTOR = SparkShimDescriptor(3, 2, 1)
+ val DESCRIPTOR_STRINGS = Seq(s"$DESCRIPTOR")
+}
+
+class SparkShimProvider extends com.intel.oap.sql.shims.SparkShimProvider {
+ def createShim: SparkShims = {
+ new Spark321Shims()
+ }
+
+ def matches(version: String): Boolean = {
+ SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
+ }
+}
diff --git a/shims/spark321/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala b/shims/spark321/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala
new file mode 100644
index 000000000..c948ac6f6
--- /dev/null
+++ b/shims/spark321/src/main/scala/org/apache/spark/sql/BasePythonRunnerChild.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+import java.io.DataInputStream
+import java.net.Socket
+import java.util.concurrent.atomic.AtomicBoolean
+
+import com.intel.oap.vectorized.ArrowWritableColumnVector
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, SpecialLengths}
+import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
+
+/**
+ * We put this class in this package in order to access ArrowUtils, which has
+ * private[sql] access modifier.
+ */
+abstract class BasePythonRunnerChild(funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]])
+ extends BasePythonRunner[ColumnarBatch, ColumnarBatch](funcs, evalType, argOffsets) {
+
+ /**
+ * The implementation is completely as same as that for spark 3.1 except the arguments for
+ * newReaderIterator & ReaderIterator. The pid: Option[Int] is introduced from spark 3.2.
+ * The pid is not truely utilized here.
+ * TODO: put the implementation into some common place.
+ */
+ protected def newReaderIterator(
+ stream: DataInputStream,
+ writerThread: WriterThread,
+ startTime: Long,
+ env: SparkEnv,
+ worker: Socket,
+ pid: Option[Int],
+ releasedOrClosed: AtomicBoolean,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+
+ // The pid argument is added for ReaderIterator since spark3.2.
+ new ReaderIterator(stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
+ private val allocator = SparkMemoryUtils.contextAllocator().newChildAllocator(
+ s"stdin reader for $pythonExec", 0, Long.MaxValue)
+
+ private var reader: ArrowStreamReader = _
+ private var root: VectorSchemaRoot = _
+ private var schema: StructType = _
+ private var vectors: Array[ColumnVector] = _
+
+ context.addTaskCompletionListener[Unit] { _ =>
+ if (reader != null) {
+ reader.close(false)
+ }
+ allocator.close()
+ }
+
+ private var batchLoaded = true
+
+ protected override def read(): ColumnarBatch = {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ try {
+ if (reader != null && batchLoaded) {
+ batchLoaded = reader.loadNextBatch()
+ if (batchLoaded) {
+ val batch = new ColumnarBatch(vectors)
+ batch.setNumRows(root.getRowCount)
+ batch
+ } else {
+ reader.close(false)
+ allocator.close()
+ // Reach end of stream. Call `read()` again to read control data.
+ read()
+ }
+ } else {
+ stream.readInt() match {
+ case SpecialLengths.START_ARROW_STREAM =>
+ reader = new ArrowStreamReader(stream, allocator)
+ root = reader.getVectorSchemaRoot()
+ schema = ArrowUtils.fromArrowSchema(root.getSchema())
+ vectors = ArrowWritableColumnVector.loadColumns(root.getRowCount, root.getFieldVectors).toArray[ColumnVector]
+ read()
+ case SpecialLengths.TIMING_DATA =>
+ handleTimingData()
+ read()
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+ throw handlePythonException()
+ case SpecialLengths.END_OF_DATA_SECTION =>
+ handleEndOfDataSection()
+ null
+ }
+ }
+ } catch handleException
+ }
+ }
+ }
+
+}
\ No newline at end of file
diff --git a/shims/spark321/src/main/scala/org/apache/spark/util/ShimUtils.scala b/shims/spark321/src/main/scala/org/apache/spark/util/ShimUtils.scala
new file mode 100644
index 000000000..551b03749
--- /dev/null
+++ b/shims/spark321/src/main/scala/org/apache/spark/util/ShimUtils.scala
@@ -0,0 +1,69 @@
+/*
+ * Copyright 2020 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.MigratableResolver
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.SortShuffleWriter
+
+object ShimUtils {
+
+ /**
+ * Only applicable to IndexShuffleBlockResolver. We move the implementation here, because
+ * IndexShuffleBlockResolver's access modifier is private[spark].
+ */
+ def shuffleBlockResolverWriteAndCommit(shuffleBlockResolver: MigratableResolver,
+ shuffleId: Int, mapId: Long, partitionLengths: Array[Long], dataTmp: File): Unit = {
+ shuffleBlockResolver match {
+ case resolver: IndexShuffleBlockResolver =>
+ resolver.writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, null, dataTmp)
+ case _ => throw new RuntimeException ("IndexShuffleBlockResolver is expected!")
+ }
+ }
+
+ def newSortShuffleWriter(resolver: MigratableResolver, shuffleHandle: ShuffleHandle,
+ mapId: Long, context: TaskContext,
+ shuffleExecutorComponents: ShuffleExecutorComponents): AnyRef = {
+
+ shuffleHandle match {
+ case baseShuffleHandle: BaseShuffleHandle[_, _, _] =>
+ new SortShuffleWriter(
+ baseShuffleHandle,
+ mapId,
+ context,
+ shuffleExecutorComponents)
+ case _ => throw new RuntimeException("BaseShuffleHandle is expected!")
+ }
+ }
+
+ /**
+ * We move the implementation into this package because Utils has private[spark]
+ * access modifier.
+ */
+ def doFetchFile(urlString: String, targetDirHandler: File,
+ targetFileName: String, sparkConf: SparkConf): Unit = {
+ Utils.doFetchFile(urlString, targetDirHandler, targetFileName, sparkConf, null)
+ }
+
+}
\ No newline at end of file