From 0c349da8c70d5abacc4aa6decd57c6daf511696b Mon Sep 17 00:00:00 2001 From: Felipe Pessoto Date: Tue, 15 Nov 2022 18:00:43 -0800 Subject: [PATCH] Optimize common case: SELECT COUNT(*) FROM Table Fix #1192 ## Description Running the query "SELECT COUNT(*) FROM Table" takes a lot of time for big tables, Spark scan all the parquet files just to return the number of rows, that information is available from Delta Logs. Resolves #1192 Created unit tests to validate the optimization works, including cases not covered by this optimization. ## Does this PR introduce _any_ user-facing changes? Only performance improvement Closes delta-io/delta#1377 Signed-off-by: Shixiong Zhu GitOrigin-RevId: a9116e42a9c805adc967dd3e802f84d502f50a8b --- .../OptimizeMetadataOnlyDeltaQuery.scala | 75 +++++ .../sql/delta/sources/DeltaSQLConf.scala | 8 + .../sql/delta/stats/PrepareDeltaScan.scala | 12 +- .../sql/delta/DeltaHistoryManagerSuite.scala | 3 +- .../apache/spark/sql/delta/DeltaSuite.scala | 6 +- .../OptimizeMetadataOnlyDeltaQuerySuite.scala | 314 ++++++++++++++++++ 6 files changed, 412 insertions(+), 6 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuery.scala create mode 100644 core/src/test/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuerySuite.scala diff --git a/core/src/main/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuery.scala b/core/src/main/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuery.scala new file mode 100644 index 00000000000..484138adc9a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuery.scala @@ -0,0 +1,75 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.optimizer + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.delta.DeltaTable +import org.apache.spark.sql.delta.files.TahoeLogFileIndex +import org.apache.spark.sql.delta.stats.DeltaScanGenerator +import org.apache.spark.sql.functions.{count, sum} + +trait OptimizeMetadataOnlyDeltaQuery { + def optimizeQueryWithMetadata(plan: LogicalPlan): LogicalPlan = { + plan.transformUpWithSubqueries { + case agg@CountStarDeltaTable(countValue) => + LocalRelation(agg.output, Seq(InternalRow(countValue))) + } + } + + protected def getDeltaScanGenerator(index: TahoeLogFileIndex): DeltaScanGenerator + + object CountStarDeltaTable { + def unapply(plan: Aggregate): Option[Long] = { + plan match { + case Aggregate( + Nil, + Seq(Alias(AggregateExpression(Count(Seq(Literal(1, _))), Complete, false, None, _), _)), + PhysicalOperation(_, Nil, DeltaTable(tahoeLogFileIndex: TahoeLogFileIndex))) => + extractGlobalCount(tahoeLogFileIndex) + case _ => None + } + } + + private def extractGlobalCount(tahoeLogFileIndex: TahoeLogFileIndex): Option[Long] = { + val row = getDeltaScanGenerator(tahoeLogFileIndex).filesWithStatsForScan(Nil) + .agg( + sum("stats.numRecords"), + count(new Column("*")), + count(new Column("stats.numRecords"))) + .first + + val numOfFiles = row.getLong(1) + val numOfFilesWithStats = row.getLong(2) + + if (numOfFiles == numOfFilesWithStats) { + val numRecords = if (row.isNullAt(0)) { + 0 // It is Null if deltaLog.snapshot.allFiles is empty + } else { row.getLong(0) } + + Some(numRecords) + } else { + // If COUNT(*) is greater than COUNT(numRecords) means not every AddFile records has stats + None + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala b/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala index 0419b4d6abd..575c8600df7 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala @@ -911,6 +911,14 @@ trait DeltaSQLConfBase { " concurrent queries accessing the table until the history wipe is complete.") .booleanConf .createWithDefault(false) + + val DELTA_OPTIMIZE_METADATA_QUERY_ENABLED = + buildConf("optimizeMetadataQuery.enabled") + .internal() + .doc("Whether we can use the metadata in the DeltaLog to" + + " optimize queries that can be run purely on metadata.") + .booleanConf + .createWithDefault(true) } object DeltaSQLConf extends DeltaSQLConfBase diff --git a/core/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala b/core/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala index 1d20258f893..b4b16a4a499 100644 --- a/core/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala +++ b/core/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.delta._ import org.apache.spark.sql.delta.actions.{AddFile, Metadata} import org.apache.spark.sql.delta.files.{TahoeFileIndex, TahoeLogFileIndex} import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.optimizer.OptimizeMetadataOnlyDeltaQuery import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.hadoop.fs.Path @@ -49,7 +50,8 @@ import org.apache.spark.sql.types.StructType */ trait PrepareDeltaScanBase extends Rule[LogicalPlan] with PredicateHelper - with DeltaLogging { self: PrepareDeltaScan => + with DeltaLogging + with OptimizeMetadataOnlyDeltaQuery { self: PrepareDeltaScan => private val snapshotIsolationEnabled = spark.conf.get(DeltaSQLConf.DELTA_SNAPSHOT_ISOLATION) @@ -199,7 +201,6 @@ trait PrepareDeltaScanBase extends Rule[LogicalPlan] spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_STATS_SKIPPING) ) if (shouldPrepareDeltaScan) { - // Should not be applied to subqueries to avoid duplicate delta jobs. val isSubquery = plan.isInstanceOf[Subquery] || plan.isInstanceOf[SupportsSubquery] // Should not be applied to DataSourceV2 write plans, because they'll be planned later @@ -209,6 +210,13 @@ trait PrepareDeltaScanBase extends Rule[LogicalPlan] return plan } + val optimizeMetadataQueryEnabled = + spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_OPTIMIZE_METADATA_QUERY_ENABLED) + + if (optimizeMetadataQueryEnabled) { + plan = optimizeQueryWithMetadata(plan) + } + prepareDeltaScan(plan) } else { // If this query is running inside an active transaction and is touching the same table diff --git a/core/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala index b2b85f407ca..9ec5ef1d953 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala @@ -554,8 +554,9 @@ abstract class DeltaHistoryManagerBase extends DeltaTimeTravelTests sql(s"optimize $tblName") withSQLConf( + // Disable query rewrite or else the parquet files are not scanned. + DeltaSQLConf.DELTA_OPTIMIZE_METADATA_QUERY_ENABLED.key -> "false", DeltaSQLConf.DELTA_VACUUM_RETENTION_CHECK_ENABLED.key -> "false") { - sql(s"vacuum $tblName retain 0 hours") intercept[SparkException] { sql(s"select * from ${versionAsOf(tblName, 0)}").collect() diff --git a/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala index 3768bbfa97b..83284e52521 100644 --- a/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/delta/DeltaSuite.scala @@ -1498,7 +1498,7 @@ class DeltaSuite extends QueryTest } val thrown = intercept[SparkException] { - data.toDF().count() + data.toDF().collect() } assert(thrown.getMessage.contains("is not a Parquet file")) } @@ -1528,7 +1528,7 @@ class DeltaSuite extends QueryTest // We don't have a good way to tell which specific values got deleted, so just check that // the right number remain. (Note that this works because there's 1 value per append, which // means 1 value per file.) - assert(data.toDF().count() == 6) + assert(data.toDF().collect().size == 6) } } } @@ -1553,7 +1553,7 @@ class DeltaSuite extends QueryTest } val thrown = intercept[SparkException] { - data.toDF().count() + data.toDF().collect() } assert(thrown.getMessage.contains("FileNotFound")) } diff --git a/core/src/test/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuerySuite.scala b/core/src/test/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuerySuite.scala new file mode 100644 index 00000000000..4ba837bb0fb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/sql/delta/optimizer/OptimizeMetadataOnlyDeltaQuerySuite.scala @@ -0,0 +1,314 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.optimizer + +// scalastyle:off import.ordering.noEmptyLine +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.stats.PrepareDeltaScanBase +import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class OptimizeMetadataOnlyDeltaQuerySuite + extends QueryTest + with SharedSparkSession + with BeforeAndAfterAll + with DeltaSQLCommandTest { + val testTableName = "table_basic" + val testTablePath = Utils.createTempDir().getAbsolutePath + val noStatsTableName = " table_nostats" + val mixedStatsTableName = " table_mixstats" + val totalRows = 9L + val totalNonNullData = 8L + val totalDistinctData = 5L + + override def beforeAll(): Unit = { + super.beforeAll() + val df = spark.createDataFrame(Seq((1L, "a", 1L), (2L, "b", 1L), (3L, "c", 1L))) + .toDF("id", "data", "group") + val df2 = spark.createDataFrame(Seq( + (4L, "d", 1L), + (5L, "e", 1L), + (6L, "f", 1L), + (7L, null, 1L), + (8L, "b", 1L), + (9L, "b", 1L), + (10L, "b", 1L))).toDF("id", "data", "group") + + withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "false") { + df.write.format("delta").mode(SaveMode.Overwrite).saveAsTable(noStatsTableName) + df.write.format("delta").mode(SaveMode.Overwrite).saveAsTable(mixedStatsTableName) + + spark.sql(s"DELETE FROM $noStatsTableName WHERE id = 1") + spark.sql(s"DELETE FROM $mixedStatsTableName WHERE id = 1") + + df2.write.format("delta").mode("append").saveAsTable(noStatsTableName) + } + + withSQLConf(DeltaSQLConf.DELTA_COLLECT_STATS.key -> "true") { + import io.delta.tables._ + + df.write.format("delta").mode(SaveMode.Overwrite).saveAsTable(testTableName) + df.write.format("delta").mode(SaveMode.Overwrite).save(testTablePath) + + spark.sql(s"DELETE FROM $testTableName WHERE id = 1") + DeltaTable.forPath(spark, testTablePath).delete("id = 1") + + df2.write.format("delta").mode(SaveMode.Append).saveAsTable(testTableName) + df2.write.format("delta").mode(SaveMode.Append).save(testTablePath) + df2.write.format("delta").mode(SaveMode.Append).saveAsTable(mixedStatsTableName) + } + } + + test("Select Count: basic") { + checkResultsAndOptimizedPlan( + s"SELECT COUNT(*) FROM $testTableName", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: column alias") { + checkResultsAndOptimizedPlan( + s"SELECT COUNT(*) as MyColumn FROM $testTableName", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: table alias") { + checkResultsAndOptimizedPlan( + s"SELECT COUNT(*) FROM $testTableName MyTable", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: time travel") { + checkResultsAndOptimizedPlan(s"SELECT COUNT(*) FROM $testTableName VERSION AS OF 0", + Seq(Row(3L)), + "LocalRelation [none#0L]") + + checkResultsAndOptimizedPlan(s"SELECT COUNT(*) FROM $testTableName VERSION AS OF 1", + Seq(Row(2L)), + "LocalRelation [none#0L]") + + checkResultsAndOptimizedPlan(s"SELECT COUNT(*) FROM $testTableName VERSION AS OF 2", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: external") { + checkResultsAndOptimizedPlan( + s"SELECT COUNT(*) FROM delta.`$testTablePath`", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: sub-query") { + checkResultsAndOptimizedPlan( + s"SELECT (SELECT COUNT(*) FROM $testTableName)", + Seq(Row(totalRows)), + "Project [scalar-subquery#0 [] AS #0L]\n: +- LocalRelation [none#0L]\n+- OneRowRelation") + } + + test("Select Count: as sub-query filter") { + checkResultsAndOptimizedPlan( + s"SELECT 'ABC' WHERE (SELECT COUNT(*) FROM $testTableName) = $totalRows", + Seq(Row("ABC")), + "Project [ABC AS #0]\n+- Filter (scalar-subquery#0 [] = " + + totalRows + ")\n : +- LocalRelation [none#0L]\n +- OneRowRelation") + } + + test("Select Count: limit") { + // Limit doesn't affect COUNT results + checkResultsAndOptimizedPlan( + s"SELECT COUNT(*) FROM $testTableName LIMIT 3", + Seq(Row(totalRows)), + "LocalRelation [none#0L]") + } + + test("Select Count: empty table") { + sql(s"CREATE TABLE TestEmpty (c1 int) USING DELTA") + + val query = "SELECT COUNT(*) FROM TestEmpty" + + checkResultsAndOptimizedPlan(query, Seq(Row(0)), "LocalRelation [none#0L]") + } + + test("Select Count: snapshot isolation") { + sql(s"CREATE TABLE TestSnapshotIsolation (c1 int) USING DELTA") + spark.sql("INSERT INTO TestSnapshotIsolation VALUES (1)") + + val query = "SELECT (SELECT COUNT(*) FROM TestSnapshotIsolation), " + + "(SELECT COUNT(*) FROM TestSnapshotIsolation)" + + checkResultsAndOptimizedPlan( + query, + Seq(Row(1, 1)), + "Project [scalar-subquery#0 [] AS #0L, scalar-subquery#0 [] AS #1L]\n" + + ": :- LocalRelation [none#0L]\n" + + ": +- LocalRelation [none#0L]\n" + + "+- OneRowRelation") + + PrepareDeltaScanBase.withCallbackOnGetDeltaScanGenerator(_ => { + // Insert a row after each call to get scanGenerator + // to test if the count doesn't change in the same query + spark.sql("INSERT INTO TestSnapshotIsolation VALUES (1)") + }) { + val result = spark.sql(query).collect()(0) + val c1 = result.getLong(0) + val c2 = result.getLong(1) + assertResult(c1, "Snapshot isolation should guarantee the results are always the same")(c2) + } + } + + // Tests to validate the optimizer won't use missing or partial stats + test("Select Count: missing stats") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) FROM $mixedStatsTableName", + Seq(Row(totalRows))) + + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) FROM $noStatsTableName", + Seq(Row(totalRows))) + } + + + // Tests to validate the optimizer won't incorrectly change queries it can't correctly handle + test("Select Count: multiple aggregations") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) AS MyCount, MAX(id) FROM $testTableName", + Seq(Row(totalRows, 10L))) + } + + test("Select Count: group by") { + checkSameQueryPlanAndResults( + s"SELECT group, COUNT(*) FROM $testTableName GROUP BY group", + Seq(Row(1L, totalRows))) + } + + test("Select Count: count twice") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*), COUNT(*) FROM $testTableName", + Seq(Row(totalRows, totalRows))) + } + + test("Select Count: plus literal") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) + 1 FROM $testTableName", + Seq(Row(totalRows + 1))) + } + + test("Select Count: distinct") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(DISTINCT data) FROM $testTableName", + Seq(Row(totalDistinctData))) + } + + test("Select Count: filter") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) FROM $testTableName WHERE id > 0", + Seq(Row(totalRows))) + } + + test("Select Count: sub-query with filter") { + checkSameQueryPlanAndResults( + s"SELECT (SELECT COUNT(*) FROM $testTableName WHERE id > 0)", + Seq(Row(totalRows))) + } + + test("Select Count: non-null") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(ALL data) FROM $testTableName", + Seq(Row(totalNonNullData))) + checkSameQueryPlanAndResults( + s"SELECT COUNT(data) FROM $testTableName", + Seq(Row(totalNonNullData))) + } + + test("Select Count: join") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) FROM $testTableName A, $testTableName B", + Seq(Row(totalRows * totalRows))) + } + + test("Select Count: over") { + checkSameQueryPlanAndResults( + s"SELECT COUNT(*) OVER() FROM $testTableName LIMIT 1", + Seq(Row(totalRows))) + } + + private def checkResultsAndOptimizedPlan( + query: String, + expectedAnswer: scala.Seq[Row], + expectedOptimizedPlan: String): Unit = { + checkResultsAndOptimizedPlan(() => spark.sql(query), expectedAnswer, expectedOptimizedPlan) + } + + private def checkResultsAndOptimizedPlan( + generateQueryDf: () => DataFrame, + expectedAnswer: scala.Seq[Row], + expectedOptimizedPlan: String): Unit = { + withSQLConf(DeltaSQLConf.DELTA_OPTIMIZE_METADATA_QUERY_ENABLED.key -> "true") { + val queryDf = generateQueryDf() + val optimizedPlan = queryDf.queryExecution.optimizedPlan.canonicalized.toString() + + assertResult(expectedAnswer(0)(0)) { + queryDf.collect()(0)(0) + } + + assertResult(expectedOptimizedPlan.trim) { + optimizedPlan.trim + } + } + } + + /** + * Verify the query plans and results are the same with/without metadata query optimization. + * This method can be used to verify cases that we shouldn't trigger optimization + * or cases that we can potentially improve. + * @param query + * @param expectedAnswer + */ + private def checkSameQueryPlanAndResults( + query: String, + expectedAnswer: scala.Seq[Row]) { + var optimizationEnabledQueryPlan: String = null + var optimizationDisabledQueryPlan: String = null + + withSQLConf(DeltaSQLConf.DELTA_OPTIMIZE_METADATA_QUERY_ENABLED.key -> "true") { + + val queryDf = spark.sql(query) + optimizationEnabledQueryPlan = queryDf.queryExecution.optimizedPlan + .canonicalized.toString() + checkAnswer(queryDf, expectedAnswer) + } + + withSQLConf(DeltaSQLConf.DELTA_OPTIMIZE_METADATA_QUERY_ENABLED.key -> "false") { + + val countQuery = spark.sql(query) + optimizationDisabledQueryPlan = countQuery.queryExecution.optimizedPlan + .canonicalized.toString() + checkAnswer(countQuery, expectedAnswer) + } + + assertResult(optimizationEnabledQueryPlan) { + optimizationDisabledQueryPlan + } + } +}