Skip to content

Commit

Permalink
Add flow control for multithreaded shuffle writer (#9678)
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Brennan <[email protected]>
  • Loading branch information
jbrennan333 authored Nov 13, 2023
1 parent c20a843 commit 792b4c1
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Name | Description | Default Value | Applicable at
<a name="python.memory.gpu.pooling.enabled"></a>spark.rapids.python.memory.gpu.pooling.enabled|Should RMM in Python workers act as a pooling allocator for GPU memory, or should it just pass through to CUDA memory allocation directly. When not specified, It will honor the value of config 'spark.rapids.memory.gpu.pooling.enabled'|None|Runtime
<a name="shuffle.enabled"></a>spark.rapids.shuffle.enabled|Enable or disable the RAPIDS Shuffle Manager at runtime. The [RAPIDS Shuffle Manager](https://docs.nvidia.com/spark-rapids/user-guide/latest/additional-functionality/rapids-shuffle.html) must already be configured. When set to `false`, the built-in Spark shuffle will be used. |true|Runtime
<a name="shuffle.mode"></a>spark.rapids.shuffle.mode|RAPIDS Shuffle Manager mode. "MULTITHREADED": shuffle file writes and reads are parallelized using a thread pool. "UCX": (requires UCX installation) uses accelerated transports for transferring shuffle blocks. "CACHE_ONLY": use when running a single executor, for short-circuit cached shuffle (for testing purposes).|MULTITHREADED|Startup
<a name="shuffle.multiThreaded.maxBytesInFlight"></a>spark.rapids.shuffle.multiThreaded.maxBytesInFlight|The size limit, in bytes, that the RAPIDS shuffle manager configured in "MULTITHREADED" mode will allow to be deserialized concurrently per task. This is also the maximum amount of memory that will be used per task. This should be set larger than Spark's default maxBytesInFlight (48MB). The larger this setting is, the more compressed shuffle chunks are processed concurrently. In practice, care needs to be taken to not go over the amount of off-heap memory that Netty has available. See https://github.com/NVIDIA/spark-rapids/issues/9153.|134217728|Startup
<a name="shuffle.multiThreaded.maxBytesInFlight"></a>spark.rapids.shuffle.multiThreaded.maxBytesInFlight|The size limit, in bytes, that the RAPIDS shuffle manager configured in "MULTITHREADED" mode will allow to be serialized or deserialized concurrently per task. This is also the maximum amount of memory that will be used per task. This should be set larger than Spark's default maxBytesInFlight (48MB). The larger this setting is, the more compressed shuffle chunks are processed concurrently. In practice, care needs to be taken to not go over the amount of off-heap memory that Netty has available. See https://github.com/NVIDIA/spark-rapids/issues/9153.|134217728|Startup
<a name="shuffle.multiThreaded.reader.threads"></a>spark.rapids.shuffle.multiThreaded.reader.threads|The number of threads to use for reading shuffle blocks per executor in the RAPIDS shuffle manager configured in "MULTITHREADED" mode. There are two special values: 0 = feature is disabled, falls back to Spark built-in shuffle reader; 1 = our implementation of Spark's built-in shuffle reader with extra metrics.|20|Startup
<a name="shuffle.multiThreaded.writer.threads"></a>spark.rapids.shuffle.multiThreaded.writer.threads|The number of threads to use for writing shuffle blocks per executor in the RAPIDS shuffle manager configured in "MULTITHREADED" mode. There are two special values: 0 = feature is disabled, falls back to Spark built-in shuffle writer; 1 = our implementation of Spark's built-in shuffle writer with extra metrics.|20|Startup
<a name="shuffle.transport.earlyStart"></a>spark.rapids.shuffle.transport.earlyStart|Enable early connection establishment for RAPIDS Shuffle|true|Startup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1661,8 +1661,9 @@ object RapidsConf {
conf("spark.rapids.shuffle.multiThreaded.maxBytesInFlight")
.doc(
"The size limit, in bytes, that the RAPIDS shuffle manager configured in " +
"\"MULTITHREADED\" mode will allow to be deserialized concurrently per task. This is " +
"also the maximum amount of memory that will be used per task. This should be set larger " +
"\"MULTITHREADED\" mode will allow to be serialized or deserialized concurrently " +
"per task. This is also the maximum amount of memory that will be used per task. " +
"This should be set larger " +
"than Spark's default maxBytesInFlight (48MB). The larger this setting is, the " +
"more compressed shuffle chunks are processed concurrently. In practice, " +
"care needs to be taken to not go over the amount of off-heap memory that Netty has " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
mapId: Long,
sparkConf: SparkConf,
writeMetrics: ShuffleWriteMetricsReporter,
maxBytesInFlight: Long,
shuffleExecutorComponents: ShuffleExecutorComponents,
numWriterThreads: Int)
extends ShuffleWriter[K, V]
Expand All @@ -262,6 +263,7 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
private val serializer = dep.serializer.newInstance()
private val transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true)
private val fileBufferSize = sparkConf.get(config.SHUFFLE_FILE_BUFFER_SIZE).toInt * 1024
private val limiter = new BytesInFlightLimiter(maxBytesInFlight)
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
Expand Down Expand Up @@ -323,17 +325,23 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
} else {
// we close batches actively in the `records` iterator as we get the next batch
// this makes sure it is kept alive while a task is able to handle it.
val cb = value match {
val (cb, size) = value match {
case columnarBatch: ColumnarBatch =>
SlicedGpuColumnVector.incRefCount(columnarBatch)
(SlicedGpuColumnVector.incRefCount(columnarBatch),
SlicedGpuColumnVector.getTotalHostMemoryUsed(columnarBatch))
case _ =>
null
(null, 0L)
}
limiter.acquireOrBlock(size)
writeFutures += RapidsShuffleInternalManagerBase.queueWriteTask(slotNum, () => {
withResource(cb) { _ =>
val recordWriteTimeStart = System.nanoTime()
myWriter.write(key, value)
recordWriteTime.getAndAdd(System.nanoTime() - recordWriteTimeStart)
try {
val recordWriteTimeStart = System.nanoTime()
myWriter.write(key, value)
recordWriteTime.getAndAdd(System.nanoTime() - recordWriteTimeStart)
} finally {
limiter.release(size)
}
}
})
}
Expand Down Expand Up @@ -513,6 +521,48 @@ abstract class RapidsShuffleThreadedWriterBase[K, V](
diskBlockObjectWriters.clear()
}
}

def getBytesInFlight: Long = limiter.getBytesInFlight
}

