Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude the semaphore wating time from the deserialization metric #38

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ trait GpuPartitioning extends Partitioning {
if (_serializingOnGPU) {
table =>
withResource(new NvtxRange("Table to Host", NvtxColor.BLUE)) { _ =>
withResource(table) { _ =>
PackedTableHostColumnVector.from(table)
}
PackedTableHostColumnVector.from(table)
}
} else {
GpuCompressedColumnVector.from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ private[rapids] class SimpleTableSerializer extends TableSerde {
}
}

private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) extends TableSerde {
private[rapids] class SimpleTableDeserializer(
sparkTypes: Array[DataType],
deserTime: GpuMetric) extends TableSerde {
private def readProtocolHeader(dIn: DataInputStream): Unit = {
val magicNum = dIn.readInt()
if (magicNum != P_MAGIC_NUM) {
Expand Down Expand Up @@ -172,10 +174,12 @@ private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) exten
def readFromStream(dIn: DataInputStream): ColumnarBatch = {
// IO operation is coming, so leave GPU for a while.
GpuSemaphore.releaseIfNecessary(TaskContext.get())
// 1) read and check header
readProtocolHeader(dIn)
// 2) read table metadata
val tableMeta = TableMeta.getRootAsTableMeta(readByteBufferFromStream(dIn))
val tableMeta = deserTime.ns {
// 1) read and check header
readProtocolHeader(dIn)
// 2) read table metadata
TableMeta.getRootAsTableMeta(readByteBufferFromStream(dIn))
}
if (tableMeta.packedMetaAsByteBuffer() == null) {
// no packed metadata, must be a table with zero columns
// Acquiring the GPU even the coming batch is empty, because the downstream
Expand All @@ -186,39 +190,42 @@ private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) exten
} else {
// 3) read table data
val hostBuf = withResource(new NvtxRange("Read Host Table", NvtxColor.ORANGE)) { _ =>
readHostBufferFromStream(dIn)
deserTime.ns(readHostBufferFromStream(dIn))
}
val data = withResource(hostBuf) { _ =>
// Begin to use GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get())
withResource(new NvtxRange("Table to Device", NvtxColor.YELLOW)) { _ =>
closeOnExcept(DeviceMemoryBuffer.allocate(hostBuf.getLength)) { devBuf =>
devBuf.copyFromHostBuffer(hostBuf)
devBuf
deserTime.ns {
closeOnExcept(DeviceMemoryBuffer.allocate(hostBuf.getLength)) { devBuf =>
devBuf.copyFromHostBuffer(hostBuf)
devBuf
}
}
}
}
withResource(new NvtxRange("Deserialize Table", NvtxColor.RED)) { _ =>
withResource(data) { _ =>
val bufferMeta = tableMeta.bufferMeta()
if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) {
MetaUtils.getBatchFromMeta(data, tableMeta, sparkTypes)
} else {
// Compressed table is not supported by the write side, but ok to
// put it here for the read side. Since compression will be supported later.
GpuCompressedColumnVector.from(data, tableMeta)
deserTime.ns {
withResource(data) { _ =>
val bufferMeta = tableMeta.bufferMeta()
if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) {
MetaUtils.getBatchFromMeta(data, tableMeta, sparkTypes)
} else {
GpuCompressedColumnVector.from(data, tableMeta)
}
}
}
}
}
}

}

private[rapids] class SerializedTableIterator(dIn: DataInputStream,
sparkTypes: Array[DataType],
deserTime: GpuMetric) extends Iterator[(Int, ColumnarBatch)] {

private val tableDeserializer = new SimpleTableDeserializer(sparkTypes)
private val tableDeserializer = new SimpleTableDeserializer(sparkTypes, deserTime)
private var closed = false
private var onDeck: Option[SpillableColumnarBatch] = None
Option(TaskContext.get()).foreach { tc =>
Expand Down Expand Up @@ -255,10 +262,8 @@ private[rapids] class SerializedTableIterator(dIn: DataInputStream,
return
}
try {
onDeck = deserTime.ns(
Some(SpillableColumnarBatch(tableDeserializer.readFromStream(dIn),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
)
onDeck = Some(SpillableColumnarBatch(tableDeserializer.readFromStream(dIn),
SpillPriorities.ACTIVE_ON_DECK_PRIORITY))
} catch {
case _: EOFException => // we reach the end
dIn.close()
Expand Down
Loading