From bd68db8cf2c8d1d0b60cc7c1e38b1ee52d370cef Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Wed, 13 Nov 2024 03:20:17 -0800 Subject: [PATCH] wip - CometNativeScan (#1078) --- .../scala/org/apache/comet/CometConf.scala | 10 + .../core/src/execution/datafusion/planner.rs | 186 ++++--- native/proto/src/proto/operator.proto | 9 + .../comet/CometSparkSessionExtensions.scala | 19 +- .../apache/comet/serde/QueryPlanSerde.scala | 108 ++-- .../spark/sql/comet/CometNativeScanExec.scala | 509 ++++++++++++++++++ .../apache/spark/sql/comet/operators.scala | 8 +- .../apache/comet/exec/CometExecSuite.scala | 10 +- .../org/apache/spark/sql/CometTestBase.scala | 4 +- 9 files changed, 715 insertions(+), 148 deletions(-) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 7450d27a6..09355446c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -77,6 +77,16 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf( + "spark.comet.native.scan.enabled") + .internal() + .doc( + "Whether to enable the fully native scan. When this is turned on, Spark will use Comet to " + + "read supported data sources (currently only Parquet is supported natively)." + + " By default, this config is true.") + .booleanConf + .createWithDefault(true) + val COMET_PARQUET_PARALLEL_IO_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.parquet.read.parallel.io.enabled") .doc( diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 48a653add..b2940eabc 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -948,118 +948,116 @@ impl PhysicalPlanner { Arc::new(SortExec::new(LexOrdering::new(exprs?), child).with_fetch(fetch)), )) } - OpStruct::Scan(scan) => { + OpStruct::NativeScan(scan) => { let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); - if scan.source == "CometScan parquet (unknown)" { - let data_schema = parse_message_type(&scan.data_schema).unwrap(); - let required_schema = parse_message_type(&scan.required_schema).unwrap(); - println!("data_schema: {:?}", data_schema); - println!("required_schema: {:?}", required_schema); - - let data_schema_descriptor = - parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema)); - let data_schema_arrow = Arc::new( - parquet::arrow::schema::parquet_to_arrow_schema( - &data_schema_descriptor, - None, - ) - .unwrap(), - ); - println!("data_schema_arrow: {:?}", data_schema_arrow); - - let required_schema_descriptor = - parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema)); - let required_schema_arrow = Arc::new( - parquet::arrow::schema::parquet_to_arrow_schema( - &required_schema_descriptor, - None, - ) + println!("NATIVE: SCAN: {:?}", scan); + let data_schema = parse_message_type(&*scan.data_schema).unwrap(); + let required_schema = parse_message_type(&*scan.required_schema).unwrap(); + println!("data_schema: {:?}", data_schema); + println!("required_schema: {:?}", required_schema); + + let data_schema_descriptor = + parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema)); + let data_schema_arrow = Arc::new( + parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None) .unwrap(), - ); - println!("required_schema_arrow: {:?}", required_schema_arrow); + ); + println!("data_schema_arrow: {:?}", data_schema_arrow); + + let required_schema_descriptor = + parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema)); + let required_schema_arrow = Arc::new( + parquet::arrow::schema::parquet_to_arrow_schema( + &required_schema_descriptor, + None, + ) + .unwrap(), + ); + println!("required_schema_arrow: {:?}", required_schema_arrow); - assert!(!required_schema_arrow.fields.is_empty()); + assert!(!required_schema_arrow.fields.is_empty()); - let mut projection_vector: Vec = - Vec::with_capacity(required_schema_arrow.fields.len()); - // TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of. - required_schema_arrow.fields.iter().for_each(|field| { - projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap()); - }); - println!("projection_vector: {:?}", projection_vector); + let mut projection_vector: Vec = + Vec::with_capacity(required_schema_arrow.fields.len()); + // TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of. + required_schema_arrow.fields.iter().for_each(|field| { + projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap()); + }); + println!("projection_vector: {:?}", projection_vector); - assert_eq!(projection_vector.len(), required_schema_arrow.fields.len()); + assert_eq!(projection_vector.len(), required_schema_arrow.fields.len()); - // Convert the Spark expressions to Physical expressions - let data_filters: Result>, ExecutionError> = scan - .data_filters - .iter() - .map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow))) - .collect(); - - // Create a conjunctive form of the vector because ParquetExecBuilder takes - // a single expression - let data_filters = data_filters?; - let test_data_filters = - data_filters.clone().into_iter().reduce(|left, right| { - Arc::new(BinaryExpr::new( - left, - datafusion::logical_expr::Operator::And, - right, - )) - }); - - println!("data_filters: {:?}", data_filters); - println!("test_data_filters: {:?}", test_data_filters); - - let object_store_url = ObjectStoreUrl::local_filesystem(); - let paths: Vec = scan - .path - .iter() - .map(|path| Url::parse(path).unwrap()) - .collect(); + // Convert the Spark expressions to Physical expressions + let data_filters: Result>, ExecutionError> = scan + .data_filters + .iter() + .map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow))) + .collect(); - let object_store = object_store::local::LocalFileSystem::new(); - // register the object store with the runtime environment - let url = Url::try_from("file://").unwrap(); - self.session_ctx - .runtime_env() - .register_object_store(&url, Arc::new(object_store)); + // Create a conjunctive form of the vector because ParquetExecBuilder takes + // a single expression + let data_filters = data_filters?; + let test_data_filters = data_filters.clone().into_iter().reduce(|left, right| { + Arc::new(BinaryExpr::new( + left, + datafusion::logical_expr::Operator::And, + right, + )) + }); - let files: Vec = paths - .iter() - .map(|path| PartitionedFile::from_path(path.path().to_string()).unwrap()) - .collect(); + println!("data_filters: {:?}", data_filters); + println!("test_data_filters: {:?}", test_data_filters); - // partition the files - // TODO really should partition the row groups + let object_store_url = ObjectStoreUrl::local_filesystem(); + let paths: Vec = scan + .path + .iter() + .map(|path| Url::parse(path).unwrap()) + .collect(); - let mut file_groups = vec![vec![]; partition_count]; - files.iter().enumerate().for_each(|(idx, file)| { - file_groups[idx % partition_count].push(file.clone()); - }); + let object_store = object_store::local::LocalFileSystem::new(); + // register the object store with the runtime environment + let url = Url::try_from("file://").unwrap(); + self.session_ctx + .runtime_env() + .register_object_store(&url, Arc::new(object_store)); - let file_scan_config = - FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow)) - .with_file_groups(file_groups) - .with_projection(Some(projection_vector)); + let files: Vec = paths + .iter() + .map(|path| PartitionedFile::from_path(path.path().to_string()).unwrap()) + .collect(); - let mut table_parquet_options = TableParquetOptions::new(); - table_parquet_options.global.pushdown_filters = true; - table_parquet_options.global.reorder_filters = true; + // partition the files + // TODO really should partition the row groups - let mut builder = ParquetExecBuilder::new(file_scan_config) - .with_table_parquet_options(table_parquet_options); + let mut file_groups = vec![vec![]; partition_count]; + files.iter().enumerate().for_each(|(idx, file)| { + file_groups[idx % partition_count].push(file.clone()); + }); - if let Some(filter) = test_data_filters { - builder = builder.with_predicate(filter); - } + let file_scan_config = + FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow)) + .with_file_groups(file_groups) + .with_projection(Some(projection_vector)); - let scan = builder.build(); - return Ok((vec![], Arc::new(scan))); + let mut table_parquet_options = TableParquetOptions::new(); + table_parquet_options.global.pushdown_filters = true; + table_parquet_options.global.reorder_filters = true; + + let mut builder = ParquetExecBuilder::new(file_scan_config) + .with_table_parquet_options(table_parquet_options); + + if let Some(filter) = test_data_filters { + builder = builder.with_predicate(filter); } + let scan = builder.build(); + return Ok((vec![], Arc::new(scan))); + } + OpStruct::Scan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + // If it is not test execution context for unit test, we should have at least one // input source if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index afd5d1951..fbcea8721 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -43,6 +43,7 @@ message Operator { SortMergeJoin sort_merge_join = 108; HashJoin hash_join = 109; Window window = 110; + NativeScan native_scan = 111; } } @@ -52,6 +53,14 @@ message Scan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; +} + +message NativeScan { + repeated spark.spark_expression.DataType fields = 1; + // The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This + // is purely for informational purposes when viewing native query plans in + // debug mode. + string source = 2; repeated string path = 3; string required_schema = 4; string data_schema = 5; diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 952ef39e9..6026fcfff 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -189,6 +189,22 @@ class CometSparkSessionExtensions } // data source V1 + case scanExec @ FileSourceScanExec( + HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _), + _: Seq[_], + requiredSchema, + _, + _, + _, + _, + _, + _) + if CometNativeScanExec.isSchemaSupported(requiredSchema) + && CometNativeScanExec.isSchemaSupported(partitionSchema) + && COMET_FULL_NATIVE_SCAN_ENABLED.get => + logInfo("Comet extension enabled for v1 Scan") + CometNativeScanExec(scanExec, session) + // data source V1 case scanExec @ FileSourceScanExec( HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _), _: Seq[_], @@ -1205,7 +1221,8 @@ object CometSparkSessionExtensions extends Logging { } def isCometScan(op: SparkPlan): Boolean = { - op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec] + op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec] || + op.isInstanceOf[CometNativeScanExec] } private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 1170c55a3..058c809c3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ @@ -2479,6 +2479,69 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim childOp.foreach(result.addChildren) op match { + case scan: CometNativeScanExec => + val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() + nativeScanBuilder.setSource(op.simpleStringWithNodeId()) + + val scanTypes = op.output.flatten { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + nativeScanBuilder.addAllFields(scanTypes.asJava) + + // Sink operators don't have children + result.clearChildren() + + // scalastyle:off println + // System.out.println(op.simpleStringWithNodeId()) + // System.out.println(scanTypes.asJava) // Spark types for output. + System.out.println(scan.output) // This is the names of the output columns. + // System.out.println(cometScan.requiredSchema); // This is the projected columns. + System.out.println( + scan.dataFilters + ); // This is the filter expressions that have been pushed down. + + val dataFilters = scan.dataFilters.map(exprToProto(_, scan.output)) + nativeScanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava) + // System.out.println(cometScan.relation.location.inputFiles(0)) + // System.out.println(cometScan.partitionFilters); + // System.out.println(cometScan.relation.partitionSchema) + // System.out.println(cometScan.metadata); + + // System.out.println("requiredSchema:") + // cometScan.requiredSchema.fields.foreach(field => { + // System.out.println(field.dataType) + // }) + + // System.out.println("relation.dataSchema:") + // cometScan.relation.dataSchema.fields.foreach(field => { + // System.out.println(field.dataType) + // }) + // scalastyle:on println + + val requiredSchemaParquet = + new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema) + val dataSchemaParquet = + new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema) + + nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString) + nativeScanBuilder.setDataSchema(dataSchemaParquet.toString) + scan.relation.location.inputFiles.foreach { f => + nativeScanBuilder.addPath(f) + } + + Some(result.setNativeScan(nativeScanBuilder).build()) + + } else { + // There are unsupported scan type + val msg = + s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above" + emitWarning(msg) + withInfo(op, msg) + None + } + case ProjectExec(projectList, child) if CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) => val exprs = projectList.map(exprToProto(_, child.output)) @@ -2888,49 +2951,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // Sink operators don't have children result.clearChildren() - op match { - case cometScan: CometScanExec => - // scalastyle:off println -// System.out.println(op.simpleStringWithNodeId()) -// System.out.println(scanTypes.asJava) // Spark types for output. - System.out.println(cometScan.output) // This is the names of the output columns. -// System.out.println(cometScan.requiredSchema); // This is the projected columns. - System.out.println( - cometScan.dataFilters - ); // This is the filter expressions that have been pushed down. - - val dataFilters = cometScan.dataFilters.map(exprToProto(_, cometScan.output)) - scanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava) -// System.out.println(cometScan.relation.location.inputFiles(0)) -// System.out.println(cometScan.partitionFilters); -// System.out.println(cometScan.relation.partitionSchema) -// System.out.println(cometScan.metadata); - -// System.out.println("requiredSchema:") -// cometScan.requiredSchema.fields.foreach(field => { -// System.out.println(field.dataType) -// }) - -// System.out.println("relation.dataSchema:") -// cometScan.relation.dataSchema.fields.foreach(field => { -// System.out.println(field.dataType) -// }) - // scalastyle:on println - - val requiredSchemaParquet = - new SparkToParquetSchemaConverter(conf).convert(cometScan.requiredSchema) - val dataSchemaParquet = - new SparkToParquetSchemaConverter(conf).convert(cometScan.relation.dataSchema) - - scanBuilder.setRequiredSchema(requiredSchemaParquet.toString) - scanBuilder.setDataSchema(dataSchemaParquet.toString) - - cometScan.relation.location.inputFiles.foreach { f => - scanBuilder.addPath(f) - } - case _ => - } - Some(result.setScan(scanBuilder).build()) } else { // There are unsupported scan type diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala new file mode 100644 index 000000000..ccd7de0d6 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.mutable.HashMap +import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.comet.shims.ShimCometScanExec +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection._ + +import org.apache.comet.{CometConf, DataTypeSupport, MetricsSupport} +import org.apache.comet.parquet.{CometParquetFileFormat, CometParquetPartitionReaderFactory} + +/** + * Comet physical scan node for DataSource V1. Most of the code here follow Spark's + * [[FileSourceScanExec]], + */ +case class CometNativeScanExec( + @transient relation: HadoopFsRelation, + override val output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false, + originalPlan: FileSourceScanExec) + extends CometPlan + with DataSourceScanExec + with ShimCometScanExec { + + def wrapped: FileSourceScanExec = originalPlan + + // FIXME: ideally we should reuse wrapped.supportsColumnar, however that fails many tests + override lazy val supportsColumnar: Boolean = + relation.fileFormat.supportBatch(relation.sparkSession, schema) + + override def vectorTypes: Option[Seq[String]] = originalPlan.vectorTypes + + private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty + + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = { + driverMetrics.foreach(e => metrics(e._1).add(e._2)) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, + executionId, + metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) + } + + private def isDynamicPruningFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient lazy val selectedPartitions: Array[PartitionDirectory] = { + val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) + val startTime = System.nanoTime() + val ret = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters) + setFilesNumAndSizeMetric(ret, true) + val timeTakenMs = + NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs) + driverMetrics("metadataTime") = timeTakenMs + ret + }.toArray + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter) + + if (dynamicPartitionFilters.nonEmpty) { + val startTime = System.nanoTime() + // call the file index for the files matching all filters except dynamic partition filters + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = Predicate.create( + predicate.transform { case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, + Nil) + val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values)) + setFilesNumAndSizeMetric(ret, false) + val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000 + driverMetrics("pruningTime") = timeTakenMs + ret + } else { + selectedPartitions + } + } + + // exposed for testing + lazy val bucketedScan: Boolean = originalPlan.bucketedScan + + override lazy val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = + (originalPlan.outputPartitioning, originalPlan.outputOrdering) + + @transient + private lazy val pushedDownFilters = getPushedDownFilters(relation, dataFilters) + + override lazy val metadata: Map[String, String] = + if (originalPlan == null) Map.empty else originalPlan.metadata + + override def verboseStringWithOperatorId(): String = { + val metadataStr = metadata.toSeq.sorted + .filterNot { + case (_, value) if (value.isEmpty || value.equals("[]")) => true + case (key, _) if (key.equals("DataFilters") || key.equals("Format")) => true + case (_, _) => false + } + .map { + case (key, _) if (key.equals("Location")) => + val location = relation.location + val numPaths = location.rootPaths.length + val abbreviatedLocation = if (numPaths <= 1) { + location.rootPaths.mkString("[", ", ", "]") + } else { + "[" + location.rootPaths.head + s", ... ${numPaths - 1} entries]" + } + s"$key: ${location.getClass.getSimpleName} ${redact(abbreviatedLocation)}" + case (key, value) => s"$key: ${redact(value)}" + } + + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Output", output)} + |${metadataStr.mkString("\n")} + |""".stripMargin + } + + lazy val inputRDD: RDD[InternalRow] = { + val options = relation.options + + (FileFormat.OPTION_RETURNING_BATCH -> supportsColumnar.toString) + val readFile: (PartitionedFile) => Iterator[InternalRow] = + relation.fileFormat.buildReaderWithPartitionValues( + sparkSession = relation.sparkSession, + dataSchema = relation.dataSchema, + partitionSchema = relation.partitionSchema, + requiredSchema = requiredSchema, + filters = pushedDownFilters, + options = options, + hadoopConf = + relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) + + val readRDD = if (bucketedScan) { + createBucketedReadRDD( + relation.bucketSpec.get, + readFile, + dynamicallySelectedPartitions, + relation) + } else { + createReadRDD(readFile, dynamicallySelectedPartitions, relation) + } + sendDriverMetrics() + readRDD + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + inputRDD :: Nil + } + + /** Helper for computing total number and size of files in selected partitions. */ + private def setFilesNumAndSizeMetric( + partitions: Seq[PartitionDirectory], + static: Boolean): Unit = { + val filesNum = partitions.map(_.files.size.toLong).sum + val filesSize = partitions.map(_.files.map(_.getLen).sum).sum + if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { + driverMetrics("numFiles") = filesNum + driverMetrics("filesSize") = filesSize + } else { + driverMetrics("staticFilesNum") = filesNum + driverMetrics("staticFilesSize") = filesSize + } + if (relation.partitionSchema.nonEmpty) { + driverMetrics("numPartitions") = partitions.length + } + } + + override lazy val metrics: Map[String, SQLMetric] = originalPlan.metrics ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Map( + "scanTime" -> SQLMetrics.createNanoTimingMetric( + sparkContext, + "scan time")) ++ CometMetricNode.scanMetrics(sparkContext) + } else { + Map.empty + } + } ++ { + relation.fileFormat match { + case f: MetricsSupport => f.initMetrics(sparkContext) + case _ => Map.empty + } + } + + override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + new Iterator[ColumnarBatch] { + + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = batches.hasNext + scanTime += System.nanoTime() - startNs + res + } + + override def next(): ColumnarBatch = { + val batch = batches.next() + numOutputRows += batch.numRows() + batch + } + } + } + } + + override def executeCollect(): Array[InternalRow] = { + ColumnarToRowExec(this).executeCollect() + } + + override val nodeName: String = + s"CometNativeScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}" + + /** + * Create an RDD for bucketed reads. The non-bucketed variant of this function is + * [[createReadRDD]]. + * + * The algorithm is pretty simple: each RDD partition being returned should include all the + * files with the same bucket id from all the given Hive partitions. + * + * @param bucketSpec + * the bucketing spec. + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createBucketedReadRDD( + bucketSpec: BucketSpec, + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") + val filesGroupedToBuckets = + selectedPartitions + .flatMap { p => + p.files.map { f => + getPartitionedFile(f, p) + } + } + .groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath.toString()).getName) + .getOrElse(throw invalidBucketFile(f.filePath.toString(), sparkContext.version)) + } + + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { f => + bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + + val filePartitions = optionalNumCoalescedBuckets + .map { numCoalescedBuckets => + logInfo(s"Coalescing to ${numCoalescedBuckets} buckets") + val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets) + Seq.tabulate(numCoalescedBuckets) { bucketId => + val partitionedFiles = coalescedBuckets + .get(bucketId) + .map { + _.values.flatten.toArray + } + .getOrElse(Array.empty) + FilePartition(bucketId, partitionedFiles) + } + } + .getOrElse { + Seq.tabulate(bucketSpec.numBuckets) { bucketId => + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) + } + } + + prepareRDD(fsRelation, readFile, filePartitions) + } + + /** + * Create an RDD for non-bucketed reads. The bucketed variant of this function is + * [[createBucketedReadRDD]]. + * + * @param readFile + * a function to read each (part of a) file. + * @param selectedPartitions + * Hive-style partition that are part of the read. + * @param fsRelation + * [[HadoopFsRelation]] associated with the read. + */ + private def createReadRDD( + readFile: (PartitionedFile) => Iterator[InternalRow], + selectedPartitions: Array[PartitionDirectory], + fsRelation: HadoopFsRelation): RDD[InternalRow] = { + val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes + val maxSplitBytes = + FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + logInfo( + s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"open cost is considered as scanning $openCostInBytes bytes.") + + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + + val splitFiles = selectedPartitions + .flatMap { partition => + partition.files.flatMap { file => + // getPath() is very expensive so we only want to call it once in this block: + val filePath = file.getPath + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, + relation.options, + filePath) && + // SPARK-39634: Allow file splitting in combination with row index generation once + // the fix for PARQUET-2161 is available. + !isNeededForSchema(requiredSchema) + super.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values) + } else { + Seq.empty + } + } + } + .sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + prepareRDD( + fsRelation, + readFile, + FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)) + } + + private def prepareRDD( + fsRelation: HadoopFsRelation, + readFile: (PartitionedFile) => Iterator[InternalRow], + partitions: Seq[FilePartition]): RDD[InternalRow] = { + val hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options) + val prefetchEnabled = hadoopConf.getBoolean( + CometConf.COMET_SCAN_PREFETCH_ENABLED.key, + CometConf.COMET_SCAN_PREFETCH_ENABLED.defaultValue.get) + + val sqlConf = fsRelation.sparkSession.sessionState.conf + if (prefetchEnabled) { + CometParquetFileFormat.populateConf(sqlConf, hadoopConf) + val broadcastedConf = + fsRelation.sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val partitionReaderFactory = CometParquetPartitionReaderFactory( + sqlConf, + broadcastedConf, + requiredSchema, + relation.partitionSchema, + pushedDownFilters.toArray, + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf), + metrics) + + new DataSourceRDD( + fsRelation.sparkSession.sparkContext, + partitions.map(Seq(_)), + partitionReaderFactory, + true, + Map.empty) + } else { + newFileScanRDD( + fsRelation, + readFile, + partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), + new ParquetOptions(CaseInsensitiveMap(relation.options), sqlConf)) + } + } + + // Filters unused DynamicPruningExpression expressions - one which has been replaced + // with DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning + private def filterUnusedDynamicPruningExpressions( + predicates: Seq[Expression]): Seq[Expression] = { + predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + } + + override def doCanonicalize(): CometNativeScanExec = { + CometNativeScanExec( + relation, + output.map(QueryPlan.normalizeExpressions(_, output)), + requiredSchema, + QueryPlan.normalizePredicates( + filterUnusedDynamicPruningExpressions(partitionFilters), + output), + optionalBucketSet, + optionalNumCoalescedBuckets, + QueryPlan.normalizePredicates(dataFilters, output), + None, + disableBucketedScan, + null) + } + +} + +object CometNativeScanExec extends DataTypeSupport { + def apply(scanExec: FileSourceScanExec, session: SparkSession): CometNativeScanExec = { + // TreeNode.mapProductIterator is protected method. + def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = { + val arr = Array.ofDim[B](product.productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(product.productElement(i)) + i += 1 + } + arr + } + + // Replacing the relation in FileSourceScanExec by `copy` seems causing some issues + // on other Spark distributions if FileSourceScanExec constructor is changed. + // Using `makeCopy` to avoid the issue. + // https://github.com/apache/arrow-datafusion-comet/issues/190 + def transform(arg: Any): AnyRef = arg match { + case _: HadoopFsRelation => + scanExec.relation.copy(fileFormat = new CometParquetFileFormat)(session) + case other: AnyRef => other + case null => null + } + val newArgs = mapProductIterator(scanExec, transform(_)) + val wrapped = scanExec.makeCopy(newArgs).asInstanceOf[FileSourceScanExec] + val batchScanExec = CometNativeScanExec( + wrapped.relation, + wrapped.output, + wrapped.requiredSchema, + wrapped.partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan, + wrapped) + scanExec.logicalLink.foreach(batchScanExec.setLogicalLink) + batchScanExec + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index dd1526d82..293bc35fa 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -319,10 +319,10 @@ abstract class CometNativeExec extends CometExec { */ def foreachUntilCometInput(plan: SparkPlan)(func: SparkPlan => Unit): Unit = { plan match { - case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec | - _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | - _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | - _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec | + case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec | + _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | + _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | + _: ReusedExchangeExec | _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec | _: CometSparkToColumnarExec => func(plan) case _: CometPlan => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 99007d0c9..73ccbbd63 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Hex} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometNativeScanExec, CometProjectExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometSparkToColumnarExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec @@ -559,7 +559,8 @@ class CometExecSuite extends CometTestBase { val df = sql("SELECT * FROM tbl WHERE _2 > _3") df.collect() - val metrics = find(df.queryExecution.executedPlan)(_.isInstanceOf[CometScanExec]) + val metrics = find(df.queryExecution.executedPlan)(s => + s.isInstanceOf[CometScanExec] || s.isInstanceOf[CometNativeScanExec]) .map(_.metrics) .get @@ -1484,7 +1485,10 @@ class CometExecSuite extends CometTestBase { val projected = df.selectExpr("_1 as x") val unioned = projected.union(df) val p = unioned.queryExecution.executedPlan.find(_.isInstanceOf[UnionExec]) - assert(p.get.collectLeaves().forall(_.isInstanceOf[CometScanExec])) + assert( + p.get + .collectLeaves() + .forall(o => o.isInstanceOf[CometScanExec] || o.isInstanceOf[CometNativeScanExec])) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61..35ba06902 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -36,7 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} -import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometNativeScanExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -172,7 +172,7 @@ abstract class CometTestBase protected def checkCometOperators(plan: SparkPlan, excludedClasses: Class[_]*): Unit = { val wrapped = wrapCometSparkToColumnar(plan) wrapped.foreach { - case _: CometScanExec | _: CometBatchScanExec => + case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec => case _: CometSinkPlaceHolder | _: CometScanWrapper => case _: CometSparkToColumnarExec => case _: CometExec | _: CometShuffleExchangeExec =>