diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala b/spark/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala index 78b7e2fe7dd..0d3e74c17b0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuery.scala @@ -55,66 +55,59 @@ trait OptimizeMetadataOnlyDeltaQuery { private def createLocalRelationPlan( plan: Aggregate, tahoeLogFileIndex: TahoeLogFileIndex): LogicalPlan = { - val rowCount = extractGlobalCount(tahoeLogFileIndex) - - if (rowCount.isDefined) { - val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*) - val columnStats = extractMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames) - - def checkStatsExists(attrRef: AttributeReference): Boolean = { - columnStats.contains(attrRef.name) && - // Avoid StructType, it is not supported by this optimization - // Sanity check only. If reference is nested column it would be GetStructType - // instead of AttributeReference - attrRef.references.size == 1 && - attrRef.references.head.dataType != StructType - } - def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = { - if (attrRef.dataType == DateType && value != null) { - DateTimeUtils.fromJavaDate(value.asInstanceOf[Date]) - } else { - value - } - } + val aggColumnsNames = Set(extractMinMaxFieldNames(plan).map(_.toLowerCase(Locale.ROOT)) : _*) + val (rowCount, columnStats) = extractCountMinMaxFromDeltaLog(tahoeLogFileIndex, aggColumnsNames) - val rewrittenAggregationValues = plan.aggregateExpressions.collect { - case Alias(AggregateExpression( - Count(Seq(Literal(1, _))), Complete, false, None, _), _) => - rowCount.get - case Alias(tps@ToPrettyString(AggregateExpression( - Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) => - tps.copy(child = Literal(rowCount.get)).eval() - case Alias(AggregateExpression( - Min(minReference: AttributeReference), Complete, false, None, _), _) - if checkStatsExists(minReference) => - convertValueIfRequired(minReference, columnStats(minReference.name).min) - case Alias(tps@ToPrettyString(AggregateExpression( - Min(minReference: AttributeReference), Complete, false, None, _), _), _) - if checkStatsExists(minReference) => - val v = columnStats(minReference.name).min - tps.copy(child = Literal(v)).eval() - case Alias(AggregateExpression( - Max(maxReference: AttributeReference), Complete, false, None, _), _) - if checkStatsExists(maxReference) => - convertValueIfRequired(maxReference, columnStats(maxReference.name).max) - case Alias(tps@ToPrettyString(AggregateExpression( - Max(maxReference: AttributeReference), Complete, false, None, _), _), _) - if checkStatsExists(maxReference) => - val v = columnStats(maxReference.name).max - tps.copy(child = Literal(v)).eval() - } + def checkStatsExists(attrRef: AttributeReference): Boolean = { + columnStats.contains(attrRef.name) && + // Avoid StructType, it is not supported by this optimization + // Sanity check only. If reference is nested column it would be GetStructType + // instead of AttributeReference + attrRef.references.size == 1 && attrRef.references.head.dataType != StructType + } - if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) { - val r = LocalRelation( - plan.output, - Seq(InternalRow.fromSeq(rewrittenAggregationValues))) - r + def convertValueIfRequired(attrRef: AttributeReference, value: Any): Any = { + if (attrRef.dataType == DateType && value != null) { + DateTimeUtils.fromJavaDate(value.asInstanceOf[Date]) } else { - plan + value } } - else { + + val rewrittenAggregationValues = plan.aggregateExpressions.collect { + case Alias(AggregateExpression( + Count(Seq(Literal(1, _))), Complete, false, None, _), _) if rowCount.isDefined => + rowCount.get + case Alias(tps@ToPrettyString(AggregateExpression( + Count(Seq(Literal(1, _))), Complete, false, None, _), _), _) if rowCount.isDefined => + tps.copy(child = Literal(rowCount.get)).eval() + case Alias(AggregateExpression( + Min(minReference: AttributeReference), Complete, false, None, _), _) + if checkStatsExists(minReference) => + convertValueIfRequired(minReference, columnStats(minReference.name).min) + case Alias(tps@ToPrettyString(AggregateExpression( + Min(minReference: AttributeReference), Complete, false, None, _), _), _) + if checkStatsExists(minReference) => + val v = columnStats(minReference.name).min + tps.copy(child = Literal(v)).eval() + case Alias(AggregateExpression( + Max(maxReference: AttributeReference), Complete, false, None, _), _) + if checkStatsExists(maxReference) => + convertValueIfRequired(maxReference, columnStats(maxReference.name).max) + case Alias(tps@ToPrettyString(AggregateExpression( + Max(maxReference: AttributeReference), Complete, false, None, _), _), _) + if checkStatsExists(maxReference) => + val v = columnStats(maxReference.name).max + tps.copy(child = Literal(v)).eval() + } + + if (plan.aggregateExpressions.size == rewrittenAggregationValues.size) { + val r = LocalRelation( + plan.output, + Seq(InternalRow.fromSeq(rewrittenAggregationValues))) + r + } else { plan } } @@ -136,122 +129,114 @@ trait OptimizeMetadataOnlyDeltaQuery { } } - /** Return the number of rows in the table or `None` if we cannot calculate it from stats */ - private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = { - // account for deleted rows according to deletion vectors - val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0)) - val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords") - val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil) - .agg( - sum(numLogicalRecords), - // Calculate the number of files missing `numRecords` - count(when(col("stats.numRecords").isNull, 1))) - .first - - // The count agg is never null. A non-zero value means we have incomplete stats; otherwise, - // the sum agg is either null (for an empty table) or gives an accurate record count. - if (row.getLong(1) > 0) return None - val numRecords = if (row.isNullAt(0)) 0 else row.getLong(0) - Some(numRecords) - } - /** * Min and max values from Delta Log stats or partitionValues. */ case class DeltaColumnStat(min: Any, max: Any) - private def extractMinMaxFromStats( + private def extractCountMinMaxFromStats( deltaScanGenerator: DeltaScanGenerator, - lowerCaseColumnNames: Set[String]): Map[String, DeltaColumnStat] = { - + lowerCaseColumnNames: Set[String]): (Option[Long], Map[String, DeltaColumnStat]) = { // TODO Update this to work with DV (https://github.com/delta-io/delta/issues/1485) + val snapshot = deltaScanGenerator.snapshotToScan - val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col => - AggregateDeltaTable.isSupportedDataType(col.dataType) && - lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT))) + // Count - account for deleted rows according to deletion vectors + val dvCardinality = coalesce(col("deletionVector.cardinality"), lit(0)) + val numLogicalRecords = (col("stats.numRecords") - dvCardinality).as("numLogicalRecords") + + val filesWithStatsForScan = deltaScanGenerator.filesWithStatsForScan(Nil) // Validate all the files has stats - lazy val filesStatsCount = deltaScanGenerator.filesWithStatsForScan(Nil).select( + val filesStatsCount = filesWithStatsForScan.select( + sum(numLogicalRecords).as("numLogicalRecords"), count(when(col("stats.numRecords").isNull, 1)).as("missingNumRecords"), count(when(col("stats.numRecords") > 0, 1)).as("countNonEmptyFiles")).head - lazy val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0 + // If any numRecords is null, we have incomplete stats; + val allRecordsHasStats = filesStatsCount.getAs[Long]("missingNumRecords") == 0 + if (!allRecordsHasStats) { + return (None, Map.empty) + } + // the sum agg is either null (for an empty table) or gives an accurate record count. + val numRecords = if (filesStatsCount.isNullAt(0)) 0 else filesStatsCount.getLong(0) lazy val numFiles: Long = filesStatsCount.getAs[Long]("countNonEmptyFiles") + val dataColumns = snapshot.statCollectionPhysicalSchema.filter(col => + AggregateDeltaTable.isSupportedDataType(col.dataType) && + lowerCaseColumnNames.contains(col.name.toLowerCase(Locale.ROOT))) + // DELETE operations creates AddFile records with 0 rows, and no column stats. // We can safely ignore it since there is no data. - lazy val files = deltaScanGenerator.filesWithStatsForScan(Nil) - .filter(col("stats.numRecords") > 0) + lazy val files = filesWithStatsForScan.filter(col("stats.numRecords") > 0) lazy val statsMinMaxNullColumns = files.select(col("stats.*")) if (dataColumns.isEmpty || !isTableDVFree(snapshot) - || !allRecordsHasStats || numFiles == 0 || !statsMinMaxNullColumns.columns.contains("minValues") || !statsMinMaxNullColumns.columns.contains("maxValues") || !statsMinMaxNullColumns.columns.contains("nullCount")) { - Map.empty - } else { - // dataColumns can contain columns without stats if dataSkippingNumIndexedCols - // has been increased - val columnsWithStats = files.select( - col("stats.minValues.*"), - col("stats.maxValues.*"), - col("stats.nullCount.*")) - .columns.groupBy(identity).mapValues(_.size) - .filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount - .map(x => x._1).toSet - - // Creates a tuple with physical name to avoid recalculating it multiple times - val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x))) - .filter(x => columnsWithStats.contains(x._2)) - - val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName => - val dataType = columnAndPhysicalName._1.dataType - val physicalName = columnAndPhysicalName._2 - - Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"), - col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"), - col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName")) - } ++ Seq(col(s"stats.numRecords").as(s"numRecords")) + return (Some(numRecords), Map.empty) + } - val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName => - val physicalName = columnAndPhysicalName._2 + // dataColumns can contain columns without stats if dataSkippingNumIndexedCols + // has been increased + val columnsWithStats = files.select( + col("stats.minValues.*"), + col("stats.maxValues.*"), + col("stats.nullCount.*")) + .columns.groupBy(identity).mapValues(_.size) + .filter(x => x._2 == 3) // 3: minValues, maxValues, nullCount + .map(x => x._1).toSet + + // Creates a tuple with physical name to avoid recalculating it multiple times + val dataColumnsWithStats = dataColumns.map(x => (x, DeltaColumnMapping.getPhysicalName(x))) + .filter(x => columnsWithStats.contains(x._2)) + + val columnsToQuery = dataColumnsWithStats.flatMap { columnAndPhysicalName => + val dataType = columnAndPhysicalName._1.dataType + val physicalName = columnAndPhysicalName._2 + + Seq(col(s"stats.minValues.`$physicalName`").cast(dataType).as(s"min.$physicalName"), + col(s"stats.maxValues.`$physicalName`").cast(dataType).as(s"max.$physicalName"), + col(s"stats.nullCount.`$physicalName`").as(s"nullCount.$physicalName")) + } ++ Seq(col(s"stats.numRecords").as(s"numRecords")) + + val minMaxExpr = dataColumnsWithStats.flatMap { columnAndPhysicalName => + val physicalName = columnAndPhysicalName._2 + + // To validate if the column has stats we do two validation: + // 1-) COUNT(nullCount.columnName) should be equals to numFiles, + // since nullCount is always non-null. + // 2-) The number of files with non-null min/max: + // a. count(min.columnName)|count(max.columnName) + + // the number of files where all rows are NULL: + // b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords) + // should be equals to numFiles + Seq( + s"""case when $numFiles = count(`nullCount.$physicalName`) + | AND $numFiles = (count(`min.$physicalName`) + sum(case when + | ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords + | then 1 else 0 end)) + | AND $numFiles = (count(`max.$physicalName`) + sum(case when + | ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords + | then 1 else 0 end)) + | then TRUE else FALSE end as `complete_$physicalName`""".stripMargin, + s"min(`min.$physicalName`) as `min_$physicalName`", + s"max(`max.$physicalName`) as `max_$physicalName`") + } - // To validate if the column has stats we do two validation: - // 1-) COUNT(nullCount.columnName) should be equals to numFiles, - // since nullCount is always non-null. - // 2-) The number of files with non-null min/max: - // a. count(min.columnName)|count(max.columnName) + - // the number of files where all rows are NULL: - // b. count of (ISNULL(min.columnName) and nullCount.columnName == numRecords) - // should be equals to numFiles - Seq( - s"""case when $numFiles = count(`nullCount.$physicalName`) - | AND $numFiles = (count(`min.$physicalName`) + sum(case when - | ISNULL(`min.$physicalName`) and `nullCount.$physicalName` = numRecords - | then 1 else 0 end)) - | AND $numFiles = (count(`max.$physicalName`) + sum(case when - | ISNULL(`max.$physicalName`) AND `nullCount.$physicalName` = numRecords - | then 1 else 0 end)) - | then TRUE else FALSE end as `complete_$physicalName`""".stripMargin, - s"min(`min.$physicalName`) as `min_$physicalName`", - s"max(`max.$physicalName`) as `max_$physicalName`") - } + val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head - val statsResults = files.select(columnsToQuery: _*).selectExpr(minMaxExpr: _*).head - - dataColumnsWithStats - .filter(x => statsResults.getAs[Boolean](s"complete_${x._2}")) - .map { columnAndPhysicalName => - val column = columnAndPhysicalName._1 - val physicalName = columnAndPhysicalName._2 - column.name -> - DeltaColumnStat( - statsResults.getAs(s"min_$physicalName"), - statsResults.getAs(s"max_$physicalName")) - }.toMap - } + (Some(numRecords), dataColumnsWithStats + .filter(x => statsResults.getAs[Boolean](s"complete_${x._2}")) + .map { columnAndPhysicalName => + val column = columnAndPhysicalName._1 + val physicalName = columnAndPhysicalName._2 + column.name -> + DeltaColumnStat( + statsResults.getAs(s"min_$physicalName"), + statsResults.getAs(s"max_$physicalName")) + }.toMap) } private def extractMinMaxFromPartitionValue( @@ -295,21 +280,28 @@ trait OptimizeMetadataOnlyDeltaQuery { } } - private def extractMinMaxFromDeltaLog( + /** + * Extract the Count, Min and Max values from Delta Log stats and partitionValues. + * The first field is the rows count in the table or `None` if we cannot calculate it from stats + * If the column is not partitioned, the values are extracted from stats when it exists. + * If the column is partitioned, the values are extracted from partitionValues. + */ + private def extractCountMinMaxFromDeltaLog( tahoeLogFileIndex: TahoeLogFileIndex, lowerCaseColumnNames: Set[String]): - CaseInsensitiveMap[DeltaColumnStat] = { - val deltaScanGenerator = getDeltaScanGenerator(tahoeLogFileIndex) - val snapshot = deltaScanGenerator.snapshotToScan - val columnFromStats = extractMinMaxFromStats(deltaScanGenerator, lowerCaseColumnNames) + (Option[Long], CaseInsensitiveMap[DeltaColumnStat]) = { + val deltaScanGen = getDeltaScanGenerator(tahoeLogFileIndex) + val (rowCount, columnStats) = extractCountMinMaxFromStats(deltaScanGen, lowerCaseColumnNames) - if(lowerCaseColumnNames.equals(columnFromStats.keySet)) { - CaseInsensitiveMap(columnFromStats) + val minMaxValues = if (lowerCaseColumnNames.equals(columnStats.keySet)) { + CaseInsensitiveMap(columnStats) } else { CaseInsensitiveMap( - columnFromStats.++ - (extractMinMaxFromPartitionValue(snapshot, lowerCaseColumnNames))) + columnStats.++ + (extractMinMaxFromPartitionValue(deltaScanGen.snapshotToScan, lowerCaseColumnNames))) } + + (rowCount, minMaxValues) } object AggregateDeltaTable { @@ -322,7 +314,9 @@ trait OptimizeMetadataOnlyDeltaQuery { dataType.isInstanceOf[DateType] } - def getAggFunctionOptimizable(aggExpr: AggregateExpression): Option[DeclarativeAggregate] = { + private def getAggFunctionOptimizable( + aggExpr: AggregateExpression): Option[DeclarativeAggregate] = { + aggExpr match { case AggregateExpression( c@Count(Seq(Literal(1, _))), Complete, false, None, _) => diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuerySuite.scala index adfbc4aa742..b7b7de02d91 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/perf/OptimizeMetadataOnlyDeltaQuerySuite.scala @@ -114,8 +114,7 @@ class OptimizeMetadataOnlyDeltaQuerySuite min(col("DoubleColumn")), max(col("DoubleColumn")), min(col("DateColumn")), max(col("DateColumn"))), expectedPlan = "LocalRelation [none#0L, none#1L, none#2, none#3, none#4, none#5, none#6" + - ", none#7, none#8L, none#9L, none#10, none#11, none#12, none#13, none#14, none#15]"), - ) + ", none#7, none#8L, none#9L, none#10, none#11, none#12, none#13, none#14, none#15]")) .foreach { testParams => test(s"optimization supported - Scala - ${testParams.name}") { checkResultsAndOptimizedPlan(testParams.queryScala, testParams.expectedPlan) @@ -482,8 +481,7 @@ class OptimizeMetadataOnlyDeltaQuerySuite ", MIN(Column4), MAX(Column4)" + " FROM TestColumnMappingPartitioned", expectedPlan = "LocalRelation [none#0L, none#1, none#2, none#3," + - " none#4, none#5, none#6, none#7, none#8]"), - ) + " none#4, none#5, none#6, none#7, none#8]")) .foreach { testParams => test(s"optimization supported - SQL - ${testParams.name}") { if(testParams.querySetup.isDefined) { @@ -506,6 +504,27 @@ class OptimizeMetadataOnlyDeltaQuerySuite } } + test("min-max - partitioned column stats disabled") { + withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") { + val tableName = "TestPartitionedNoStats" + + spark.sql(s"CREATE TABLE $tableName (Column1 INT, Column2 INT)" + + " USING DELTA PARTITIONED BY (Column2)") + + spark.sql(s"INSERT INTO $tableName (Column1, Column2) VALUES (1, 3);") + spark.sql(s"INSERT INTO $tableName (Column1, Column2) VALUES (2, 4);") + + //Has no stats, including COUNT + checkOptimizationIsNotTriggered( + s"SELECT COUNT(*), MIN(Column2), MAX(Column2) FROM $tableName") + + //Should work for partitioned columns even without stats + checkResultsAndOptimizedPlan( + s"SELECT MIN(Column2), MAX(Column2) FROM $tableName", + "LocalRelation [none#0, none#1]") + } + } + test("min-max - recompute column missing stats") { val tableName = "TestRecomputeMissingStat"