Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinct left join #10520

Merged
merged 6 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,19 @@ def do_join(spark):
"spark.sql.sources.useV1SourceList": "",
"spark.rapids.sql.input." + scan_name: False})

@ignore_order(local=True)
@pytest.mark.parametrize("join_type", ["Inner", "LeftOuter"], ids=idfn)
@pytest.mark.parametrize("batch_size", ["500", "1g"], ids=idfn)
def test_distinct_join(join_type, batch_size):
join_conf = {
"spark.rapids.sql.batchSizeBytes": batch_size
}
def do_join(spark):
left_df = spark.range(1024).withColumn("x", f.col("id") + 1)
right_df = spark.range(768).withColumn("x", f.col("id") * 2)
return left_df.join(right_df, ["x"], join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=join_conf)

@ignore_order(local=True)
@pytest.mark.parametrize("join_type", ["Inner", "FullOuter"], ids=idfn)
@pytest.mark.parametrize("is_left_host_shuffle", [False, True], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,52 +328,54 @@ abstract class SplittableJoinIterator(
joinType: JoinType): Option[JoinGatherer] = {
assert(maps.length > 0 && maps.length <= 2)
try {
val leftMap = maps.head
val rightMap = if (maps.length > 1) {
if (rightData.numCols == 0) {
// No data so don't bother with it
None
} else {
Some(maps(1))
}
} else {
None
val leftGatherer = joinType match {
case LeftOuter if maps.length == 1 =>
// Distinct left outer joins only produce a single gather map since left table rows
// are not rearranged by the join.
new JoinGathererSameTable(leftData)
case _ =>
val lazyLeftMap = LazySpillableGatherMap(maps.head, "left_map")
// Inner joins -- manifest the intersection of both left and right sides. The gather maps
// contain the number of rows that must be manifested, and every index
// must be within bounds, so we can skip the bounds checking.
//
// Left outer -- Left outer manifests all rows for the left table. The left gather map
// must contain valid indices, so we skip the check for the left side.
val leftOutOfBoundsPolicy = joinType match {
case _: InnerLike | LeftOuter => OutOfBoundsPolicy.DONT_CHECK
case _ => OutOfBoundsPolicy.NULLIFY
}
JoinGatherer(lazyLeftMap, leftData, leftOutOfBoundsPolicy)
}
val rightMap = joinType match {
case _ if rightData.numCols == 0 => None
case LeftOuter if maps.length == 1 =>
// Distinct left outer joins only produce a single gather map since left table rows
// are not rearranged by the join.
Some(maps.head)
case _ if maps.length == 1 => None
case _ => Some(maps(1))
}

val lazyLeftMap = LazySpillableGatherMap(leftMap, "left_map")
val gatherer = rightMap match {
case None =>
// When there isn't a `rightMap` we are in either LeftSemi or LeftAnti joins.
// In these cases, the map and the table are both the left side, and everything in the map
// is a match on the left table, so we don't want to check for bounds.
rightData.close()
JoinGatherer(lazyLeftMap, leftData, OutOfBoundsPolicy.DONT_CHECK)
leftGatherer
case Some(right) =>
// Inner joins -- manifest the intersection of both left and right sides. The gather maps
// contain the number of rows that must be manifested, and every index
// must be within bounds, so we can skip the bounds checking.
//
// Left outer -- Left outer manifests all rows for the left table. The left gather map
// must contain valid indices, so we skip the check for the left side. The right side
// has to be checked, since we need to produce nulls (for the right) for those
// rows on the left side that don't have a match on the right.
//
// Right outer -- Is the opposite from left outer (skip right bounds check, keep left)
//
// Full outer -- Can produce nulls for any left or right rows that don't have a match
// in the opposite table. So we must check both gather maps.
//
val leftOutOfBoundsPolicy = joinType match {
case _: InnerLike | LeftOuter => OutOfBoundsPolicy.DONT_CHECK
case _ => OutOfBoundsPolicy.NULLIFY
}
val rightOutOfBoundsPolicy = joinType match {
case _: InnerLike | RightOuter => OutOfBoundsPolicy.DONT_CHECK
case _ => OutOfBoundsPolicy.NULLIFY
}
val lazyRightMap = LazySpillableGatherMap(right, "right_map")
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData,
leftOutOfBoundsPolicy, rightOutOfBoundsPolicy)
val rightGatherer = JoinGatherer(lazyRightMap, rightData, rightOutOfBoundsPolicy)
MultiJoinGather(leftGatherer, rightGatherer)
}
if (gatherer.isDone) {
// Nothing matched...
Expand Down
102 changes: 101 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,12 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}
import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit

import org.apache.spark.TaskContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -644,6 +646,104 @@ class JoinGathererImpl(
}
}

/**
* JoinGatherer for the case where the gather produces the same table as the input table.
*/
class JoinGathererSameTable(
private val data: LazySpillableColumnarBatch) extends JoinGatherer {

assert(data.numCols > 0, "data with no columns should have been filtered out already")

// How much of the gather map we have output so far
private var gatheredUpTo: Long = 0
private var gatheredUpToCheckpoint: Long = 0
private val totalRows: Long = data.numRows
private val fixedWidthRowSizeBits = {
val dts = data.dataTypes
JoinGathererImpl.fixedWidthRowSizeBits(dts)
}

override def checkpoint: Unit = {
gatheredUpToCheckpoint = gatheredUpTo
data.checkpoint()
}

override def restore: Unit = {
gatheredUpTo = gatheredUpToCheckpoint
data.restore()
}

override def toString: String = {
s"SAMEGATHER $gatheredUpTo/$totalRows $data"
}

override def realCheapPerRowSizeEstimate: Double = {
val totalInputRows: Int = data.numRows
val totalInputSize: Long = data.deviceMemorySize
// Avoid divide by 0 here and later on
if (totalInputRows > 0 && totalInputSize > 0) {
totalInputSize.toDouble / totalInputRows
} else {
1.0
}
}

override def getFixedWidthBitSize: Option[Int] = fixedWidthRowSizeBits

override def gatherNext(n: Int): ColumnarBatch = {
assert(gatheredUpTo + n <= totalRows)
val ret = sliceForGather(n)
gatheredUpTo += n
ret
}

override def isDone: Boolean =
gatheredUpTo >= totalRows

override def numRowsLeft: Long = totalRows - gatheredUpTo

override def allowSpilling(): Unit = {
data.allowSpilling()
}

override def getBitSizeMap(n: Int): ColumnView = {
withResource(sliceForGather(n)) { cb =>
withResource(GpuColumnVector.from(cb)) { table =>
withResource(table.rowBitCount()) { bits =>
bits.castTo(DType.INT64)
}
}
}
}

override def close(): Unit = {
data.close()
}

private def isFullBatchGather(n: Int): Boolean = gatheredUpTo == 0 && n == totalRows

private def sliceForGather(n: Int): ColumnarBatch = {
val cb = data.getBatch
if (isFullBatchGather(n)) {
GpuColumnVector.incRefCounts(cb)
} else {
val splitStart = gatheredUpTo.toInt
val splitEnd = splitStart + n
val inputColumns = GpuColumnVector.extractColumns(cb)
val outputColumns: Array[vectorized.ColumnVector] = inputColumns.safeMap { c =>
val views = c.getBase.splitAsViews(splitStart, splitEnd)
assert(views.length == 3, s"Unexpected number of views: ${views.length}")
views(0).safeClose()
views(2).safeClose()
withResource(views(1)) { v =>
GpuColumnVector.from(v.copyToColumnVector(), c.dataType())
}
}
new ColumnarBatch(outputColumns, splitEnd - splitStart)
}
}
}

/**
* Join Gatherer for a left table and a right table
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ class HashJoinIterator(
None
} else {
val maps = joinType match {
case LeftOuter if isDistinctJoin =>
Array(leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual))
case LeftOuter => leftKeys.leftJoinGatherMaps(rightKeys, compareNullsEqual)
case RightOuter =>
// Reverse the output of the join, because we expect the right gather map to
Expand Down
Loading