Skip to content

Commit

Permalink
private[spill]
Browse files Browse the repository at this point in the history
  • Loading branch information
abellina committed Nov 22, 2024
1 parent 4cde905 commit 2619cab
Show file tree
Hide file tree
Showing 18 changed files with 66 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicLong

import ai.rapids.cudf.{Cuda, Rmm, RmmEventHandler}
import com.nvidia.spark.rapids.spill.SpillableDeviceStore
import com.sun.management.HotSpotDiagnosticMXBean

import org.apache.spark.internal.Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.util.control.NonFatal

import ai.rapids.cudf._
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.spill.SpillFramework

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{DefaultHostMemoryAllocator, HostMemoryAllocator, HostMemoryBuffer, MemoryBuffer, PinnedMemoryPool}
import com.nvidia.spark.rapids.jni.{CpuRetryOOM, RmmSpark}
import com.nvidia.spark.rapids.spill.SpillFramework

import org.apache.spark.internal.Logging

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, Table}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.spill.{SpillableDeviceBufferHandle, SpillableHandle}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.{DeviceMemoryBuffer, Table}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle

import org.apache.spark.internal.Logging
import org.apache.spark.sql.types.DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer}
import com.nvidia.spark.rapids.Arm.closeOnExcept
import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchFromBufferHandle, SpillableColumnarBatchHandle, SpillableCompressedColumnarBatchHandle, SpillableDeviceBufferHandle, SpillableHostBufferHandle, SpillableHostColumnarBatchHandle}

import org.apache.spark.TaskContext
import org.apache.spark.sql.types.DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids
package com.nvidia.spark.rapids.spill

import java.io.{DataOutputStream, File, FileInputStream, InputStream, OutputStream}
import java.io._
import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.file.StandardOpenOption
Expand All @@ -26,10 +26,12 @@ import java.util.concurrent.ConcurrentHashMap

import scala.collection.mutable

import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange, Table}
import ai.rapids.cudf._
import com.nvidia.spark.rapids.{GpuColumnVector, GpuColumnVectorFromBuffer, GpuCompressedColumnVector, GpuDeviceManager, HostAlloc, HostMemoryOutputStream, MemoryBufferToHostByteBufferIterator, RapidsConf, RapidsHostColumnVector}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.internal.HostByteBufferIterator
import org.apache.commons.io.IOUtils

import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
Expand Down Expand Up @@ -154,7 +156,7 @@ trait SpillableHandle extends StoreHandle {
* is a `val`.
* @return true if currently spillable, false otherwise
*/
def spillable: Boolean = sizeInBytes > 0
private[spill] def spillable: Boolean = sizeInBytes > 0
}

/**
Expand All @@ -163,9 +165,9 @@ trait SpillableHandle extends StoreHandle {
* on the device.
*/
trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
var dev: Option[T]
private[spill] var dev: Option[T]

override def spillable: Boolean = synchronized {
private[spill] override def spillable: Boolean = synchronized {
super.spillable && dev.isDefined
}
}
Expand All @@ -176,9 +178,9 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
* on the host.
*/
trait HostSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
var host: Option[T]
private[spill] var host: Option[T]

override def spillable: Boolean = synchronized {
private[spill] override def spillable: Boolean = synchronized {
super.spillable && host.isDefined
}
}
Expand All @@ -194,7 +196,7 @@ object SpillableHostBufferHandle extends Logging {
handle
}