class BytesInFlightLimiter(maxBytesInFlight: Long) {
private var inFlight: Long = 0L

def acquire(sz: Long): Boolean = {
if (sz == 0) {
true
} else {
synchronized {
if (inFlight == 0 || sz + inFlight < maxBytesInFlight) {
inFlight += sz
true
} else {
false
}
}
}
}

def acquireOrBlock(sz: Long): Unit = {
var acquired = acquire(sz)
if (!acquired) {
synchronized {
while (!acquired) {
acquired = acquire(sz)
if (!acquired) {
wait()
}
}
}
}
}

def release(sz: Long): Unit = synchronized {
inFlight -= sz
notifyAll()
}

def getBytesInFlight: Long = inFlight
}

abstract class RapidsShuffleThreadedReaderBase[K, C](
Expand Down Expand Up @@ -585,28 +635,6 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](
doBatchFetch
}

class BytesInFlightLimiter(maxBytesInFlight: Long) {
private var inFlight: Long = 0L

def acquire(sz: Long): Boolean = {
if (sz == 0) {
true
} else {
synchronized {
if (inFlight == 0 || sz + inFlight < maxBytesInFlight) {
inFlight += sz
true
} else {
false
}
}
}
}

def release(sz: Long): Unit = synchronized {
inFlight -= sz
}
}

