Skip to content

Commit

Permalink
[SPARK-31521][CORE] Correct the fetch size when merging blocks into a…
Browse files Browse the repository at this point in the history
… merged block

### What changes were proposed in this pull request?

Fix the wrong fetch size.

### Why are the changes needed?

The fetch size should be the sum of the size of merged block and the total size of those merging blocks. But we missed the size of merged block.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added a regression test.

Closes #28301 from Ngone51/fix_merged_block_size.

Authored-by: yi.wu <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
Ngone51 authored and dongjoon-hyun committed Apr 25, 2020
1 parent 3e83ccc commit ab8cada
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,15 @@ final class ShuffleBlockFetcherIterator(
if (address.executorId == blockManager.blockManagerId.executorId) {
checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)))
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
localBlockBytes += mergedBlockInfos.map(_.size).sum
} else if (hostLocalDirReadingEnabled && address.host == blockManager.blockManagerId.host) {
checkBlockSizes(blockInfos)
val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)))
blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
numBlocksToFetch += mergedBlockInfos.size
val blocksForAddress =
mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
hostLocalBlocksByExecutor += address -> blocksForAddress
Expand Down Expand Up @@ -340,7 +342,8 @@ final class ShuffleBlockFetcherIterator(
address: BlockManagerId,
isLast: Boolean,
collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = {
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch)
numBlocksToFetch += mergedBlocks.size
var retBlocks = Seq.empty[FetchBlockInfo]
if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
collectedRemoteRequests += createFetchRequest(mergedBlocks, address)
Expand Down Expand Up @@ -400,73 +403,6 @@ final class ShuffleBlockFetcherIterator(
blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) }
}

private[this] def mergeContinuousShuffleBlockIdsIfNeeded(
blocks: Seq[FetchBlockInfo]): Seq[FetchBlockInfo] = {
val result = if (doBatchFetch) {
var curBlocks = new ArrayBuffer[FetchBlockInfo]
val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]

def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId]

// The last merged block may comes from the input, and we can merge more blocks
// into it, if the map id is the same.
def shouldMergeIntoPreviousBatchBlockId =
mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId

val startReduceId = if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) {
// Remove the previous batch block id as we will add a new one to replace it.
mergedBlockInfo.remove(mergedBlockInfo.length - 1).blockId
.asInstanceOf[ShuffleBlockBatchId].startReduceId
} else {
startBlockId.reduceId
}

FetchBlockInfo(
ShuffleBlockBatchId(
startBlockId.shuffleId,
startBlockId.mapId,
startReduceId,
toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1),
toBeMerged.map(_.size).sum,
toBeMerged.head.mapIndex)
}

val iter = blocks.iterator
while (iter.hasNext) {
val info = iter.next()
// It's possible that the input block id is already a batch ID. For example, we merge some
// blocks, and then make fetch requests with the merged blocks according to "max blocks per
// request". The last fetch request may be too small, and we give up and put the remaining
// merged blocks back to the input list.
if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) {
mergedBlockInfo += info
} else {
if (curBlocks.isEmpty) {
curBlocks += info
} else {
val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId]
val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId
if (curBlockId.mapId != currentMapId) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
curBlocks.clear()
}
curBlocks += info
}
}
}
if (curBlocks.nonEmpty) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
}
mergedBlockInfo
} else {
blocks
}
// update metrics
numBlocksToFetch += result.size
result
}

