Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnavBalyan committed Jan 30, 2025
1 parent f0336c0 commit b715f6e
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, JoinType}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}

case class VeloxBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
Expand Down Expand Up @@ -51,4 +55,17 @@ case class VeloxBroadcastNestedLoopJoinExecTransformer(
newRight: SparkPlan): VeloxBroadcastNestedLoopJoinExecTransformer =
copy(left = newLeft, right = newRight)

override def genJoinParameters(): Any = {
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

}
8 changes: 8 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::CrossRel_JoinType::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI:
if (crossRel.has_advanced_extension() &&
SubstraitParser::configSetInOptimization(crossRel.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
VELOX_NYI("Unsupported Join type: {}", std::to_string(crossRel.type()));
}
break;
default:
VELOX_NYI("Unsupported Join type: {}", std::to_string(crossRel.type()));
}
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR
switch (crossRel.type()) {
case ::substrait::CrossRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::CrossRel_JoinType_JOIN_TYPE_LEFT_SEMI:
break;
default:
LOG_VALIDATION_MSG("Unsupported Join type in CrossRel");
Expand Down
4 changes: 2 additions & 2 deletions ep/build-velox/src/get_velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

set -exu

VELOX_REPO=https://github.com/oap-project/velox.git
VELOX_BRANCH=2025_01_22
VELOX_REPO=https://github.com/ArnavBalyan/velox.git
VELOX_BRANCH=2025_01_30_nested_join
VELOX_HOME=""

OS=`uname -s`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.utils.SubstraitUtil

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
import org.apache.spark.sql.execution.joins.BaseJoinExec
Expand Down Expand Up @@ -87,6 +87,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case j: ExistenceJoin =>
left.output :+ j.exists
case LeftExistence(_) =>
left.output
case FullOuter =>
Expand All @@ -108,7 +110,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
case BuildRight =>
joinType match {
case _: InnerLike => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case LeftOuter | ExistenceJoin(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType with building right side")
Expand Down Expand Up @@ -169,6 +171,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
def validateJoinTypeAndBuildSide(): ValidationResult = {
val result = joinType match {
case _: InnerLike | LeftOuter | RightOuter => ValidationResult.succeeded
case ExistenceJoin(_) => ValidationResult.succeeded
case _ =>
ValidationResult.failed(s"$joinType join is not supported with BroadcastNestedLoopJoin")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.ExpressionNode

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}

import io.substrait.proto.{CrossRel, JoinRel, NamedStruct, Type}

Expand Down Expand Up @@ -59,7 +59,7 @@ object SubstraitUtil {
// the left and right relations are exchanged and the
// join type is reverted.
CrossRel.JoinType.JOIN_TYPE_LEFT
case LeftSemi =>
case LeftSemi | ExistenceJoin(_) =>
CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI
case FullOuter =>
CrossRel.JoinType.JOIN_TYPE_OUTER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,67 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.GlutenSQLTestsBaseTrait

class GlutenExistenceJoinSuite extends ExistenceJoinSuite with GlutenSQLTestsBaseTrait {}
class GlutenExistenceJoinSuite extends ExistenceJoinSuite with GlutenSQLTestsBaseTrait {

test("force existence join with BNJ using Catalyst APIs") {
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{JoinHint, _}
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.join.preferSortMergeJoin", "false")

val left: DataFrame = spark.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c")
)),
new StructType().add("id", IntegerType).add("val", StringType)
)

val right: DataFrame = spark.createDataFrame(
sparkContext.parallelize(
Seq(
Row(1, "x"),
Row(3, "y")
)),
new StructType().add("id", IntegerType).add("val", StringType)
)

val leftPlan = left.logicalPlan
val rightPlan = right.logicalPlan

val existsAttr = AttributeReference("exists", BooleanType, nullable = false)()

val joinCondition: Expression = LessThan(leftPlan.output(0), rightPlan.output(0))

val existenceJoin = Join(
left = leftPlan,
right = rightPlan,
joinType = ExistenceJoin(existsAttr),
condition = Some(joinCondition),
hint = JoinHint.NONE
)

val project = Project(
projectList = leftPlan.output :+ existsAttr,
child = existenceJoin
)

val df = Dataset.ofRows(spark, project)

assert(existenceJoin.joinType == ExistenceJoin(existsAttr))
assert(existenceJoin.condition.contains(joinCondition))
val expected = Seq(
Row(1, "a", true),
Row(2, "b", true),
Row(3, "c", false)
)
assert(df.collect() === expected)

}
}

0 comments on commit b715f6e

Please sign in to comment.