diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 5e32e07d230..ef9f4d03f24 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -1178,6 +1178,19 @@ def do_join(spark): "spark.sql.sources.useV1SourceList": "", "spark.rapids.sql.input." + scan_name: False}) +@ignore_order(local=True) +@pytest.mark.parametrize("join_type", ["Inner", "LeftOuter"], ids=idfn) +@pytest.mark.parametrize("batch_size", ["500", "1g"], ids=idfn) +def test_distinct_join(join_type, batch_size): + join_conf = { + "spark.rapids.sql.batchSizeBytes": batch_size + } + def do_join(spark): + left_df = spark.range(1024).withColumn("x", f.col("id") + 1) + right_df = spark.range(768).withColumn("x", f.col("id") * 2) + return left_df.join(right_df, ["x"], join_type) + assert_gpu_and_cpu_are_equal_collect(do_join, conf=join_conf) + @ignore_order(local=True) @pytest.mark.parametrize("join_type", ["Inner", "FullOuter"], ids=idfn) @pytest.mark.parametrize("is_left_host_shuffle", [False, True], ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala index 91e5235a22c..087d6b59098 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala @@ -328,52 +328,54 @@ abstract class SplittableJoinIterator( joinType: JoinType): Option[JoinGatherer] = { assert(maps.length > 0 && maps.length <= 2) try { - val leftMap = maps.head - val rightMap = if (maps.length > 1) { - if (rightData.numCols == 0) { - // No data so don't bother with it - None - } else { - Some(maps(1)) - } - } else { - None + val leftGatherer = joinType match { + case LeftOuter if maps.length == 1 => + // Distinct left outer joins only produce a single gather map since left table rows + // are not rearranged by the join. + new JoinGathererSameTable(leftData) + case _ => + val lazyLeftMap = LazySpillableGatherMap(maps.head, "left_map") + // Inner joins -- manifest the intersection of both left and right sides. The gather maps + // contain the number of rows that must be manifested, and every index + // must be within bounds, so we can skip the bounds checking. + // + // Left outer -- Left outer manifests all rows for the left table. The left gather map + // must contain valid indices, so we skip the check for the left side. + val leftOutOfBoundsPolicy = joinType match { + case _: InnerLike | LeftOuter => OutOfBoundsPolicy.DONT_CHECK + case _ => OutOfBoundsPolicy.NULLIFY + } + JoinGatherer(lazyLeftMap, leftData, leftOutOfBoundsPolicy) + } + val rightMap = joinType match { + case _ if rightData.numCols == 0 => None + case LeftOuter if maps.length == 1 => + // Distinct left outer joins only produce a single gather map since left table rows + // are not rearranged by the join. + Some(maps.head) + case _ if maps.length == 1 => None + case _ => Some(maps(1)) } - - val lazyLeftMap = LazySpillableGatherMap(leftMap, "left_map") val gatherer = rightMap match { case None => // When there isn't a `rightMap` we are in either LeftSemi or LeftAnti joins. // In these cases, the map and the table are both the left side, and everything in the map // is a match on the left table, so we don't want to check for bounds. rightData.close() - JoinGatherer(lazyLeftMap, leftData, OutOfBoundsPolicy.DONT_CHECK) + leftGatherer case Some(right) => // Inner joins -- manifest the intersection of both left and right sides. The gather maps // contain the number of rows that must be manifested, and every index // must be within bounds, so we can skip the bounds checking. // - // Left outer -- Left outer manifests all rows for the left table. The left gather map - // must contain valid indices, so we skip the check for the left side. The right side - // has to be checked, since we need to produce nulls (for the right) for those - // rows on the left side that don't have a match on the right. - // // Right outer -- Is the opposite from left outer (skip right bounds check, keep left) - // - // Full outer -- Can produce nulls for any left or right rows that don't have a match - // in the opposite table. So we must check both gather maps. - // - val leftOutOfBoundsPolicy = joinType match { - case _: InnerLike | LeftOuter => OutOfBoundsPolicy.DONT_CHECK - case _ => OutOfBoundsPolicy.NULLIFY - } val rightOutOfBoundsPolicy = joinType match { case _: InnerLike | RightOuter => OutOfBoundsPolicy.DONT_CHECK case _ => OutOfBoundsPolicy.NULLIFY } val lazyRightMap = LazySpillableGatherMap(right, "right_map") - JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData, - leftOutOfBoundsPolicy, rightOutOfBoundsPolicy) + val rightGatherer = JoinGatherer(lazyRightMap, rightData, rightOutOfBoundsPolicy) + MultiJoinGather(leftGatherer, rightGatherer) } if (gatherer.isDone) { // Nothing matched... diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala index 80ec540ff84..c4584086173 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/JoinGatherer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,12 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table} import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import org.apache.spark.TaskContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -644,6 +646,104 @@ class JoinGathererImpl( } } +/** + * JoinGatherer for the case where the gather produces the same table as the input table. + */ +class JoinGathererSameTable( + private val data: LazySpillableColumnarBatch) extends JoinGatherer { + + assert(data.numCols > 0, "data with no columns should have been filtered out already") + + // How much of the gather map we have output so far + private var gatheredUpTo: Long = 0 + private var gatheredUpToCheckpoint: Long = 0 + private val totalRows: Long = data.numRows + private val fixedWidthRowSizeBits = { + val dts = data.dataTypes + JoinGathererImpl.fixedWidthRowSizeBits(dts) + } + + override def checkpoint: Unit = { + gatheredUpToCheckpoint = gatheredUpTo + data.checkpoint() + } + + override def restore: Unit = { + gatheredUpTo = gatheredUpToCheckpoint + data.restore() + } + + override def toString: String = { + s"SAMEGATHER $gatheredUpTo/$totalRows $data" + } + + override def realCheapPerRowSizeEstimate: Double = { + val totalInputRows: Int = data.numRows + val totalInputSize: Long = data.deviceMemorySize + // Avoid divide by 0 here and later on + if (totalInputRows > 0 && totalInputSize > 0) { + totalInputSize.toDouble / totalInputRows + } else { + 1.0 + } + } + + override def getFixedWidthBitSize: Option[Int] = fixedWidthRowSizeBits + + override def gatherNext(n: Int): ColumnarBatch = { + assert(gatheredUpTo + n <= totalRows) + val ret = sliceForGather(n) + gatheredUpTo += n + ret + } + + override def isDone: Boolean = + gatheredUpTo >= totalRows + + override def numRowsLeft: Long = totalRows - gatheredUpTo + + override def allowSpilling(): Unit = { + data.allowSpilling() + } + + override def getBitSizeMap(n: Int): ColumnView = { + withResource(sliceForGather(n)) { cb => + withResource(GpuColumnVector.from(cb)) { table => + withResource(table.rowBitCount()) { bits => + bits.castTo(DType.INT64) + } + } + } + } + + override def close(): Unit = { + data.close() + } + + private def isFullBatchGather(n: Int): Boolean = gatheredUpTo == 0 && n == totalRows + + private def sliceForGather(n: Int): ColumnarBatch = { + val cb = data.getBatch + if (isFullBatchGather(n)) { + GpuColumnVector.incRefCounts(cb) + } else { + val splitStart = gatheredUpTo.toInt + val splitEnd = splitStart + n + val inputColumns = GpuColumnVector.extractColumns(cb) + val outputColumns: Array[vectorized.ColumnVector] = inputColumns.safeMap { c => + val views = c.getBase.splitAsViews(splitStart, splitEnd) + assert(views.length == 3, s"Unexpected number of views: ${views.length}") + views(0).safeClose() + views(2).safeClose() + withResource(views(1)) { v => + GpuColumnVector.from(v.copyToColumnVector(), c.dataType()) + } + } + new ColumnarBatch(outputColumns, splitEnd - splitStart) + } + } +} + /** * Join Gatherer for a left table and a right table */ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index 0ce0c1a8608..631ca0da090 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -474,6 +474,8 @@ class HashJoinIterator( None } else { val maps = joinType match { + case LeftOuter if isDistinctJoin => + Array(leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual)) case LeftOuter => leftKeys.leftJoinGatherMaps(rightKeys, compareNullsEqual) case RightOuter => // Reverse the output of the join, because we expect the right gather map to