/**
* Fetch the local blocks while we are fetching remote blocks. This is ok because
* `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
Expand Down Expand Up @@ -915,6 +851,86 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat
private[storage]
object ShuffleBlockFetcherIterator {

/**
* This function is used to merged blocks when doBatchFetch is true. Blocks which have the
* same `mapId` can be merged into one block batch. The block batch is specified by a range
* of reduceId, which implies the continuous shuffle blocks that we can fetch in a batch.
* For example, input blocks like (shuffle_0_0_0, shuffle_0_0_1, shuffle_0_1_0) can be
* merged into (shuffle_0_0_0_2, shuffle_0_1_0_1), and input blocks like (shuffle_0_0_0_2,
* shuffle_0_0_2, shuffle_0_0_3) can be merged into (shuffle_0_0_0_4).
*
* @param blocks blocks to be merged if possible. May contains already merged blocks.
* @param doBatchFetch whether to merge blocks.
* @return the input blocks if doBatchFetch=false, or the merged blocks if doBatchFetch=true.
*/
def mergeContinuousShuffleBlockIdsIfNeeded(
blocks: Seq[FetchBlockInfo],
doBatchFetch: Boolean): Seq[FetchBlockInfo] = {
val result = if (doBatchFetch) {
var curBlocks = new ArrayBuffer[FetchBlockInfo]
val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]

def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId]

// The last merged block may comes from the input, and we can merge more blocks
// into it, if the map id is the same.
def shouldMergeIntoPreviousBatchBlockId =
mergedBlockInfo.last.blockId.asInstanceOf[ShuffleBlockBatchId].mapId == startBlockId.mapId

val (startReduceId, size) =
if (mergedBlockInfo.nonEmpty && shouldMergeIntoPreviousBatchBlockId) {
// Remove the previous batch block id as we will add a new one to replace it.
val removed = mergedBlockInfo.remove(mergedBlockInfo.length - 1)
(removed.blockId.asInstanceOf[ShuffleBlockBatchId].startReduceId,
removed.size + toBeMerged.map(_.size).sum)
} else {
(startBlockId.reduceId, toBeMerged.map(_.size).sum)
}

FetchBlockInfo(
ShuffleBlockBatchId(
startBlockId.shuffleId,
startBlockId.mapId,
startReduceId,
toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1),
size,
toBeMerged.head.mapIndex)
}

val iter = blocks.iterator
while (iter.hasNext) {
val info = iter.next()
// It's possible that the input block id is already a batch ID. For example, we merge some
// blocks, and then make fetch requests with the merged blocks according to "max blocks per
// request". The last fetch request may be too small, and we give up and put the remaining
// merged blocks back to the input list.
if (info.blockId.isInstanceOf[ShuffleBlockBatchId]) {
mergedBlockInfo += info
} else {
if (curBlocks.isEmpty) {
curBlocks += info
} else {
val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId]
val currentMapId = curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId
if (curBlockId.mapId != currentMapId) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
curBlocks.clear()
}
curBlocks += info
}
}
}
if (curBlocks.nonEmpty) {
mergedBlockInfo += mergeFetchBlockInfo(curBlocks)
}
mergedBlockInfo
} else {
blocks
}
result
}

/**
* The block information to fetch used in FetchRequest.
* @param blockId block id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -1071,4 +1072,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val e = intercept[FetchFailedException] { iterator.next() }
assert(e.getMessage.contains("Received a zero-size buffer"))
}

test("SPARK-31521: correct the fetch size when merging blocks into a merged block") {
val bId1 = ShuffleBlockBatchId(0, 0, 0, 5)
val bId2 = ShuffleBlockId(0, 0, 6)
val bId3 = ShuffleBlockId(0, 0, 7)
val block1 = FetchBlockInfo(bId1, 40, 0)
val block2 = FetchBlockInfo(bId2, 50, 0)
val block3 = FetchBlockInfo(bId3, 60, 0)
val inputBlocks = Seq(block1, block2, block3)

val mergedBlocks = ShuffleBlockFetcherIterator.
mergeContinuousShuffleBlockIdsIfNeeded(inputBlocks, true)
assert(mergedBlocks.size === 1)
val mergedBlock = mergedBlocks.head
val mergedBlockId = mergedBlock.blockId.asInstanceOf[ShuffleBlockBatchId]
assert(mergedBlockId.startReduceId === bId1.startReduceId)
assert(mergedBlockId.endReduceId === bId3.reduceId + 1)
assert(mergedBlock.size === inputBlocks.map(_.size).sum)
}
}

0 comments on commit ab8cada

Please sign in to comment.