Skip to content

Commit

Permalink
Extract Count and Min/Max in a single method. Allows to extract Min/M…
Browse files Browse the repository at this point in the history
…ax from partitioned columns even when COUNT is not available

Fix style error

Signed-off-by: Felipe Fujiy Pessoto <[email protected]>
  • Loading branch information
felipepessoto committed Nov 22, 2023
1 parent 6ba89ce commit f290bc6
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,66 +55,59 @@ trait OptimizeMetadataOnlyDeltaQuery {
private def createLocalRelationPlan(
plan: Aggregate,
tahoeLogFileIndex: TahoeLogFileIndex): LogicalPlan = {
val rowCount = extractGlobalCount(tahoeLogFileIndex)

if (rowCount.isDefined) {
val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*)
val columnStats = extractMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames)

def checkStatsExists(attrRef: AttributeReference): Boolean = {
columnStats.contains(attrRef.name) &&
// Avoid StructType, it is not supported by this optimization
// Sanity check only. If reference is nested column it would be GetStructType
// instead of AttributeReference
attrRef.references.size == 1 &&
attrRef.references.head.dataType != StructType
}

def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = {
if (attrRef.dataType == DateType && value != null) {
DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
} else {
value
}
}
val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*)
val (rowCount, columnStats) = extractCountMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames)

val rewrittenAggregationValues = plan.aggregateExpressions.collect {
case Alias(AggregateExpression(
Count(Seq(Literal(1, _))), Complete, false, None, _), _) =>
rowCount.get
case Alias(tps@ToPrettyString(AggregateExpression(
Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) =>
tps.copy(child = Literal(rowCount.get)).eval()
case Alias(AggregateExpression(
Min(minReference: AttributeReference), Complete, false, None, _), _)
if checkStatsExists(minReference) =>
convertValueIfRequired(minReference, columnStats(minReference.name).min)
case Alias(tps@ToPrettyString(AggregateExpression(
Min(minReference: AttributeReference), Complete, false, None, _), _), _)
if checkStatsExists(minReference) =>
val v = columnStats(minReference.name).min
tps.copy(child = Literal(v)).eval()
case Alias(AggregateExpression(
Max(maxReference: AttributeReference), Complete, false, None, _), _)
if checkStatsExists(maxReference) =>
convertValueIfRequired(maxReference, columnStats(maxReference.name).max)
case Alias(tps@ToPrettyString(AggregateExpression(
Max(maxReference: AttributeReference), Complete, false, None, _), _), _)
if checkStatsExists(maxReference) =>
val v = columnStats(maxReference.name).max
tps.copy(child = Literal(v)).eval()
}
def checkStatsExists(attrRef: AttributeReference): Boolean = {
columnStats.contains(attrRef.name) &&
// Avoid StructType, it is not supported by this optimization
// Sanity check only. If reference is nested column it would be GetStructType
// instead of AttributeReference
attrRef.references.size == 1 && attrRef.references.head.dataType != StructType
}

if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) {
val r = LocalRelation(
plan.output,
Seq(InternalRow.fromSeq(rewrittenAggregationValues)))
r
def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = {
if (attrRef.dataType == DateType && value != null) {
DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
} else {
plan
value
}
}
else {

val rewrittenAggregationValues = plan.aggregateExpressions.collect {
case Alias(AggregateExpression(
Count(Seq(Literal(1, _))), Complete, false, None, _), _) if rowCount.isDefined =>
rowCount.get
case Alias(tps@ToPrettyString(AggregateExpression(
Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) if rowCount.isDefined =>
tps.copy(child = Literal(rowCount.get)).eval()
case Alias(AggregateExpression(
Min(minReference: AttributeReference), Complete, false, None, _), _)
if checkStatsExists(minReference) =>
convertValueIfRequired(minReference, columnStats(minReference.name).min)
case Alias(tps@ToPrettyString(AggregateExpression(
Min(minReference: AttributeReference), Complete, false, None, _), _), _)
if checkStatsExists(minReference) =>
val v = columnStats(minReference.name).min
tps.copy(child = Literal(v)).eval()
case Alias(AggregateExpression(
Max(maxReference: AttributeReference), Complete, false, None, _), _)
if checkStatsExists(maxReference) =>
convertValueIfRequired(maxReference, columnStats(maxReference.name).max)
case Alias(tps@ToPrettyString(AggregateExpression(
Max(maxReference: AttributeReference), Complete, false, None, _), _), _)
if checkStatsExists(maxReference) =>
val v = columnStats(maxReference.name).max
tps.copy(child = Literal(v)).eval()
}

if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) {
val r = LocalRelation(
plan.output,
Seq(InternalRow.fromSeq(rewrittenAggregationValues)))
r
} else {
plan
}
}
Expand All @@ -136,122 +129,114 @@ trait OptimizeMetadataOnlyDeltaQuery {
}
}

/** Return the number of rows in the table or `None` if we cannot calculate it from stats */
private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = {
// account for deleted rows according to deletion vectors
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")
val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil)
.agg(
sum(numLogicalRecords),
// Calculate the number of files missing `numRecords`
count(when(col("stats.numRecords").isNull, 1)))
.first

// The count agg is never null. A non-zero value means we have incomplete stats; otherwise,
// the sum agg is either null (for an empty table) or gives an accurate record count.
if (row.getLong(1) > 0) return None
val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0)
Some(numRecords)
}

