Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix merge conflict 3122 from branch-21.06 [skip ci] #3124

Merged
merged 2 commits into from
Aug 4, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix databricks 3.0.1 for ParquetFilters api change (#3118)
* Update spark 301db shim to new ParquetFilter api (#3100)

Signed-off-by: Thomas Graves <[email protected]>

* Fix ParquetFilters issue (#2961)

* Fix ParquetFilters issue

Spark has added datetime rebase mode into constructor of ParquetFilters,
see SPARK-36034, which causes the build failed.

This change is adding getParquetFilters in the shim layers to match
different Spark version.

Signed-off-by: Bobby Wang <[email protected]>

* Fix compiling error

Co-authored-by: Bobby Wang <[email protected]>
tgravescs and wbo4958 authored Aug 4, 2021
commit c14d0a0337d5d240cd999283a2154a861f4d343a
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ import com.nvidia.spark.rapids._
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
@@ -41,6 +42,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -80,6 +82,19 @@ abstract class SparkBaseShims extends SparkShims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters = {
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive)
}

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
AlterTableRecoverPartitionsCommand(tableName)

Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import com.databricks.sql.execution.window.RunningWindowFunctionExec
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark301.Spark301Shims
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.sql.rapids.shims.spark301db._
import org.apache.spark.rdd.RDD
@@ -33,11 +34,13 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.{AggregateInPandasExec, ArrowEvalPythonExec, FlatMapGroupsInPandasExec, MapInPandasExec, WindowInPandasExec}
import org.apache.spark.sql.execution.window.WindowExecBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.python.{GpuAggregateInPandasExecMeta, GpuArrowEvalPythonExec, GpuFlatMapGroupsInPandasExecMeta, GpuMapInPandasExecMeta, GpuPythonUDF, GpuWindowInPandasExecMetaBase}
@@ -47,6 +50,18 @@ class Spark301dbShims extends Spark301Shims {

override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)

override def getGpuBroadcastNestedLoopJoinShim(
left: SparkPlan,
right: SparkPlan,
Original file line number Diff line number Diff line change
@@ -19,6 +19,10 @@ package com.nvidia.spark.rapids.shims.spark313
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark312.Spark312Shims
import com.nvidia.spark.rapids.spark313.RapidsShuffleManager
import org.apache.parquet.schema.MessageType

import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.internal.SQLConf

class Spark313Shims extends Spark312Shims {

@@ -27,4 +31,16 @@ class Spark313Shims extends Spark312Shims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode)
}
Original file line number Diff line number Diff line change
@@ -44,6 +44,18 @@ class Spark320Shims extends Spark311Shims {
override def parquetRebaseWrite(conf: SQLConf): String =
conf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)

override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: SQLConf.LegacyBehaviorPolicy.Value): ParquetFilters =
new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith,
pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode, datetimeRebaseMode)

override def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand =
RepairTableCommand(tableName,
// These match the one place that this is called, if we start to call this in more places
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, ParquetReadSupport}
import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
@@ -257,8 +257,8 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
private val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
private val isCorrectedRebase =
"CORRECTED" == ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val rebaseMode = ShimLoader.getSparkShims.parquetRebaseRead(sqlConf)
private val isCorrectedRebase = "CORRECTED" == rebaseMode

def filterBlocks(
file: PartitionedFile,
@@ -272,8 +272,11 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte
ParquetMetadataConverter.range(file.start, file.start + file.length))
val fileSchema = footer.getFileMetaData.getSchema
val pushedFilters = if (enableParquetFilterPushDown) {
val parquetFilters = new ParquetFilters(fileSchema, pushDownDate, pushDownTimestamp,
pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
footer.getFileMetaData.getKeyValueMetaData.get, rebaseMode)
val parquetFilters = ShimLoader.getSparkShims.getParquetFilters(fileSchema, pushDownDate,
pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold,
isCaseSensitive, datetimeRebaseMode)
filters.flatMap(parquetFilters.createFilter).reduceOption(FilterApi.and)
} else {
None
13 changes: 13 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import org.apache.arrow.memory.ReferenceManager
import org.apache.arrow.vector.ValueVector
import org.apache.hadoop.fs.Path
import org.apache.parquet.schema.MessageType

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
@@ -40,9 +41,11 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.sources.BaseRelation
@@ -88,6 +91,16 @@ trait SparkShims {
def parquetRebaseWrite(conf: SQLConf): String
def v1RepairTableCommand(tableName: TableIdentifier): RunnableCommand

def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value): ParquetFilters

def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isWindowFunctionExec(plan: SparkPlan): Boolean