class RapidsShuffleThreadedBlockIterator(
fetcherIterator: RapidsShuffleBlockFetcherIterator,
Expand Down Expand Up @@ -1275,6 +1303,7 @@ class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: Boolean)
mapId,
conf,
new ThreadSafeShuffleWriteMetricsReporter(metricsReporter),
rapidsConf.shuffleMultiThreadedMaxBytesInFlight,
execComponents.get,
rapidsConf.shuffleMultiThreadedWriterThreads)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RapidsShuffleThreadedWriter[K, V](
mapId: Long,
sparkConf: SparkConf,
writeMetrics: ShuffleWriteMetricsReporter,
maxBytesInFlight: Long,
shuffleExecutorComponents: ShuffleExecutorComponents,
numWriterThreads: Int)
extends RapidsShuffleThreadedWriterBase[K, V](
Expand All @@ -41,6 +42,7 @@ class RapidsShuffleThreadedWriter[K, V](
mapId,
sparkConf,
writeMetrics,
maxBytesInFlight,
shuffleExecutorComponents,
numWriterThreads) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class RapidsShuffleThreadedWriter[K, V](
mapId: Long,
sparkConf: SparkConf,
writeMetrics: ShuffleWriteMetricsReporter,
maxBytesInFlight: Long,
shuffleExecutorComponents: ShuffleExecutorComponents,
numWriterThreads: Int)
extends RapidsShuffleThreadedWriterBase[K, V](
Expand All @@ -63,6 +64,7 @@ class RapidsShuffleThreadedWriter[K, V](
mapId,
sparkConf,
writeMetrics,
maxBytesInFlight,
shuffleExecutorComponents,
numWriterThreads)
with org.apache.spark.shuffle.checksum.ShuffleChecksumSupport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,13 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
1024 * 1024,
shuffleExecutorComponents,
numWriterThreads)
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
assert(writer.getPartitionLengths.sum === 0)
assert(writer.getBytesInFlight == 0)
assert(outputFile.exists())
assert(outputFile.length() === 0)
assert(temporaryFilesCreated.isEmpty)
Expand All @@ -313,13 +315,15 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
0L, // MapId
transferConf,
new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics),
1024 * 1024,
shuffleExecutorComponents,
numWriterThreads)
writer.write(records)
writer.stop( /* success = */ true)
assert(temporaryFilesCreated.nonEmpty)
assert(writer.getPartitionLengths.sum === outputFile.length())
assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
assert(writer.getBytesInFlight == 0)
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted
val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
Expand Down Expand Up @@ -349,6 +353,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
1024 * 1024,
shuffleExecutorComponents,
numWriterThreads)

Expand All @@ -362,6 +367,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite

writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
assert(writer.getBytesInFlight == 0)
}

test("cleanup of intermediate files after errors") {
Expand All @@ -371,6 +377,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
0L, // MapId
conf,
taskContext.taskMetrics().shuffleWriteMetrics,
1024 * 1024,
shuffleExecutorComponents,
numWriterThreads)
intercept[SparkException] {
Expand All @@ -384,6 +391,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
assert(temporaryFilesCreated.nonEmpty)
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
assert(writer.getBytesInFlight == 0)
}

test("write checksum file") {
Expand Down Expand Up @@ -426,11 +434,13 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
mapId,
conf,
new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics),
1024 * 1024,
new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver),
numWriterThreads)

writer.write(Iterator((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)))
writer.stop( /* success = */ true)
assert(writer.getBytesInFlight == 0)
assert(checksumFile.exists())
assert(checksumFile.length() === 8 * numPartition)
compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile)
Expand All @@ -455,6 +465,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
0L, // MapId
conf,
new ThreadSafeShuffleWriteMetricsReporter(taskContext.taskMetrics().shuffleWriteMetrics),
1024 * 1024,
shuffleExecutorComponents,
numWriterThreads)
assertThrows[IOException] {
Expand All @@ -469,6 +480,7 @@ class RapidsShuffleThreadedWriterSuite extends AnyFunSuite
}
assert(temporaryFilesCreated.nonEmpty)
assert(writer.getPartitionLengths == null)
assert(writer.getBytesInFlight == 0)
assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted
}
}
Expand Down

0 comments on commit 792b4c1

Please sign in to comment.