/**
* Min and max values from Delta Log stats or partitionValues.
*/
case class DeltaColumnStat(min: Any, max: Any)

private def extractMinMaxFromStats(
private def extractCountMinMaxFromStats(
deltaScanGenerator: DeltaScanGenerator,
lowerCaseColumnNames: Set[String]): Map[String, DeltaColumnStat] = {

lowerCaseColumnNames: Set[String]): (Option[Long], Map[String, DeltaColumnStat]) = {
// TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485)

val snapshot = deltaScanGenerator.snapshotToScan
val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col =>
AggregateDeltaTable.isSupportedDataType(col.dataType) &&
lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT)))

// Count - account for deleted rows according to deletion vectors
val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0))
val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords")

val filesWithStatsForScan = deltaScanGenerator.filesWithStatsForScan(Nil)
// Validate all the files has stats
lazy val filesStatsCount = deltaScanGenerator.filesWithStatsForScan(Nil).select(
val filesStatsCount = filesWithStatsForScan.select(
sum(numLogicalRecords).as("numLogicalRecords"),
count(when(col("stats.numRecords").isNull, 1)).as("missingNumRecords"),
count(when(col("stats.numRecords") > 0, 1)).as("countNonEmptyFiles")).head

lazy val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
// If any numRecords is null, we have incomplete stats;
val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0
if (!allRecordsHasStats) {
return (None, Map.empty)
}
// the sum agg is either null (for an empty table) or gives an accurate record count.
val numRecords = if (filesStatsCount.isNullAt(0)) 0 else filesStatsCount.getLong(0)
lazy val numFiles: Long = filesStatsCount.getAs[Long]("countNonEmptyFiles")

val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col =>
AggregateDeltaTable.isSupportedDataType(col.dataType) &&
lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT)))

// DELETE operations creates AddFile records with 0 rows, and no column stats.
// We can safely ignore it since there is no data.
lazy val files = deltaScanGenerator.filesWithStatsForScan(Nil)
.filter(col("stats.numRecords") > 0)
lazy val files = filesWithStatsForScan.filter(col("stats.numRecords") > 0)
lazy val statsMinMaxNullColumns = files.select(col("stats.*"))
if (dataColumns.isEmpty
|| !isTableDVFree(snapshot)
|| !allRecordsHasStats
|| numFiles == 0
|| !statsMinMaxNullColumns.columns.contains("minValues")
|| !statsMinMaxNullColumns.columns.contains("maxValues")
|| !statsMinMaxNullColumns.columns.contains("nullCount")) {
Map.empty
} else {
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
// has been increased
val columnsWithStats = files.select(
col("stats.minValues.*"),
col("stats.maxValues.*"),
col("stats.nullCount.*"))
.columns.groupBy(identity).mapValues(_.size)
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
.map(x => x._1).toSet

// Creates a tuple with physical name to avoid recalculating it multiple times
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
.filter(x => columnsWithStats.contains(x._2))

val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val dataType = columnAndPhysicalName._1.dataType
val physicalName = columnAndPhysicalName._2

Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))
return (Some(numRecords), Map.empty)
}

val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val physicalName = columnAndPhysicalName._2
// dataColumns can contain columns without stats if dataSkippingNumIndexedCols
// has been increased
val columnsWithStats = files.select(
col("stats.minValues.*"),
col("stats.maxValues.*"),
col("stats.nullCount.*"))
.columns.groupBy(identity).mapValues(_.size)
.filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount
.map(x => x._1).toSet

// Creates a tuple with physical name to avoid recalculating it multiple times
val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x)))
.filter(x => columnsWithStats.contains(x._2))

val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val dataType = columnAndPhysicalName._1.dataType
val physicalName = columnAndPhysicalName._2

Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"),
col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"),
col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName"))
} ++ Seq(col(s"stats.numRecords").as(s"numRecords"))

val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName =>
val physicalName = columnAndPhysicalName._2

