-
Notifications
You must be signed in to change notification settings - Fork 241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the broadcast joins issues caused by InputFileBlockRule[databricks] #9673
Changes from 6 commits
2c742c0
314e8ce
2ab7ad7
6cb8167
1da5d52
24e9b7f
69f05f1
6907c38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* Copyright (c) 2021-2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
|
@@ -15,36 +15,33 @@ | |
*/ | ||
package com.nvidia.spark.rapids | ||
|
||
import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} | ||
import scala.collection.mutable | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName} | ||
import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan} | ||
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec | ||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec | ||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} | ||
import org.apache.spark.sql.rapids.{GpuInputFileBlockLength, GpuInputFileBlockStart, GpuInputFileName} | ||
|
||
/** | ||
* InputFileBlockRule is to prevent the SparkPlans | ||
* [SparkPlan (with first input_file_xxx expression), FileScan) to run on GPU | ||
* | ||
* See https://github.com/NVIDIA/spark-rapids/issues/3333 | ||
* A rule prevents the plans [SparkPlan (with first input_file_xxx expression), FileScan) | ||
* from running on GPU. | ||
* For more details, please go to https://github.com/NVIDIA/spark-rapids/issues/3333. | ||
*/ | ||
object InputFileBlockRule { | ||
private type PlanMeta = SparkPlanMeta[SparkPlan] | ||
|
||
private def checkHasInputFileExpressions(plan: SparkPlan): Boolean = { | ||
plan.expressions.exists(GpuTransitionOverrides.checkHasInputFileExpressions) | ||
} | ||
|
||
// Apply the rule on SparkPlanMeta | ||
def apply(plan: SparkPlanMeta[SparkPlan]) = { | ||
/** | ||
* key: the SparkPlanMeta where has the first input_file_xxx expression | ||
* value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan) | ||
*/ | ||
val resultOps = LinkedHashMap[SparkPlanMeta[SparkPlan], ArrayBuffer[SparkPlanMeta[SparkPlan]]]() | ||
def apply(plan: PlanMeta): Unit = { | ||
// key: the SparkPlanMeta where has the first input_file_xxx expression | ||
// value: an array of the SparkPlanMeta chain [SparkPlan (with first input_file_xxx), FileScan) | ||
val resultOps = mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]() | ||
recursivelyResolve(plan, None, resultOps) | ||
|
||
// If we've found some chains, we should prevent the transition. | ||
resultOps.foreach { item => | ||
item._2.foreach(p => p.inputFilePreventsRunningOnGpu()) | ||
resultOps.foreach { case (_, metas) => | ||
metas.foreach(_.willNotWorkOnGpu("GPU plans may get incorrect file name" + | ||
", or file start or file length from a CPU scan")) | ||
} | ||
} | ||
|
||
|
@@ -54,39 +51,51 @@ object InputFileBlockRule { | |
* @param key the SparkPlanMeta with the first input_file_xxx | ||
* @param resultOps the found SparkPlan chain | ||
*/ | ||
private def recursivelyResolve( | ||
plan: SparkPlanMeta[SparkPlan], | ||
key: Option[SparkPlanMeta[SparkPlan]], | ||
resultOps: LinkedHashMap[SparkPlanMeta[SparkPlan], | ||
ArrayBuffer[SparkPlanMeta[SparkPlan]]]): Unit = { | ||
|
||
private def recursivelyResolve(plan: PlanMeta, key: Option[PlanMeta], | ||
resultOps: mutable.LinkedHashMap[PlanMeta, ArrayBuffer[PlanMeta]]): Unit = { | ||
plan.wrapped match { | ||
case _: ShuffleExchangeExec => // Exchange will invalid the input_file_xxx | ||
case _: ShuffleExchangeLike => // Exchange will invalid the input_file_xxx | ||
key.map(p => resultOps.remove(p)) // Remove the chain from Map | ||
plan.childPlans.foreach(p => recursivelyResolve(p, None, resultOps)) | ||
case _: FileSourceScanExec | _: BatchScanExec => | ||
if (plan.canThisBeReplaced) { // FileScan can be replaced | ||
key.map(p => resultOps.remove(p)) // Remove the chain from Map | ||
} | ||
case _: BroadcastExchangeLike => | ||
// noop: Don't go any further, the file info cannot come from a broadcast. | ||
case _: LeafExecNode => // We've reached the LeafNode but without any FileScan | ||
key.map(p => resultOps.remove(p)) // Remove the chain from Map | ||
case _ => | ||
val newKey = if (key.isDefined) { | ||
// The node is in the middle of chain [SparkPlan with input_file_xxx, FileScan) | ||
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan | ||
resultOps.getOrElseUpdate(key.get, new ArrayBuffer[PlanMeta]) += plan | ||
key | ||
} else { // There is no parent Node who has input_file_xxx | ||
if (checkHasInputFileExpressions(plan.wrapped)) { | ||
// Current node has input_file_xxx. Mark it as the first Node with input_file_xxx | ||
resultOps.getOrElseUpdate(plan, new ArrayBuffer[SparkPlanMeta[SparkPlan]]) += plan | ||
} else { // There is no parent node who has input_file_xxx | ||
if (hasInputFileExpression(plan.wrapped)) { | ||
// Current node has input_file_xxx. Mark it as the first node with input_file_xxx | ||
resultOps.getOrElseUpdate(plan, new ArrayBuffer[PlanMeta]) += plan | ||
Some(plan) | ||
} else { | ||
None | ||
} | ||
} | ||
|
||
plan.childPlans.foreach(p => recursivelyResolve(p, newKey, resultOps)) | ||
} | ||
} | ||
|
||
private def hasInputFileExpression(expr: Expression): Boolean = expr match { | ||
case _: InputFileName => true | ||
case _: InputFileBlockStart => true | ||
case _: InputFileBlockLength => true | ||
case _: GpuInputFileName => true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm why do we still need to return true given it's already converted to Gpu case? Given the reason mentioned above is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be used for two stages during the overiding process. The stage after inserting transitions for row and column may get a InputFileName or a GpuInputFileName. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concerning this issue, we will never get a GpuInputFileName since plan conversion does not happen. |
||
case _: GpuInputFileBlockStart => true | ||
case _: GpuInputFileBlockLength => true | ||
case e => e.children.exists(hasInputFileExpression) | ||
} | ||
|
||
/** Whether a plan has any InputFile{Name, BlockStart, BlockLength} expression. */ | ||
def hasInputFileExpression(plan: SparkPlan): Boolean = { | ||
plan.expressions.exists(hasInputFileExpression) | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,22 @@ abstract class GpuBroadcastHashJoinMetaBase( | |
} | ||
} | ||
|
||
// Called in runAfterTagRules for a special post tagging for this broadcast join. | ||
def checkTagForBuildSide(): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make more sense to move this into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not do that because there are 4 shims for GpuBroadcastJoinMeta, which means I need to duplicate this code 4 times. The current option looks much simpler, only two times. |
||
val Seq(leftChild, rightChild) = childPlans | ||
val buildSideMeta = buildSide match { | ||
case GpuBuildLeft => leftChild | ||
case GpuBuildRight => rightChild | ||
} | ||
// Check both of the conditions to avoid duplicate reason string. | ||
if (!canThisBeReplaced && canBuildSideBeReplaced(buildSideMeta)) { | ||
buildSideMeta.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU") | ||
} | ||
if (canThisBeReplaced && !canBuildSideBeReplaced(buildSideMeta)) { | ||
willNotWorkOnGpu("the broadcast for this join must be on the GPU too") | ||
} | ||
} | ||
|
||
def convertToGpu(): GpuExec | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran these tests on the current 23.12 and
test_broadcast_hash_join_fix_fallback_by_inputfile[BroadcastExchange-ParquetScan]
produced the wrong answer, buttest_broadcast_hash_join_fix_fallback_by_inputfile[GpuBroadcastExchange-ParquetScan]
failed with not falling back as expected.test_broadcast_nested_join_fix_fallback_by_inputfile
passed in all cases and none of them triggered the error as described in #9469Can we please add in a test that is the same as #9469 so we can be sure that it is fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case in #9469 requires Iceberg to run, so we can not test this for Spark 330+, is it OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the tests, now they can reproduce the same error as #9469 on the current 23.12.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@revans2 Could you take a look again? Thx in advance.