def createHostHandleWithPacker(
private[spill] def createHostHandleWithPacker(
chunkedPacker: ChunkedPacker): SpillableHostBufferHandle = {
val handle = new SpillableHostBufferHandle(chunkedPacker.getTotalContiguousSize)
withResource(
Expand All @@ -213,7 +215,7 @@ object SpillableHostBufferHandle extends Logging {
}
}

def createHostHandleFromDeviceBuff(
private[spill] def createHostHandleFromDeviceBuff(
buff: DeviceMemoryBuffer): SpillableHostBufferHandle = {
val handle = new SpillableHostBufferHandle(buff.getLength)
withResource(
Expand All @@ -227,11 +229,11 @@ object SpillableHostBufferHandle extends Logging {

class SpillableHostBufferHandle private (
override val sizeInBytes: Long,
override var host: Option[HostMemoryBuffer] = None,
var disk: Option[DiskHandle] = None)
private[spill] override var host: Option[HostMemoryBuffer] = None,
private[spill] var disk: Option[DiskHandle] = None)
extends HostSpillableHandle[HostMemoryBuffer] {

override def spillable: Boolean = synchronized {
private[spill] override def spillable: Boolean = synchronized {
if (super.spillable) {
host.getOrElse {
throw new IllegalStateException(
Expand Down Expand Up @@ -307,7 +309,7 @@ class SpillableHostBufferHandle private (
}
}

def getHostBuffer: Option[HostMemoryBuffer] = synchronized {
private[spill] def getHostBuffer: Option[HostMemoryBuffer] = synchronized {
host.foreach(_.incRefCount())
host
}
Expand All @@ -320,7 +322,7 @@ class SpillableHostBufferHandle private (
}
}

def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = {
private[spill] def materializeToDeviceMemoryBuffer(dmb: DeviceMemoryBuffer): Unit = {
var hostBuffer: HostMemoryBuffer = null
var diskHandle: DiskHandle = null
synchronized {
Expand Down Expand Up @@ -351,11 +353,11 @@ class SpillableHostBufferHandle private (
}
}

def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized {
private[spill] def setHost(singleShotBuffer: HostMemoryBuffer): Unit = synchronized {
host = Some(singleShotBuffer)
}

def setDisk(handle: DiskHandle): Unit = synchronized {
private[spill] def setDisk(handle: DiskHandle): Unit = synchronized {
disk = Some(handle)
}
}
Expand All @@ -370,11 +372,11 @@ object SpillableDeviceBufferHandle {

class SpillableDeviceBufferHandle private (
override val sizeInBytes: Long,
override var dev: Option[DeviceMemoryBuffer],
var host: Option[SpillableHostBufferHandle] = None)
private[spill] override var dev: Option[DeviceMemoryBuffer],
private[spill] var host: Option[SpillableHostBufferHandle] = None)
extends DeviceSpillableHandle[DeviceMemoryBuffer] {

override def spillable: Boolean = synchronized {
private[spill] override def spillable: Boolean = synchronized {
if (super.spillable) {
dev.getOrElse {
throw new IllegalStateException(
Expand Down Expand Up @@ -445,8 +447,8 @@ class SpillableDeviceBufferHandle private (

class SpillableColumnarBatchHandle private (
override val sizeInBytes: Long,
override var dev: Option[ColumnarBatch],
var host: Option[SpillableHostBufferHandle] = None)
private[spill] override var dev: Option[ColumnarBatch],
private[spill] var host: Option[SpillableHostBufferHandle] = None)
extends DeviceSpillableHandle[ColumnarBatch] with Logging {

override def spillable: Boolean = synchronized {
Expand Down Expand Up @@ -573,13 +575,13 @@ object SpillableColumnarBatchFromBufferHandle {

class SpillableColumnarBatchFromBufferHandle private (
override val sizeInBytes: Long,
override var dev: Option[ColumnarBatch],
var host: Option[SpillableHostBufferHandle] = None)
private[spill] override var dev: Option[ColumnarBatch],
private[spill] var host: Option[SpillableHostBufferHandle] = None)
extends DeviceSpillableHandle[ColumnarBatch] {

private var meta: Option[TableMeta] = None

override def spillable: Boolean = synchronized {
private[spill] override def spillable: Boolean = synchronized {
if (super.spillable) {
val dcvs = GpuColumnVector.extractBases(dev.get)
val colRepetition = mutable.HashMap[ColumnVector, Int]()
Expand Down Expand Up @@ -670,8 +672,8 @@ object SpillableCompressedColumnarBatchHandle {

class SpillableCompressedColumnarBatchHandle private (
val compressedSizeInBytes: Long,
override var dev: Option[ColumnarBatch],
var host: Option[SpillableHostBufferHandle] = None)
private[spill] override var dev: Option[ColumnarBatch],
private[spill] var host: Option[SpillableHostBufferHandle] = None)
extends DeviceSpillableHandle[ColumnarBatch] {

override val sizeInBytes: Long = compressedSizeInBytes
Expand Down Expand Up @@ -760,8 +762,8 @@ object SpillableHostColumnarBatchHandle {
class SpillableHostColumnarBatchHandle private (
val sizeInBytes: Long,
val numRows: Int,
override var host: Option[ColumnarBatch],
var disk: Option[DiskHandle] = None)
private[spill] override var host: Option[ColumnarBatch],
private[spill] var disk: Option[DiskHandle] = None)
extends HostSpillableHandle[ColumnarBatch] {

override def spillable: Boolean = synchronized {
Expand Down Expand Up @@ -1116,7 +1118,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None)
HostAlloc.tryAlloc(handle.sizeInBytes).foreach { hmb =>
withResource(hmb) { _ =>
if (trackInternal(handle)) {
hmb.incRefCount
hmb.incRefCount()
// the host store made room or fit this buffer
builder = Some(new SpillableHostBufferHandleBuilderForHost(handle, hmb))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.{ExecutionException, Future, LinkedBlockingQueue, Ti
import ai.rapids.cudf.{HostMemoryBuffer, PinnedMemoryPool, Rmm, RmmAllocationMode}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.{RmmSpark, RmmSparkThreadState}
import com.nvidia.spark.rapids.spill._
import org.mockito.Mockito.when
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.concurrent.{Signaler, TimeLimits}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.spill.SpillableDeviceStore
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.when
import org.scalatestplus.mockito.MockitoSugar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.jni.{GpuSplitAndRetryOOM, RmmSpark}
import com.nvidia.spark.rapids.spill.{SpillableColumnarBatchHandle, SpillableDeviceStore, SpillFramework}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{doAnswer, spy}
import org.mockito.invocation.InvocationOnMock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{Rmm, RmmAllocationMode, RmmEventHandler}
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.spill.SpillFramework
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream,
import ai.rapids.cudf.{Rmm, RmmAllocationMode, Table}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.spill.SpillFramework
import org.apache.commons.lang3.SerializationUtils
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.nvidia.spark.Retryable
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu, withRestoreOnRetry, withRetry, withRetryNoSplit}
import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM, RmmSpark}
import com.nvidia.spark.rapids.spill.SpillFramework
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package com.nvidia.spark.rapids.shuffle

import com.nvidia.spark.rapids.{RapidsShuffleHandle, SpillableDeviceBufferHandle}
import com.nvidia.spark.rapids.RapidsShuffleHandle
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ import java.io.IOException
import java.nio.ByteBuffer

import ai.rapids.cudf.{DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata, SpillableDeviceBufferHandle}
import com.nvidia.spark.rapids.{MetaUtils, RapidsShuffleHandle, ShuffleMetadata}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.spill.SpillableDeviceBufferHandle
import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
Expand Down Expand Up @@ -233,7 +234,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {
// acquire once at the beginning, and closed at the end
verify(mockRequestHandler, times(1))
.getShuffleHandle(ArgumentMatchers.eq(1))
assertResult(1)(rapidsBuffer.spillable.dev.get.getRefCount)
withResource(rapidsBuffer.spillable.materialize) { dmb =>
// refcount=2 because it was on the device, and we +1 to materialize.
// but it shows no leaks.
assertResult(2)(dmb.getRefCount)
}
}
}
}
Expand Down Expand Up @@ -427,7 +432,11 @@ class RapidsShuffleServerSuite extends RapidsShuffleTestHelper {

// this handle materializes, so make sure we close it
verify(rapidsHandle2.spillable, times(1)).materialize
verify(rapidsHandle2.spillable.dev.get, times(1)).close()
withResource(rapidsHandle2.spillable.materialize) { dmb =>
// refcount=2 because it was on the device, and we +1 to materialize.
// but it shows no leaks.
assertResult(2)(dmb.getRefCount)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids
package com.nvidia.spark.rapids.spill

import java.io.File
import java.math.RoundingMode

import ai.rapids.cudf.{ColumnVector, ContiguousTable, Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, Table}
import ai.rapids.cudf._
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.format.CodecType
import org.mockito.Mockito.when
Expand All @@ -28,7 +29,7 @@ import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

class SpillFrameworkSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
package org.apache.spark.sql.rapids

import ai.rapids.cudf.{Rmm, RmmAllocationMode, TableWriter}
import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion, SpillFramework}
import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuColumnVector, GpuLiteral, RapidsConf, ScalableTaskCompletion}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.{GpuRetryOOM, GpuSplitAndRetryOOM}
import com.nvidia.spark.rapids.spill.SpillFramework
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream
import org.apache.hadoop.mapred.TaskAttemptContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.spark.sql.rapids

import ai.rapids.cudf.DeviceMemoryBuffer
import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer, SpillFramework}
import com.nvidia.spark.rapids.{RapidsConf, SpillableBuffer}
import com.nvidia.spark.rapids.spill.SpillFramework
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite

Expand Down

0 comments on commit 2619cab

Please sign in to comment.