Skip to content

Commit

Permalink
Returns min/max results from Delta Stats
Browse files Browse the repository at this point in the history
Signed-off-by: Felipe Fujiy Pessoto <[email protected]>
  • Loading branch information
felipepessoto committed Dec 17, 2022
1 parent 301bfd6 commit a3ace44
Show file tree
Hide file tree
Showing 2 changed files with 840 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,299 @@
package org.apache.spark.sql.delta.perf

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.delta.DeltaTable
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaTable, Snapshot}
import org.apache.spark.sql.delta.files.TahoeLogFileIndex
import org.apache.spark.sql.delta.stats.DeltaScanGenerator
import org.apache.spark.sql.functions.{col, count, sum, when}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import scala.collection.immutable.HashSet

trait OptimizeMetadataOnlyDeltaQuery {
def optimizeQueryWithMetadata(plan: LogicalPlan): LogicalPlan = {
plan.transformUpWithSubqueries {
case agg@CountStarDeltaTable(countValue) =>
LocalRelation(agg.output, Seq(InternalRow(countValue)))
case agg@AggregateDeltaTable(tahoeLogFileIndex) =>
createLocalRelationPlan(agg, tahoeLogFileIndex)
}
}

protected def getDeltaScanGenerator(index: TahoeLogFileIndex): DeltaScanGenerator

object CountStarDeltaTable {
def unapply(plan: Aggregate): Option[Long] = plan match {
case Aggregate(
Nil,
Seq(Alias(AggregateExpression(Count(Seq(Literal(1, _))), Complete, false, None, _), _)),
PhysicalOperation(_, Nil, DeltaTable(i: TahoeLogFileIndex))) if i.partitionFilters.isEmpty
=> extractGlobalCount(i)
protected def createLocalRelationPlan(
plan: Aggregate,
tahoeLogFileIndex: TahoeLogFileIndex): LogicalPlan = {
val rowCount = extractGlobalCount(tahoeLogFileIndex)

if (rowCount.isDefined) {
lazy val columnStats = extractGlobalColumnStats(tahoeLogFileIndex)

val aggregatedValues = plan.aggregateExpressions.collect {
case Alias(AggregateExpression(
Count(Seq(Literal(1, _))), Complete, false, None, _), _) =>
rowCount.get
case Alias(AggregateExpression(
Min(minReference: AttributeReference), Complete, false, None, _), _)
if columnStats.contains(minReference.name) &&
// Avoid StructType, it is not supported by this optimization
// Sanity check only. minReference would be GetStructType if it is a Struct column
minReference.references.size == 1 &&
minReference.references.head.dataType != StructType =>
val value = if (minReference.dataType == DateType
&& columnStats(minReference.name).min != null) {
DateTimeUtils.fromJavaDate(
columnStats(minReference.name).min.asInstanceOf[java.sql.Date])
} else {
columnStats(minReference.name).min
}
value
case Alias(AggregateExpression(
Max(maxReference: AttributeReference), Complete, false, None, _), _)
if columnStats.contains(maxReference.name) &&
// Avoid StructType, it is not supported by this optimization
// Sanity check only. maxReference would be GetStructType if it is a Struct column
maxReference.references.size == 1 &&
maxReference.references.head.dataType != StructType =>
val value = if (maxReference.dataType == DateType
&& columnStats(maxReference.name).max != null) {
DateTimeUtils.fromJavaDate(
columnStats(maxReference.name).max.asInstanceOf[java.sql.Date])
} else {
columnStats(maxReference.name).max
}
value
}

if (plan.aggregateExpressions.size == aggregatedValues.size) {
val r = LocalRelation(
plan.output,
Seq(InternalRow.fromSeq(aggregatedValues)))
r
} else {
plan
}
}
else {
plan
}
}

object AggregateDeltaTable {
def unapply(plan: Aggregate): Option[TahoeLogFileIndex] = plan match {
case Aggregate(Nil,
seqTest: Seq[Alias],
PhysicalOperation(projectList, Nil, DeltaTable(i: TahoeLogFileIndex)))
if i.partitionFilters.isEmpty
&& projectList.forall {
case _: AttributeReference => true
// Disable the optimization if Project is renaming the column
// to avoid getting the incorrect column from stats, example:
// SELECT MAX(Column2) FROM (SELECT Column1 AS Column2 FROM TableName)
// We could create a mapping (alias -> actual name) to avoid the problem
case a@Alias(_, _) => a.child.references.size == 1 &&
a.name.equals(a.child.references.head.name)
case _ => false
}
&& seqTest.forall {
case Alias(AggregateExpression(
Count(Seq(Literal(1, _))) | Min(_) | Max(_), Complete, false, None, _), _) => true
case _ => false
} =>
Some(i)
// When all columns are selected, there are no Project/PhysicalOperation
case Aggregate(Nil,
seqTest: Seq[Alias],
DeltaTable(i: TahoeLogFileIndex))
if i.partitionFilters.isEmpty
&& seqTest.forall {
case Alias(AggregateExpression(
Count(Seq(Literal(1, _))) | Min(_) | Max(_), Complete, false, None, _), _) => true
case _ => false
} =>
Some(i)
case _ => None
}
}

/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
.agg(
sum("stats.numRecords"),
// Calculate the number of files missing `numRecords`
count(when(col("stats.numRecords").isNull, 1)))
.first

// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
// the sum agg is either null (for an empty table) or gives an accurate record count.
if (row.getLong(1) > 0) return None
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
Some(numRecords)
}

val columnStatsSupportedDataTypes: HashSet[DataType] = HashSet(
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DateType)

case class DeltaColumnStat(
min: Any,
max: Any,
nullCount: Option[Long],
distinctCount: Option[Long])

def extractGlobalColumnStats(tahoeLogFileIndex: TahoeLogFileIndex):
CaseInsensitiveMap[DeltaColumnStat] = {

// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)

val deltaScanGenerator = getDeltaScanGenerator(tahoeLogFileIndex)
val snapshot = deltaScanGenerator.snapshotToScan

def extractGlobalColumnStatsDeltaLog(snapshot: Snapshot):
Map[String, DeltaColumnStat] = {

val dataColumns = snapshot.statCollectionSchema
.filter(col => columnStatsSupportedDataTypes.contains(col.dataType))

/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
.agg(
sum("stats.numRecords"),
// Calculate the number of files missing `numRecords`
count(when(col("stats.numRecords").isNull, 1)))
.first

// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
// the sum agg is either null (for an empty table) or gives an accurate record count.
if (row.getLong(1) > 0) return None
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
Some(numRecords)
// Validate all the files has stats
val filesStatsCount = deltaScanGenerator.filesWithStatsForScan(Nil).select(
count(when(col("stats.numRecords").isNull, 1)).as("missingNumRecords"),
count(when(col("stats.numRecords") > 0, 1)).as("countNonEmptyFiles")).head

val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
// DELETE operations creates AddFile records with 0 rows, and no column stats.
// We can safely ignore it since there is no data.
lazy val files = deltaScanGenerator.filesWithStatsForScan(Nil)
.filter(col("stats.numRecords") > 0)
val numFiles: Long = filesStatsCount.getAs[Long]("countNonEmptyFiles")
lazy val statsMinMaxNullColumns = files.select(col("stats.*"))
if (dataColumns.isEmpty
|| !allRecordsHasStats
|| numFiles == 0
|| !statsMinMaxNullColumns.columns.contains("minValues")
|| !statsMinMaxNullColumns.columns.contains("maxValues")
|| !statsMinMaxNullColumns.columns.contains("nullCount")) {
Map.empty
} else {
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
// has been increased
val columnsWithStats = files.select(
col("stats.minValues.*"),
col("stats.maxValues.*"),
col("stats.nullCount.*"))
.columns.groupBy(identity).mapValues(_.size)
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
.map(x => x._1).toSet

// Creates a tuple with physical name to avoid recalculating it multiple times
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
.filter(x => columnsWithStats.contains(x._2))

val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val dataType = columnAndPhysicalName._1.dataType
val physicalName = columnAndPhysicalName._2

Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))

val minMaxNullCountExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val physicalName = columnAndPhysicalName._2

// To validate if the column has stats we do two validation:
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
// since nullCount is always non-null.
// 2-) The number of files with non-null min/max:
// a. count(min.columnName)|count(max.columnName) +
// the number of files where all rows are NULL:
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
// should be equals to numFiles
Seq(
s"""case when $numFiles = count(`nullCount.$physicalName`)
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
s"min(`min.$physicalName`) as `min_$physicalName`",
s"max(`max.$physicalName`) as `max_$physicalName`",
s"sum(`nullCount.$physicalName`) as `nullCount_$physicalName`")
}

val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxNullCountExpr: _*).head

dataColumnsWithStats
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
.map { columnAndPhysicalName =>
val column = columnAndPhysicalName._1
val physicalName = columnAndPhysicalName._2
column.name ->
DeltaColumnStat(
statsResults.getAs(s"min_$physicalName"),
statsResults.getAs(s"max_$physicalName"),
Some(statsResults.getAs[Long](s"min_$physicalName")),
None)
}.toMap
}
}

def extractGlobalPartitionedColumnStatsDeltaLog(snapshot: Snapshot):
Map[String, DeltaColumnStat] = {

val partitionedColumns = snapshot.metadata.partitionSchema
.filter(x => columnStatsSupportedDataTypes.contains(x.dataType))
.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))

if (partitionedColumns.isEmpty) {
Map.empty
} else {
val partitionedColumnsValues = partitionedColumns.map { partitionedColumn =>
val physicalName = partitionedColumn._2
col(s"partitionValues.`$physicalName`")
.cast(partitionedColumn._1.dataType).as(physicalName)
}

val partitionedColumnsAgg = partitionedColumns.flatMap { partitionedColumn =>
val physicalName = partitionedColumn._2

Seq(min(s"`$physicalName`").as(s"min_$physicalName"),
max(s"`$physicalName`").as(s"max_$physicalName"),
count_distinct(col(s"`$physicalName`")).as(s"nullCount_$physicalName"))
}

val partitionedColumnsQuery = snapshot.allFiles
.select(partitionedColumnsValues: _*)
.agg(partitionedColumnsAgg.head, partitionedColumnsAgg.tail: _*)
.head()

partitionedColumns.map { partitionedColumn =>
val physicalName = partitionedColumn._2

partitionedColumn._1.name ->
DeltaColumnStat(
partitionedColumnsQuery.getAs(s"min_$physicalName"),
partitionedColumnsQuery.getAs(s"max_$physicalName"),
None,
Some(partitionedColumnsQuery.getAs[Long](s"nullCount_$physicalName")))
}.toMap
}
}

CaseInsensitiveMap(
extractGlobalColumnStatsDeltaLog(snapshot).++
(extractGlobalPartitionedColumnStatsDeltaLog(snapshot)))
}
}
Loading

0 comments on commit a3ace44

Please sign in to comment.