// To validate if the column has stats we do two validation:
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
// since nullCount is always non-null.
// 2-) The number of files with non-null min/max:
// a. count(min.columnName)|count(max.columnName) +
// the number of files where all rows are NULL:
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
// should be equals to numFiles
Seq(
s"""case when $numFiles = count(`nullCount.$physicalName`)
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
s"min(`min.$physicalName`) as `min_$physicalName`",
s"max(`max.$physicalName`) as `max_$physicalName`")
}

// To validate if the column has stats we do two validation:
// 1-) COUNT(nullCount.columnName) should be equals to numFiles,
// since nullCount is always non-null.
// 2-) The number of files with non-null min/max:
// a. count(min.columnName)|count(max.columnName) +
// the number of files where all rows are NULL:
// b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords)
// should be equals to numFiles
Seq(
s"""case when $numFiles = count(`nullCount.$physicalName`)
| AND $numFiles = (count(`min.$physicalName`) + sum(case when
| ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| AND $numFiles = (count(`max.$physicalName`) + sum(case when
| ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords
| then 1 else 0 end))
| then TRUE else FALSE end as `complete_$physicalName`""".stripMargin,
s"min(`min.$physicalName`) as `min_$physicalName`",
s"max(`max.$physicalName`) as `max_$physicalName`")
}
val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head

val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head

dataColumnsWithStats
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
.map { columnAndPhysicalName =>
val column = columnAndPhysicalName._1
val physicalName = columnAndPhysicalName._2
column.name ->
DeltaColumnStat(
statsResults.getAs(s"min_$physicalName"),
statsResults.getAs(s"max_$physicalName"))
}.toMap
}
(Some(numRecords), dataColumnsWithStats
.filter(x => statsResults.getAs[Boolean](s"complete_${x._2}"))
.map { columnAndPhysicalName =>
val column = columnAndPhysicalName._1
val physicalName = columnAndPhysicalName._2
column.name ->
DeltaColumnStat(
statsResults.getAs(s"min_$physicalName"),
statsResults.getAs(s"max_$physicalName"))
}.toMap)
}

private def extractMinMaxFromPartitionValue(
Expand Down Expand Up @@ -295,21 +280,28 @@ trait OptimizeMetadataOnlyDeltaQuery {
}
}

private def extractMinMaxFromDeltaLog(
/**
* Extract the Count, Min and Max values from Delta Log stats and partitionValues.
* The first field is the rows count in the table or `None` if we cannot calculate it from stats
* If the column is not partitioned, the values are extracted from stats when it exists.
* If the column is partitioned, the values are extracted from partitionValues.
*/
private def extractCountMinMaxFromDeltaLog(
tahoeLogFileIndex: TahoeLogFileIndex,
lowerCaseColumnNames: Set[String]):
CaseInsensitiveMap[DeltaColumnStat] = {
val deltaScanGenerator = getDeltaScanGenerator(tahoeLogFileIndex)
val snapshot = deltaScanGenerator.snapshotToScan
val columnFromStats = extractMinMaxFromStats(deltaScanGenerator, lowerCaseColumnNames)
(Option[Long], CaseInsensitiveMap[DeltaColumnStat]) = {
val deltaScanGen = getDeltaScanGenerator(tahoeLogFileIndex)
val (rowCount, columnStats) = extractCountMinMaxFromStats(deltaScanGen, lowerCaseColumnNames)

if(lowerCaseColumnNames.equals(columnFromStats.keySet)) {
CaseInsensitiveMap(columnFromStats)
val minMaxValues = if (lowerCaseColumnNames.equals(columnStats.keySet)) {
CaseInsensitiveMap(columnStats)
} else {
CaseInsensitiveMap(
columnFromStats.++
(extractMinMaxFromPartitionValue(snapshot, lowerCaseColumnNames)))
columnStats.++
(extractMinMaxFromPartitionValue(deltaScanGen.snapshotToScan, lowerCaseColumnNames)))
}

(rowCount, minMaxValues)
}

object AggregateDeltaTable {
Expand All @@ -322,7 +314,9 @@ trait OptimizeMetadataOnlyDeltaQuery {
dataType.isInstanceOf[DateType]
}

def getAggFunctionOptimizable(aggExpr: AggregateExpression): Option[DeclarativeAggregate] = {
private def getAggFunctionOptimizable(
aggExpr: AggregateExpression): Option[DeclarativeAggregate] = {

aggExpr match {
case AggregateExpression(
c@Count(Seq(Literal(1, _))), Complete, false, None, _) =>
Expand Down
Loading

0 comments on commit f290bc6

Please sign in to comment.