Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add scalar subquery pushdown to scan #678

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions dev/diffs/4.0.0-preview1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ index 56c364e2084..11779ee3b4b 100644
withTable("dt") {
sql("create table dt using parquet as select 9000000000BD as d")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 68f14f13bbd..4b8e967102f 100644
index 68f14f13bbd..174636cefb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -938,16 +938,41 @@ index 68f14f13bbd..4b8e967102f 100644
}
assert(exchanges.size === 1)
}
@@ -2668,7 +2675,8 @@ class SubquerySuite extends QueryTest
}
}

- test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery") {
+ test("SPARK-43402: FileSourceScanExec supports push down data filter with scalar subquery",
+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/551")) {
@@ -2672,18 +2679,26 @@ class SubquerySuite extends QueryTest
def checkFileSourceScan(query: String, answer: Seq[Row]): Unit = {
val df = sql(query)
checkAnswer(df, answer)
- val fileSourceScanExec = collect(df.queryExecution.executedPlan) {
- case f: FileSourceScanExec => f
+ val dataSourceScanExec = collect(df.queryExecution.executedPlan) {
+ case f: FileSourceScanLike => f
+ case c: CometScanExec => c
}
sparkContext.listenerBus.waitUntilEmpty()
kazuyukitanimura marked this conversation as resolved.
Show resolved Hide resolved
- assert(fileSourceScanExec.size === 1)
- val scalarSubquery = fileSourceScanExec.head.dataFilters.flatMap(_.collect {
- case s: ScalarSubquery => s
- })
+ assert(dataSourceScanExec.size === 1)
+ val scalarSubquery = dataSourceScanExec.head match {
+ case f: FileSourceScanLike =>
+ f.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ case c: CometScanExec =>
+ c.dataFilters.flatMap(_.collect {
+ case s: ScalarSubquery => s
+ })
+ }
assert(scalarSubquery.length === 1)
assert(scalarSubquery.head.plan.isInstanceOf[ReusedSubqueryExec])
- assert(fileSourceScanExec.head.metrics("numFiles").value === 1)
- assert(fileSourceScanExec.head.metrics("numOutputRows").value === answer.size)
+ assert(dataSourceScanExec.head.metrics("numFiles").value === 1)
+ assert(dataSourceScanExec.head.metrics("numOutputRows").value === answer.size)
}

withTable("t1", "t2") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index 1de535df246..cc7ffc4eeb3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ case class CometScanExec(
(wrapped.outputPartitioning, wrapped.outputOrdering)

@transient
private lazy val pushedDownFilters = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}
private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters)

override lazy val metadata: Map[String, String] =
if (wrapped == null) Map.empty else wrapped.metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -67,4 +68,9 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(sparkSession, file, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, FileSourceConstantMetadataAttribute, Literal}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, ScalarSubquery}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -68,4 +69,30 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
translateToV1Filters(relation, dataFilters, _.toLiteral)
}

// From Spark FileSourceScanLike
private def translateToV1Filters(relation: HadoopFsRelation,
dataFilters: Seq[Expression],
scalarSubqueryToLiteral: ScalarSubquery => Literal): Seq[Filter] = {
val scalarSubqueryReplaced = dataFilters.map(_.transform {
// Replace scalar subquery to literal so that `DataSourceStrategy.translateFilter` can
// support translating it.
case scalarSubquery: ScalarSubquery => scalarSubqueryToLiteral(scalarSubquery)
})

val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
// `dataFilters` should not include any constant metadata col filters
// because the metadata struct has been flatted in FileSourceStrategy
// and thus metadata col filters are invalid to be pushed down. Metadata that is generated
// during the scan can be used for filters.
scalarSubqueryReplaced.filterNot(_.references.exists {
case FileSourceConstantMetadataAttribute(_) => true
case _ => false
}).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
package org.apache.spark.sql.comet.shims

import org.apache.comet.shims.ShimFileFormat

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

trait ShimCometScanExec {
Expand Down Expand Up @@ -102,4 +101,10 @@ trait ShimCometScanExec {
maxSplitBytes: Long,
partitionValues: InternalRow): Seq[PartitionedFile] =
PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, maxSplitBytes, partitionValues)

protected def getPushedDownFilters(relation: HadoopFsRelation , dataFilters: Seq[Expression]): Seq[Filter] = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}

}
Loading