Skip to content

Commit

Permalink
Fix Databricks shim layer for GpuFileSourceScanExec and GpuBroadcastE…
Browse files Browse the repository at this point in the history
…xchangeExec (#571)

* Make GpuFileSourceScanExec work with Databricks

Signed-off-by: Thomas Graves <[email protected]>

* Add in GpuFileScanRDD

Signed-off-by: Thomas Graves <[email protected]>

* cleanup

Signed-off-by: Thomas Graves <[email protected]>

* Rework to get PartitionedFiles only

Signed-off-by: Thomas Graves <[email protected]>

* remove commented out code

Signed-off-by: Thomas Graves <[email protected]>

* Fix spacing in pom

Signed-off-by: Thomas Graves <[email protected]>

* Add gpu broadcast get function and fix names

Signed-off-by: Thomas Graves <[email protected]>

* remove unused imports

Signed-off-by: Thomas Graves <[email protected]>

Co-authored-by: Thomas Graves <[email protected]>
  • Loading branch information
tgravescs and tgravescs authored Aug 18, 2020
1 parent bb39b41 commit e26e961
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shims.spark300

import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuBroadcastExchangeExecBase}
import org.apache.spark.sql.rapids.execution.GpuBroadcastExchangeExecBase

case class GpuBroadcastExchangeExec(
mode: BroadcastMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.spark300.RapidsShuffleManager

import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
Expand All @@ -31,12 +33,12 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastMeta, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, GpuShuffleMeta}
import org.apache.spark.sql.rapids.shims.spark300._
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -236,4 +238,52 @@ class Spark300Shims extends SparkShims {
override def getShuffleManagerShims(): ShuffleManagerShimBase = {
new ShuffleManagerShim
}

override def getPartitionFileNames(
partitions: Seq[PartitionDirectory]): Seq[String] = {
val files = partitions.flatMap(partition => partition.files)
files.map(_.getPath.getName)
}

override def getPartitionFileStatusSize(partitions: Seq[PartitionDirectory]): Long = {
partitions.map(_.files.map(_.getLen).sum).sum
}

override def getPartitionedFiles(
partitions: Array[PartitionDirectory]): Array[PartitionedFile] = {
partitions.flatMap { p =>
p.files.map { f =>
PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
}
}
}

override def getPartitionSplitFiles(
partitions: Array[PartitionDirectory],
maxSplitBytes: Long,
relation: HadoopFsRelation): Array[PartitionedFile] = {
partitions.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession, relation.options, filePath)
PartitionedFileUtil.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
)
}
}
}

override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: (PartitionedFile) => Iterator[InternalRow],
filePartitions: Seq[FilePartition]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions)
}
}
13 changes: 13 additions & 0 deletions shims/spark300db/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
<version>0.2.0-SNAPSHOT</version>

<properties>
<parquet.version>1.10.1</parquet.version>
<spark30db.version>3.0.0-databricks</spark30db.version>
</properties>

Expand Down Expand Up @@ -59,6 +60,18 @@
<version>${spark30db.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-annotation_${scala.binary.version}</artifactId>
<version>${spark30db.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-column</artifactId>
<version>${parquet.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@ import java.time.ZoneId

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark300.Spark300Shims
import org.apache.spark.sql.rapids.shims.spark300db._
import org.apache.hadoop.fs.Path

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkEnv
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub}
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastMeta, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, GpuShuffleMeta}
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}

Expand All @@ -48,6 +55,12 @@ class Spark300dbShims extends Spark300Shims {
GpuBroadcastNestedLoopJoinExec(left, right, join, joinType, condition, targetSizeBytes)
}

override def getGpuBroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan): GpuBroadcastExchangeExecBase = {
GpuBroadcastExchangeExec(mode, child)
}

override def isGpuHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuHashJoin => true
Expand Down Expand Up @@ -118,4 +131,53 @@ class Spark300dbShims extends Spark300Shims {
override def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide = {
GpuJoinUtils.getGpuBuildSide(join.buildSide)
}

// Databricks has a different version of FileStatus
override def getPartitionFileNames(
partitions: Seq[PartitionDirectory]): Seq[String] = {
val files = partitions.flatMap(partition => partition.files)
files.map(_.getPath.getName)
}

override def getPartitionFileStatusSize(partitions: Seq[PartitionDirectory]): Long = {
partitions.map(_.files.map(_.getLen).sum).sum
}

override def getPartitionedFiles(
partitions: Array[PartitionDirectory]): Array[PartitionedFile] = {
partitions.flatMap { p =>
p.files.map { f =>
PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
}
}
}

override def getPartitionSplitFiles(
partitions: Array[PartitionDirectory],
maxSplitBytes: Long,
relation: HadoopFsRelation): Array[PartitionedFile] = {
partitions.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession, relation.options, filePath)
PartitionedFileUtil.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values
)
}
}
}

override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: (PartitionedFile) => Iterator[InternalRow],
filePartitions: Seq[FilePartition]): RDD[InternalRow] = {
new GpuFileScanRDD(sparkSession, readFunction, filePartitions)
}
}
Loading

0 comments on commit e26e961

Please sign in to comment.