diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 981ecc560d4..845a11a69e5 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -96,19 +96,19 @@ You can also disable only one rule, by specifying its rule id, as specified in:
-
-
-
-
-
-
-
+ enabled="true">
- (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*]
+ (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*]
Use Javadoc style indentation for multiline comments
+
+
+
+
+
+
+
diff --git a/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala b/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala
index 071547dcee0..db5a0746515 100755
--- a/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala
+++ b/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala
@@ -18,15 +18,15 @@ package ai.rapids.cudf
object CudaUtil {
/**
- * Copy from `src` buffer, starting at `srcOffset`,
- * to a destination buffer `dst` starting at `dstOffset`,
- * `length` bytes, in the default stream.
- * @param src source buffer
- * @param srcOffset source offset
- * @param dst destination buffer
- * @param dstOffset destination offset
- * @param length amount to copy
- */
+ * Copy from `src` buffer, starting at `srcOffset`,
+ * to a destination buffer `dst` starting at `dstOffset`,
+ * `length` bytes, in the default stream.
+ * @param src source buffer
+ * @param srcOffset source offset
+ * @param dst destination buffer
+ * @param dstOffset destination offset
+ * @param length amount to copy
+ */
def copy(src: MemoryBuffer, srcOffset: Long, dst: MemoryBuffer,
dstOffset: Long, length: Long): Unit = {
Cuda.memcpy(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala
index 2489902097f..4bb2c3ca017 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala
@@ -25,10 +25,10 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
- * A wrapper reader that always appends partition values to the ColumnarBatch produced by the input
- * reader `fileReader`. Each scalar value is splatted to a column with the same number of
- * rows as the batch returned by the reader.
- */
+ * A wrapper reader that always appends partition values to the ColumnarBatch produced by the input
+ * reader `fileReader`. Each scalar value is splatted to a column with the same number of
+ * rows as the batch returned by the reader.
+ */
class ColumnarPartitionReaderWithPartitionValues(
fileReader: PartitionReader[ColumnarBatch],
partitionValues: Array[Scalar]) extends PartitionReader[ColumnarBatch] {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
index 256bc2c1122..4fa0d7c3901 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
@@ -179,14 +179,14 @@ case class GpuOrcPartitionReaderFactory(
object GpuOrcPartitionReader {
/**
- * This class describes a stripe that will appear in the ORC output memory file.
- *
- * @param infoBuilder builder for output stripe info that has been populated with
- * all fields except those that can only be known when the file
- * is being written (e.g.: file offset, compressed footer length)
- * @param footer stripe footer
- * @param inputDataRanges input file ranges (based at file offset 0) of stripe data
- */
+ * This class describes a stripe that will appear in the ORC output memory file.
+ *
+ * @param infoBuilder builder for output stripe info that has been populated with
+ * all fields except those that can only be known when the file
+ * is being written (e.g.: file offset, compressed footer length)
+ * @param footer stripe footer
+ * @param inputDataRanges input file ranges (based at file offset 0) of stripe data
+ */
private case class OrcOutputStripe(
infoBuilder: OrcProto.StripeInformation.Builder,
footer: OrcProto.StripeFooter,
@@ -200,32 +200,32 @@ object GpuOrcPartitionReader {
OrcProto.Stream.Kind.ROW_INDEX)
/**
- * This class holds fields needed to read and iterate over the OrcFile
- *
- * @param updatedReadSchema read schema mapped to the file's field names
- * @param evolution ORC SchemaEvolution
- * @param dataReader ORC DataReader
- * @param orcReader ORC Input File Reader
- * @param blockIterator An iterator over the ORC output stripes
- */
+ * This class holds fields needed to read and iterate over the OrcFile
+ *
+ * @param updatedReadSchema read schema mapped to the file's field names
+ * @param evolution ORC SchemaEvolution
+ * @param dataReader ORC DataReader
+ * @param orcReader ORC Input File Reader
+ * @param blockIterator An iterator over the ORC output stripes
+ */
private case class OrcPartitionReaderContext(updatedReadSchema: TypeDescription,
evolution: SchemaEvolution, dataReader: DataReader, orcReader: Reader,
blockIterator: BufferedIterator[OrcOutputStripe])
}
/**
- * A PartitionReader that reads an ORC file split on the GPU.
- *
- * Efficiently reading an ORC split on the GPU requires rebuilding the ORC file
- * in memory such that only relevant data is present in the memory file.
- * This avoids sending unnecessary data to the GPU and saves GPU memory.
- *
- * @param conf Hadoop configuration
- * @param partFile file split to read
- * @param dataSchema Spark schema of the file
- * @param readDataSchema Spark schema of what will be read from the file
- * @param debugDumpPrefix path prefix for dumping the memory file or null
- */
+ * A PartitionReader that reads an ORC file split on the GPU.
+ *
+ * Efficiently reading an ORC split on the GPU requires rebuilding the ORC file
+ * in memory such that only relevant data is present in the memory file.
+ * This avoids sending unnecessary data to the GPU and saves GPU memory.
+ *
+ * @param conf Hadoop configuration
+ * @param partFile file split to read
+ * @param dataSchema Spark schema of the file
+ * @param readDataSchema Spark schema of what will be read from the file
+ * @param debugDumpPrefix path prefix for dumping the memory file or null
+ */
class GpuOrcPartitionReader(
conf: Configuration,
partFile: PartitionedFile,
@@ -319,13 +319,13 @@ class GpuOrcPartitionReader(
}
/**
- * Build an integer array that maps the original ORC file's column IDs
- * to column IDs in the memory file. Columns that are not present in
- * the memory file will have a mapping of -1.
- *
- * @param evolution ORC SchemaEvolution
- * @return column mapping array
- */
+ * Build an integer array that maps the original ORC file's column IDs
+ * to column IDs in the memory file. Columns that are not present in
+ * the memory file will have a mapping of -1.
+ *
+ * @param evolution ORC SchemaEvolution
+ * @return column mapping array
+ */
private def columnRemap(evolution: SchemaEvolution): Array[Int] = {
val fileIncluded = evolution.getFileIncluded
if (fileIncluded != null) {
@@ -346,17 +346,17 @@ class GpuOrcPartitionReader(
}
/**
- * Build the output stripe descriptors for what will appear in the ORC memory file.
- *
- * @param stripes descriptors for the ORC input stripes, filtered to what is in the split
- * @param evolution ORC SchemaEvolution
- * @param sargApp ORC search argument applier
- * @param sargColumns mapping of ORC search argument columns
- * @param ignoreNonUtf8BloomFilter true if bloom filters other than UTF8 should be ignored
- * @param writerVersion writer version from the original ORC input file
- * @param dataReader ORC DataReader
- * @return output stripes descriptors
- */
+ * Build the output stripe descriptors for what will appear in the ORC memory file.
+ *
+ * @param stripes descriptors for the ORC input stripes, filtered to what is in the split
+ * @param evolution ORC SchemaEvolution
+ * @param sargApp ORC search argument applier
+ * @param sargColumns mapping of ORC search argument columns
+ * @param ignoreNonUtf8BloomFilter true if bloom filters other than UTF8 should be ignored
+ * @param writerVersion writer version from the original ORC input file
+ * @param dataReader ORC DataReader
+ * @return output stripes descriptors
+ */
private def buildOutputStripes(
stripes: Seq[StripeInformation],
evolution: SchemaEvolution,
@@ -392,14 +392,14 @@ class GpuOrcPartitionReader(
}
/**
- * Build the output stripe descriptor for a corresponding input stripe
- * that should be copied to the ORC memory file.
- *
- * @param inputStripe input stripe descriptor
- * @param inputFooter input stripe footer
- * @param columnMapping mapping of input column IDs to output column IDs
- * @return output stripe descriptor
- */
+ * Build the output stripe descriptor for a corresponding input stripe
+ * that should be copied to the ORC memory file.
+ *
+ * @param inputStripe input stripe descriptor
+ * @param inputFooter input stripe footer
+ * @param columnMapping mapping of input column IDs to output column IDs
+ * @return output stripe descriptor
+ */
private def buildOutputStripe(
inputStripe: StripeInformation,
inputFooter: OrcProto.StripeFooter,
@@ -564,13 +564,13 @@ class GpuOrcPartitionReader(
}
/**
- * Check if the read schema is compatible with the file schema.
- *
- * @param fileSchema input file's ORC schema
- * @param readSchema ORC schema for what will be read
- * @param isCaseAware true if field names are case-sensitive
- * @return read schema mapped to the file's field names
- */
+ * Check if the read schema is compatible with the file schema.
+ *
+ * @param fileSchema input file's ORC schema
+ * @param readSchema ORC schema for what will be read
+ * @param isCaseAware true if field names are case-sensitive
+ * @return read schema mapped to the file's field names
+ */
private def checkSchemaCompatibility(
fileSchema: TypeDescription,
readSchema: TypeDescription,
@@ -602,15 +602,15 @@ class GpuOrcPartitionReader(
}
/**
- * Build an ORC search argument applier that can filter input file splits
- * when predicate push-down filters have been specified.
- *
- * @param orcReader ORC input file reader
- * @param readerOpts ORC reader options
- * @param evolution ORC SchemaEvolution
- * @param useUTCTimestamp true if timestamps are UTC
- * @return the search argument applier and search argument column mapping
- */
+ * Build an ORC search argument applier that can filter input file splits
+ * when predicate push-down filters have been specified.
+ *
+ * @param orcReader ORC input file reader
+ * @param readerOpts ORC reader options
+ * @param evolution ORC SchemaEvolution
+ * @param useUTCTimestamp true if timestamps are UTC
+ * @return the search argument applier and search argument column mapping
+ */
private def getSearchApplier(
orcReader: Reader,
readerOpts: Reader.Options,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
index d780fefcf4d..0e16dc3bf9b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala
@@ -288,22 +288,22 @@ case class GpuParquetPartitionReaderFactory(
}
/**
- * A PartitionReader that reads a Parquet file split on the GPU.
- *
- * Efficiently reading a Parquet split on the GPU requires re-constructing the Parquet file
- * in memory that contains just the column chunks that are needed. This avoids sending
- * unnecessary data to the GPU and saves GPU memory.
- *
- * @param conf the Hadoop configuration
- * @param split the file split to read
- * @param filePath the path to the Parquet file
- * @param clippedBlocks the block metadata from the original Parquet file that has been clipped
- * to only contain the column chunks to be read
- * @param clippedParquetSchema the Parquet schema from the original Parquet file that has been
- * clipped to contain only the columns to be read
- * @param readDataSchema the Spark schema describing what will be read
- * @param debugDumpPrefix a path prefix to use for dumping the fabricated Parquet data or null
- */
+ * A PartitionReader that reads a Parquet file split on the GPU.
+ *
+ * Efficiently reading a Parquet split on the GPU requires re-constructing the Parquet file
+ * in memory that contains just the column chunks that are needed. This avoids sending
+ * unnecessary data to the GPU and saves GPU memory.
+ *
+ * @param conf the Hadoop configuration
+ * @param split the file split to read
+ * @param filePath the path to the Parquet file
+ * @param clippedBlocks the block metadata from the original Parquet file that has been clipped
+ * to only contain the column chunks to be read
+ * @param clippedParquetSchema the Parquet schema from the original Parquet file that has been
+ * clipped to contain only the columns to be read
+ * @param readDataSchema the Spark schema describing what will be read
+ * @param debugDumpPrefix a path prefix to use for dumping the fabricated Parquet data or null
+ */
class ParquetPartitionReader(
conf: Configuration,
split: PartitionedFile,
@@ -433,15 +433,15 @@ class ParquetPartitionReader(
}
/**
- * Copies the data corresponding to the clipped blocks in the original file and compute the
- * block metadata for the output. The output blocks will contain the same column chunk
- * metadata but with the file offsets updated to reflect the new position of the column data
- * as written to the output.
- *
- * @param in the input stream for the original Parquet file
- * @param out the output stream to receive the data
- * @return updated block metadata corresponding to the output
- */
+ * Copies the data corresponding to the clipped blocks in the original file and compute the
+ * block metadata for the output. The output blocks will contain the same column chunk
+ * metadata but with the file offsets updated to reflect the new position of the column data
+ * as written to the output.
+ *
+ * @param in the input stream for the original Parquet file
+ * @param out the output stream to receive the data
+ * @return updated block metadata corresponding to the output
+ */
private def copyBlocksData(
in: FSDataInputStream,
out: HostMemoryOutputStream,
@@ -675,12 +675,12 @@ object ParquetPartitionReader {
private case class CopyRange(offset: Long, length: Long)
/**
- * Build a new BlockMetaData
- *
- * @param rowCount the number of rows in this block
- * @param columns the new column chunks to reference in the new BlockMetaData
- * @return the new BlockMetaData
- */
+ * Build a new BlockMetaData
+ *
+ * @param rowCount the number of rows in this block
+ * @param columns the new column chunks to reference in the new BlockMetaData
+ * @return the new BlockMetaData
+ */
private def newParquetBlock(
rowCount: Long,
columns: Seq[ColumnChunkMetaData]): BlockMetaData = {
@@ -698,14 +698,14 @@ object ParquetPartitionReader {
}
/**
- * Trim block metadata to contain only the column chunks that occur in the specified columns.
- * The column chunks that are returned are preserved verbatim
- * (i.e.: file offsets remain unchanged).
- *
- * @param columnPaths the paths of columns to preserve
- * @param blocks the block metadata from the original Parquet file
- * @return the updated block metadata with undesired column chunks removed
- */
+ * Trim block metadata to contain only the column chunks that occur in the specified columns.
+ * The column chunks that are returned are preserved verbatim
+ * (i.e.: file offsets remain unchanged).
+ *
+ * @param columnPaths the paths of columns to preserve
+ * @param blocks the block metadata from the original Parquet file
+ * @return the updated block metadata with undesired column chunks removed
+ */
private[spark] def clipBlocks(columnPaths: Seq[ColumnPath],
blocks: Seq[BlockMetaData]): Seq[BlockMetaData] = {
val pathSet = columnPaths.toSet
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala
index c24f2060a88..c6775d1073b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala
@@ -33,12 +33,12 @@ import org.apache.spark.util.MutablePair
class GpuRangePartitioner extends Serializable {
var rangeBounds: Array[InternalRow] = _
/**
- * Sketches the input RDD via reservoir sampling on each partition.
- *
- * @param rdd the input RDD to sketch
- * @param sampleSizePerPartition max sample size per partition
- * @return (total number of items, an array of (partitionId, number of items, sample))
- */
+ * Sketches the input RDD via reservoir sampling on each partition.
+ *
+ * @param rdd the input RDD to sketch
+ * @param sampleSizePerPartition max sample size per partition
+ * @return (total number of items, an array of (partitionId, number of items, sample))
+ */
def sketch[K: ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
@@ -55,13 +55,13 @@ class GpuRangePartitioner extends Serializable {
}
/**
- * Determines the bounds for range partitioning from candidates with weights indicating how many
- * items each represents. Usually this is 1 over the probability used to sample this candidate.
- *
- * @param candidates unordered candidates with weights
- * @param partitions number of partitions
- * @return selected bounds
- */
+ * Determines the bounds for range partitioning from candidates with weights indicating how many
+ * items each represents. Usually this is 1 over the probability used to sample this candidate.
+ *
+ * @param candidates unordered candidates with weights
+ * @param partitions number of partitions
+ * @return selected bounds
+ */
def determineBounds[K: Ordering : ClassTag](candidates: ArrayBuffer[(K, Float)],
partitions: Int): Array[K] = {
val ordering = implicitly[Ordering[K]]
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala
index deff3bee23b..ee23233e5b3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala
@@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructTyp
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
- * A GPU accelerated `org.apache.spark.sql.catalyst.plans.physical.Partitioning` that partitions
- * sortable records by range into roughly equal ranges. The ranges are determined by sampling
- * the content of the RDD passed in.
- *
- * @note The actual number of partitions created might not be the same
- * as the `numPartitions` parameter, in the case where the number of sampled records is less than
- * the value of `partitions`.
- */
+ * A GPU accelerated `org.apache.spark.sql.catalyst.plans.physical.Partitioning` that partitions
+ * sortable records by range into roughly equal ranges. The ranges are determined by sampling
+ * the content of the RDD passed in.
+ *
+ * @note The actual number of partitions created might not be the same
+ * as the `numPartitions` parameter, in the case where the number of sampled records is less than
+ * the value of `partitions`.
+ */
case class GpuRangePartitioning(
gpuOrdering: Seq[SortOrder],
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala
index 6caab17460a..cea112c3c5b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala
@@ -23,16 +23,16 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
case class GpuSinglePartitioning(expressions: Seq[Expression])
extends GpuExpression with GpuPartitioning {
/**
- * Returns the result of evaluating this expression on the entire `ColumnarBatch`.
- * The result of calling this may be a single [[GpuColumnVector]] or a scalar value.
- * Scalar values typically happen if they are a part of the expression i.e. col("a") + 100.
- * In this case the 100 is a literal that Add would have to be able to handle.
- *
- * By convention any [[GpuColumnVector]] returned by [[columnarEval]]
- * is owned by the caller and will need to be closed by them. This can happen by putting it
- * into a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a
- * temporary value.
- */
+ * Returns the result of evaluating this expression on the entire `ColumnarBatch`.
+ * The result of calling this may be a single [[GpuColumnVector]] or a scalar value.
+ * Scalar values typically happen if they are a part of the expression i.e. col("a") + 100.
+ * In this case the 100 is a literal that Add would have to be able to handle.
+ *
+ * By convention any [[GpuColumnVector]] returned by [[columnarEval]]
+ * is owned by the caller and will need to be closed by them. This can happen by putting it
+ * into a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a
+ * temporary value.
+ */
override def columnarEval(batch: ColumnarBatch): Any = {
if (batch.numCols == 0) {
Array(batch).zipWithIndex
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
index 7c04daea502..3f916d9e87d 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala
@@ -244,14 +244,14 @@ case class HostColumnarToGpu(child: SparkPlan, goal: CoalesceGoal)
}
/**
- * Returns an RDD[ColumnarBatch] that when mapped over will produce GPU-side column vectors
- * that are expected to be closed by its caller, not [[HostColumnarToGpu]].
- *
- * The expectation is that the only valid instantiation of this node is
- * as a child of a GPU exec node.
- *
- * @return an RDD of `ColumnarBatch`
- */
+ * Returns an RDD[ColumnarBatch] that when mapped over will produce GPU-side column vectors
+ * that are expected to be closed by its caller, not [[HostColumnarToGpu]].
+ *
+ * The expectation is that the only valid instantiation of this node is
+ * as a child of a GPU exec node.
+ *
+ * @return an RDD of `ColumnarBatch`
+ */
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numInputRows = longMetric(NUM_INPUT_ROWS)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala
index 5434bbc4741..02a7f8991ed 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala
@@ -21,12 +21,12 @@ import java.io.{InputStream, IOException, OutputStream}
import ai.rapids.cudf.HostMemoryBuffer
/**
- * An implementation of OutputStream that writes to a HostMemoryBuffer.
- *
- * NOTE: Closing this output stream does NOT close the buffer!
- *
- * @param buffer the buffer to receive written data
- */
+ * An implementation of OutputStream that writes to a HostMemoryBuffer.
+ *
+ * NOTE: Closing this output stream does NOT close the buffer!
+ *
+ * @param buffer the buffer to receive written data
+ */
class HostMemoryOutputStream(buffer: HostMemoryBuffer) extends OutputStream {
private var pos: Long = 0
@@ -49,13 +49,13 @@ class HostMemoryOutputStream(buffer: HostMemoryBuffer) extends OutputStream {
}
/**
- * An implementation of InputStream that reads from a HostMemoryBuffer.
- *
- * NOTE: Closing this input stream does NOT close the buffer!
- *
- * @param hmb the buffer from which to read data
- * @param hmbLength the amount of data available in the buffer
- */
+ * An implementation of InputStream that reads from a HostMemoryBuffer.
+ *
+ * NOTE: Closing this input stream does NOT close the buffer!
+ *
+ * @param hmb the buffer from which to read data
+ * @param hmbLength the amount of data available in the buffer
+ */
class HostMemoryInputStream(
hmb: HostMemoryBuffer,
hmbLength: Long) extends InputStream {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala
index 63d6a2f8bef..bf352633e74 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala
@@ -270,11 +270,11 @@ object ShuffleMetadata extends Logging{
}
/**
- * Given a sequence of `TableMeta`, re-lay the metas using the flat buffer builder in `fbb`.
- * @param fbb builder to use
- * @param tables sequence of `TableMeta` to copy
- * @return an array of flat buffer offsets for the copied `TableMeta`s
- */
+ * Given a sequence of `TableMeta`, re-lay the metas using the flat buffer builder in `fbb`.
+ * @param fbb builder to use
+ * @param tables sequence of `TableMeta` to copy
+ * @return an array of flat buffer offsets for the copied `TableMeta`s
+ */
def copyTables(fbb: FlatBufferBuilder, tables: Seq[TableMeta]): Array[Int] = {
tables.map { tableMeta =>
val buffMeta = tableMeta.bufferMeta()
@@ -474,9 +474,9 @@ object ShuffleMetadata extends Logging{
/**
- * Utility function to transfer a `TableMeta` to the heap,
- * @todo we would like to look for an easier way, perhaps just a memcpy will do.
- */
+ * Utility function to transfer a `TableMeta` to the heap,
+ * @todo we would like to look for an easier way, perhaps just a memcpy will do.
+ */
def copyTableMetaToHeap(meta: TableMeta): TableMeta = {
val fbb = ShuffleMetadata.getHeapBuilder
val tables = ShuffleMetadata.copyTables(fbb, Seq(meta))
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala
index a4f01c0ff94..ea942d1339f 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala
@@ -21,9 +21,9 @@ import ai.rapids.cudf.{NvtxColor, NvtxRange}
import org.apache.spark.sql.execution.metric.SQLMetric
/**
- * NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close
- * by the amount of time spent in the range
- */
+ * NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close
+ * by the amount of time spent in the range
+ */
class NvtxWithMetrics(name: String, color: NvtxColor, val metric: SQLMetric)
extends NvtxRange(name, color) {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
index 4705aa25348..a94ec1b7597 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
@@ -43,8 +43,8 @@ case class ColumnarOverrideRules() extends ColumnarRule with Logging {
}
/**
- * Extension point to enable GPU SQL processing.
- */
+ * Extension point to enable GPU SQL processing.
+ */
class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
override def apply(extensions: SparkSessionExtensions): Unit = {
logWarning("Installing extensions to enable rapids GPU SQL support." +
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala
index cb043ffedb9..af982b79f41 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala
@@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
object SamplingUtils {
/**
- * Reservoir sampling implementation that also returns the input size.
- *
- * @param input input size
- * @param k reservoir size
- * @param seed random seed
- * @return (samples, input size)
- */
+ * Reservoir sampling implementation that also returns the input size.
+ *
+ * @param input input size
+ * @param k reservoir size
+ * @param seed random seed
+ * @return (samples, input size)
+ */
def reservoirSampleAndCount[T: ClassTag](
input: Iterator[T],
k: Int,
@@ -77,16 +77,16 @@ object SamplingUtils {
/**
- * This class implements a XORShift random number generator algorithm
- * Source:
- * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
- * @see Paper
- * This implementation is approximately 3.5 times faster than
- * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
- * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
- * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
- * for each thread.
- */
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see Paper
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
def this() = this(System.nanoTime)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala
index 0ac39d17c28..5856a9c7e30 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala
@@ -203,11 +203,11 @@ class ShuffleBufferCatalog(
}
/**
- * Remove a buffer and table given a buffer ID
- * NOTE: This function is not thread safe! The caller should only invoke if
- * the [[ShuffleBufferId]] being removed is not being utilized by another thread.
- * @param id buffer identifier
- */
+ * Remove a buffer and table given a buffer ID
+ * NOTE: This function is not thread safe! The caller should only invoke if
+ * the [[ShuffleBufferId]] being removed is not being utilized by another thread.
+ * @param id buffer identifier
+ */
def removeBuffer(id: ShuffleBufferId): Unit = {
tableMap.remove(id.tableId)
catalog.removeBuffer(id)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
index d36e29f167f..50b61b24008 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
@@ -206,24 +206,24 @@ class GpuSortAggregateMeta(
}
/**
- * GpuHashAggregateExec - is the GPU version of HashAggregateExec, with some major differences:
- * - it doesn't support spilling to disk
- * - it doesn't support strings in the grouping key
- * - it doesn't support count(col1, col2, ..., colN)
- * - it doesn't support distinct
- * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in
- * EnsureRequirements to be able to add shuffle nodes
- * @param groupingExpressions The expressions that, when applied to the input batch, return the
- * grouping key
- * @param aggregateExpressions The GpuAggregateExpression instances for this node
- * @param aggregateAttributes References to each GpuAggregateExpression (attribute references)
- * @param initialInputBufferOffset this is not used in the GPU version, but it's used to offset
- * the slot in the aggregation buffer that aggregates should
- * start referencing
- * @param resultExpressions the expected output expression of this hash aggregate (which this
- * node should project)
- * @param child incoming plan (where we get input columns from)
- */
+ * GpuHashAggregateExec - is the GPU version of HashAggregateExec, with some major differences:
+ * - it doesn't support spilling to disk
+ * - it doesn't support strings in the grouping key
+ * - it doesn't support count(col1, col2, ..., colN)
+ * - it doesn't support distinct
+ * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in
+ * EnsureRequirements to be able to add shuffle nodes
+ * @param groupingExpressions The expressions that, when applied to the input batch, return the
+ * grouping key
+ * @param aggregateExpressions The GpuAggregateExpression instances for this node
+ * @param aggregateAttributes References to each GpuAggregateExpression (attribute references)
+ * @param initialInputBufferOffset this is not used in the GPU version, but it's used to offset
+ * the slot in the aggregation buffer that aggregates should
+ * start referencing
+ * @param resultExpressions the expected output expression of this hash aggregate (which this
+ * node should project)
+ * @param child incoming plan (where we get input columns from)
+ */
case class GpuHashAggregateExec(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[Expression],
@@ -531,12 +531,12 @@ case class GpuHashAggregateExec(
}
/**
- * concatenateBatches - given two ColumnarBatch instances, return a sequence of GpuColumnVector
- * that is the concatenated columns of the two input batches.
- * @param aggregatedInputCb this is an incoming batch
- * @param aggregatedCb this is a batch that was kept for concatenation
- * @return Seq[GpuColumnVector] with concatenated vectors
- */
+ * concatenateBatches - given two ColumnarBatch instances, return a sequence of GpuColumnVector
+ * that is the concatenated columns of the two input batches.
+ * @param aggregatedInputCb this is an incoming batch
+ * @param aggregatedCb this is a batch that was kept for concatenation
+ * @return Seq[GpuColumnVector] with concatenated vectors
+ */
private def concatenateBatches(aggregatedInputCb: ColumnarBatch, aggregatedCb: ColumnarBatch,
concatTime: SQLMetric): Seq[GpuColumnVector] = {
val nvtxRange = new NvtxWithMetrics("concatenateBatches", NvtxColor.BLUE, concatTime)
@@ -568,19 +568,19 @@ case class GpuHashAggregateExec(
private lazy val completeMode = uniqueModes.contains(Complete)
/**
- * getCudfAggregates returns a sequence of `cudf.Aggregate`, given the current mode
- * `AggregateMode`, and a sequence of all expressions for this [[GpuHashAggregateExec]]
- * node, we get all the expressions as that's important for us to be able to resolve the current
- * ordinal for this cudf aggregate.
- *
- * Examples:
- * fn = sum, min, max will always be Seq(fn)
- * avg will be Seq(sum, count) for Partial mode, but Seq(sum, sum) for other modes
- * count will be Seq(count) for Partial mode, but Seq(sum) for other modes
- *
- * @return Seq of `cudf.Aggregate`, with one or more aggregates that correspond to each
- * expression in allExpressions
- */
+ * getCudfAggregates returns a sequence of `cudf.Aggregate`, given the current mode
+ * `AggregateMode`, and a sequence of all expressions for this [[GpuHashAggregateExec]]
+ * node, we get all the expressions as that's important for us to be able to resolve the current
+ * ordinal for this cudf aggregate.
+ *
+ * Examples:
+ * fn = sum, min, max will always be Seq(fn)
+ * avg will be Seq(sum, count) for Partial mode, but Seq(sum, sum) for other modes
+ * count will be Seq(count) for Partial mode, but Seq(sum) for other modes
+ *
+ * @return Seq of `cudf.Aggregate`, with one or more aggregates that correspond to each
+ * expression in allExpressions
+ */
def setupReferences(childAttr: AttributeSeq,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[GpuAggregateExpression]): BoundExpressionsModeAggregates = {
@@ -866,8 +866,8 @@ case class GpuHashAggregateExec(
"Row-based execution should not occur for this class")
/**
- * All the attributes that are used for this plan. NOT used for aggregation
- */
+ * All the attributes that are used for this plan. NOT used for aggregation
+ */
override lazy val allAttributes: AttributeSeq =
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala
index 799b3c5be32..ae484940de1 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala
@@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.vectorized.ColumnarBatch
/**
- * RapidsPluginImplicits, adds implicit functions for ColumnarBatch, Seq, Seq[AutoCloseable],
- * and Array[AutoCloseable] that help make resource management easier within the project.
- */
+ * RapidsPluginImplicits, adds implicit functions for ColumnarBatch, Seq, Seq[AutoCloseable],
+ * and Array[AutoCloseable] that help make resource management easier within the project.
+ */
object RapidsPluginImplicits {
import scala.language.implicitConversions
@@ -60,11 +60,11 @@ object RapidsPluginImplicits {
implicit class AutoCloseableSeq[A <: AutoCloseable](val in: SeqLike[A, _]) {
/**
- * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each
- * element of the sequence, even if prior close calls fail. In case of failure in any of the
- * close calls, an Exception is thrown containing the suppressed exceptions (getSuppressed),
- * if any.
- */
+ * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each
+ * element of the sequence, even if prior close calls fail. In case of failure in any of the
+ * close calls, an Exception is thrown containing the suppressed exceptions (getSuppressed),
+ * if any.
+ */
def safeClose(): Unit = if (in != null) {
var closeException: Throwable = null
in.foreach { element =>
@@ -93,57 +93,57 @@ object RapidsPluginImplicits {
class MapsSafely[A, Repr] {
/**
- * safeMap: safeMap implementation that is leveraged by other type-specific implicits.
- *
- * safeMap has the added safety net that as you produce AutoCloseable values they are
- * tracked, and if an exception were to occur within the maps's body, it will make every
- * attempt to close each produced value.
- *
- * Note: safeMap will close in case of errors, without any knowledge of whether it should
- * or not.
- * Use safeMap only in these circumstances if `fn` increases the reference count,
- * producing an AutoCloseable, and nothing else is tracking these references:
- * a) seq.safeMap(x => {...; x.incRefCount; x})
- * b) seq.safeMap(x => GpuColumnVector.from(...))
- *
- * Usage of safeMap chained with other maps is a bit confusing:
- *
- * seq.map(GpuColumnVector.from).safeMap(couldThrow)
- *
- * Will close the column vectors produced from couldThrow up until the time where safeMap
- * throws.
- *
- * The correct pattern of usage in cases like this is:
- *
- * val closeTheseLater = seq.safeMap(GpuColumnVector.from)
- * closeTheseLater.safeMap{ x =>
- * var success = false
- * try {
- * val res = couldThrow(x.incRefCount())
- * success = true
- * res // return a ref count of 2
- * } finally {
- * if (!success) {
- * // in case of an error, we close x as part of normal error handling
- * // the exception will be caught by the safeMap, and it will close all
- * // AutoCloseables produced before x
- * // - Sequence looks like: [2, 2, 2, ..., 2] + x, which has also has a refcount of 2
- * x.close() // x now has a ref count of 1, the rest of the sequence has 2s
- * }
- * }
- * } // safeMap cleaned, and now everything has 1s for ref counts (as they were before)
- *
- * closeTheseLater.safeClose() // go from 1 to 0 in all things inside closeTheseLater
- *
- * @param in the Seq[A] to map on
- * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
- * @tparam A the type of the elements in Seq
- * @tparam B the type of the elements produced in the safeMap (should be subclasses of
- * AutoCloseable)
- * @tparam Repr the type of the input collection (needed by builder)
- * @tparam That the type of the output collection (needed by builder)
- * @return a sequence of B, in the success case
- */
+ * safeMap: safeMap implementation that is leveraged by other type-specific implicits.
+ *
+ * safeMap has the added safety net that as you produce AutoCloseable values they are
+ * tracked, and if an exception were to occur within the maps's body, it will make every
+ * attempt to close each produced value.
+ *
+ * Note: safeMap will close in case of errors, without any knowledge of whether it should
+ * or not.
+ * Use safeMap only in these circumstances if `fn` increases the reference count,
+ * producing an AutoCloseable, and nothing else is tracking these references:
+ * a) seq.safeMap(x => {...; x.incRefCount; x})
+ * b) seq.safeMap(x => GpuColumnVector.from(...))
+ *
+ * Usage of safeMap chained with other maps is a bit confusing:
+ *
+ * seq.map(GpuColumnVector.from).safeMap(couldThrow)
+ *
+ * Will close the column vectors produced from couldThrow up until the time where safeMap
+ * throws.
+ *
+ * The correct pattern of usage in cases like this is:
+ *
+ * val closeTheseLater = seq.safeMap(GpuColumnVector.from)
+ * closeTheseLater.safeMap{ x =>
+ * var success = false
+ * try {
+ * val res = couldThrow(x.incRefCount())
+ * success = true
+ * res // return a ref count of 2
+ * } finally {
+ * if (!success) {
+ * // in case of an error, we close x as part of normal error handling
+ * // the exception will be caught by the safeMap, and it will close all
+ * // AutoCloseables produced before x
+ * // - Sequence looks like: [2, 2, 2, ..., 2] + x, which has also has a refcount of 2
+ * x.close() // x now has a ref count of 1, the rest of the sequence has 2s
+ * }
+ * }
+ * } // safeMap cleaned, and now everything has 1s for ref counts (as they were before)
+ *
+ * closeTheseLater.safeClose() // go from 1 to 0 in all things inside closeTheseLater
+ *
+ * @param in the Seq[A] to map on
+ * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
+ * @tparam A the type of the elements in Seq
+ * @tparam B the type of the elements produced in the safeMap (should be subclasses of
+ * AutoCloseable)
+ * @tparam Repr the type of the input collection (needed by builder)
+ * @tparam That the type of the output collection (needed by builder)
+ * @return a sequence of B, in the success case
+ */
protected def safeMap[B <: AutoCloseable, That](
in: SeqLike[A, Repr],
fn: A => B)
@@ -179,48 +179,48 @@ object RapidsPluginImplicits {
implicit class AutoCloseableProducingSeq[A](val in: Seq[A]) extends MapsSafely[A, Seq[A]] {
/**
- * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of
- * AutoCloseable.
- * See [[MapsSafely.safeMap]] for a more detailed explanation.
- *
- * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
- * @tparam A the type of the elements in Seq
- * @tparam B the type of the elements produced in the safeMap (should be subclasses of
- * AutoCloseable)
- * @return a sequence of B, in the success case
- */
+ * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of
+ * AutoCloseable.
+ * See [[MapsSafely.safeMap]] for a more detailed explanation.
+ *
+ * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
+ * @tparam A the type of the elements in Seq
+ * @tparam B the type of the elements produced in the safeMap (should be subclasses of
+ * AutoCloseable)
+ * @return a sequence of B, in the success case
+ */
def safeMap[B <: AutoCloseable](fn: A => B): Seq[B] = super.safeMap(in, fn)
}
implicit class AutoCloseableProducingArray[A](val in: Array[A]) extends MapsSafely[A, Array[A]] {
/**
- * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of
- * AutoCloseable.
- * See [[MapsSafely.safeMap]] for a more detailed explanation.
- *
- * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
- * @tparam A the type of the elements in Seq
- * @tparam B the type of the elements produced in the safeMap (should be subclasses of
- * AutoCloseable)
- * @return a sequence of B, in the success case
- */
+ * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of
+ * AutoCloseable.
+ * See [[MapsSafely.safeMap]] for a more detailed explanation.
+ *
+ * @param fn a function that takes A, and produces B (a subclass of AutoCloseable)
+ * @tparam A the type of the elements in Seq
+ * @tparam B the type of the elements produced in the safeMap (should be subclasses of
+ * AutoCloseable)
+ * @return a sequence of B, in the success case
+ */
def safeMap[B <: AutoCloseable : ClassTag](fn: A => B): Array[B] = super.safeMap(in, fn)
}
implicit class AutoCloseableFromBatchColumns(val in: ColumnarBatch)
extends MapsSafely[Int, Seq[Int]] {
/**
- * safeMap: Is an implicit on ColumnarBatch, that lets you map over the columns
- * of a batch as if the batch was a Seq[GpuColumnVector], iff safeMap's body is producing
- * AutoCloseable (otherwise, it is not defined).
- *
- * See [[MapsSafely.safeMap]] for a more detailed explanation.
- *
- * @param fn a function that takes GpuColumnVector, and returns a subclass of AutoCloseable
- * @tparam B the type of the elements produced in the safeMap (should be subclasses of
- * AutoCloseable)
- * @return a sequence of B, in the success case
- */
+ * safeMap: Is an implicit on ColumnarBatch, that lets you map over the columns
+ * of a batch as if the batch was a Seq[GpuColumnVector], iff safeMap's body is producing
+ * AutoCloseable (otherwise, it is not defined).
+ *
+ * See [[MapsSafely.safeMap]] for a more detailed explanation.
+ *
+ * @param fn a function that takes GpuColumnVector, and returns a subclass of AutoCloseable
+ * @tparam B the type of the elements produced in the safeMap (should be subclasses of
+ * AutoCloseable)
+ * @return a sequence of B, in the success case
+ */
def safeMap[B <: AutoCloseable](fn: GpuColumnVector => B): Seq[B] = {
val colIds: Seq[Int] = 0 until in.numCols
super.safeMap(colIds, (i: Int) => fn(in.column(i).asInstanceOf[GpuColumnVector]))
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala
index 7ba2da5d9ec..8aea6eda62b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala
@@ -147,9 +147,9 @@ case class GpuIsNan(child: Expression) extends GpuUnaryExpression with Predicate
}
/**
- * A GPU accelerated predicate that is evaluated to be true if there are at least `n` non-null
- * and non-NaN values.
- */
+ * A GPU accelerated predicate that is evaluated to be true if there are at least `n` non-null
+ * and non-NaN values.
+ */
case class GpuAtLeastNNonNulls(
n: Int,
exprs: Seq[Expression])
@@ -160,17 +160,17 @@ case class GpuAtLeastNNonNulls(
override def toString: String = s"GpuAtLeastNNulls(n, ${children.mkString(",")})"
override def children: Seq[Expression] = exprs
/**
- * Returns the result of evaluating this expression on the entire
- * `ColumnarBatch`. The result of calling this may be a single [[GpuColumnVector]] or a scalar
- * value. Scalar values typically happen if they are a part of the expression
- * i.e. col("a") + 100.
- * In this case the 100 is a literal that Add would have to be able to handle.
- *
- * By convention any [[GpuColumnVector]] returned by [[columnarEval]]
- * is owned by the caller and will need to be closed by them. This can happen by putting it into
- * a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a
- * temporary value.
- */
+ * Returns the result of evaluating this expression on the entire
+ * `ColumnarBatch`. The result of calling this may be a single [[GpuColumnVector]] or a scalar
+ * value. Scalar values typically happen if they are a part of the expression
+ * i.e. col("a") + 100.
+ * In this case the 100 is a literal that Add would have to be able to handle.
+ *
+ * By convention any [[GpuColumnVector]] returned by [[columnarEval]]
+ * is owned by the caller and will need to be closed by them. This can happen by putting it into
+ * a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a
+ * temporary value.
+ */
override def columnarEval(batch: ColumnarBatch): Any = {
val nonNullNanCounts : mutable.Queue[ColumnVector] = new mutable.Queue[ColumnVector]()
try {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala
index e05c7b6fb57..c7f0ed8ff58 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala
@@ -23,15 +23,15 @@ import ai.rapids.cudf.MemoryBuffer
import org.apache.spark.internal.Logging
/**
- * This classes manages a set of bounce buffers, that are instances of `MemoryBuffer`.
- * The size/quantity of buffers is configurable, and so is the allocator.
- * @param poolName a human-friendly name to use for debug logs
- * @param bufferSize the size of buffer to use
- * @param numBuffers the number of buffers to allocate on instantiation
- * @param allocator function that takes a size, and returns a `MemoryBuffer` instance.
- * @tparam T the specific type of MemoryBuffer i.e. `DeviceMemoryBuffer`,
- * `HostMemoryBuffer`, etc.
- */
+ * This classes manages a set of bounce buffers, that are instances of `MemoryBuffer`.
+ * The size/quantity of buffers is configurable, and so is the allocator.
+ * @param poolName a human-friendly name to use for debug logs
+ * @param bufferSize the size of buffer to use
+ * @param numBuffers the number of buffers to allocate on instantiation
+ * @param allocator function that takes a size, and returns a `MemoryBuffer` instance.
+ * @tparam T the specific type of MemoryBuffer i.e. `DeviceMemoryBuffer`,
+ * `HostMemoryBuffer`, etc.
+ */
class BounceBufferManager[T <: MemoryBuffer](
poolName: String,
val bufferSize: Long,
@@ -47,11 +47,11 @@ class BounceBufferManager[T <: MemoryBuffer](
freeBufferMap.set(0, numBuffers)
/**
- * Acquires a [[MemoryBuffer]] from the pool. Blocks if the pool is empty.
- *
- * @note calls to this function should have a lock on this [[BounceBufferManager]]
- * @return the acquired `MemoryBuffer`
- */
+ * Acquires a [[MemoryBuffer]] from the pool. Blocks if the pool is empty.
+ *
+ * @note calls to this function should have a lock on this [[BounceBufferManager]]
+ * @return the acquired `MemoryBuffer`
+ */
private def acquireBuffer(): MemoryBuffer = {
val start = System.currentTimeMillis()
var bufferIndex = freeBufferMap.nextSetBit(0)
@@ -73,10 +73,10 @@ class BounceBufferManager[T <: MemoryBuffer](
}
/**
- * Acquire `possibleNumBuffers` buffers from the pool. This method will not block.
- * @param possibleNumBuffers number of buffers to acquire
- * @return a sequence of `MemoryBuffer`s, or empty if the request can't be satisfied
- */
+ * Acquire `possibleNumBuffers` buffers from the pool. This method will not block.
+ * @param possibleNumBuffers number of buffers to acquire
+ * @return a sequence of `MemoryBuffer`s, or empty if the request can't be satisfied
+ */
def acquireBuffersNonBlocking(possibleNumBuffers: Int): Seq[MemoryBuffer] = synchronized {
if (numFree < possibleNumBuffers) {
// would block
@@ -89,11 +89,11 @@ class BounceBufferManager[T <: MemoryBuffer](
}
/**
- * Acquire `possibleNumBuffers` buffers from the pool. This method will block until
- * it can get the buffers requested.
- * @param possibleNumBuffers number of buffers to acquire
- * @return a sequence of `MemoryBuffer`s
- */
+ * Acquire `possibleNumBuffers` buffers from the pool. This method will block until
+ * it can get the buffers requested.
+ * @param possibleNumBuffers number of buffers to acquire
+ * @return a sequence of `MemoryBuffer`s
+ */
def acquireBuffersBlocking(possibleNumBuffers: Int): Seq[MemoryBuffer] = synchronized {
val res = (0 until possibleNumBuffers).map(_ => acquireBuffer())
logDebug(s"$poolName at acquire. Has numFree ${numFree}")
@@ -101,9 +101,9 @@ class BounceBufferManager[T <: MemoryBuffer](
}
/**
- * Free a `MemoryBuffer`, putting it back into the pool.
- * @param buffer the memory buffer to free
- */
+ * Free a `MemoryBuffer`, putting it back into the pool.
+ * @param buffer the memory buffer to free
+ */
def freeBuffer(buffer: MemoryBuffer): Unit = synchronized {
require(buffer.getAddress >= rootBuffer.getAddress
&& (buffer.getAddress - rootBuffer.getAddress) % bufferSize == 0,
@@ -119,10 +119,10 @@ class BounceBufferManager[T <: MemoryBuffer](
}
/**
- * Returns the root (backing) `MemoryBuffer`. This is used for a transport
- * that wants to register the bounce buffers against hardware, for pinning purposes.
- * @return the root (backing) memory buffer
- */
+ * Returns the root (backing) `MemoryBuffer`. This is used for a transport
+ * that wants to register the bounce buffers against hardware, for pinning purposes.
+ * @return the root (backing) memory buffer
+ */
def getRootBuffer(): MemoryBuffer = rootBuffer
override def close(): Unit = rootBuffer.close()
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala
index c6d913c37d4..9b1c10a6be1 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala
@@ -29,40 +29,40 @@ import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.storage.ShuffleBlockBatchId
/**
- * trait used by client consumers ([[RapidsShuffleIterator]]) to gather what the
- * expected number of batches (Tables) is, and the ids for each table as they are received.
- */
+ * trait used by client consumers ([[RapidsShuffleIterator]]) to gather what the
+ * expected number of batches (Tables) is, and the ids for each table as they are received.
+ */
trait RapidsShuffleFetchHandler {
/**
- * After a response for shuffle metadata is received, the expected number of columnar
- * batches is communicated back to the caller via this method.
- * @param expectedBatches - number of columnar batches
- */
+ * After a response for shuffle metadata is received, the expected number of columnar
+ * batches is communicated back to the caller via this method.
+ * @param expectedBatches - number of columnar batches
+ */
def start(expectedBatches: Int): Unit
/**
- * Called when a buffer is received and has been handed off to the catalog.
- * @param bufferId - a tracked shuffle buffer id
- */
+ * Called when a buffer is received and has been handed off to the catalog.
+ * @param bufferId - a tracked shuffle buffer id
+ */
def batchReceived(bufferId: ShuffleReceivedBufferId): Unit
/**
- * Called when the transport layer is not able to handle a fetch error for metadata
- * or buffer fetches.
- *
- * @param errorMessage - a string containing an error message
- */
+ * Called when the transport layer is not able to handle a fetch error for metadata
+ * or buffer fetches.
+ *
+ * @param errorMessage - a string containing an error message
+ */
def transferError(errorMessage: String): Unit
}
/**
- * A helper case class that describes a pending table. It allows
- * the transport to schedule/throttle these requests to fit within the maximum bytes in flight.
- * @param client client used to issue the requests
- * @param tableMeta shuffle metadata describing the table
- * @param tag a transport specific tag to use for this transfer
- * @param handler a specific handler that is waiting for this batch
- */
+ * A helper case class that describes a pending table. It allows
+ * the transport to schedule/throttle these requests to fit within the maximum bytes in flight.
+ * @param client client used to issue the requests
+ * @param tableMeta shuffle metadata describing the table
+ * @param tag a transport specific tag to use for this transfer
+ * @param handler a specific handler that is waiting for this batch
+ */
case class PendingTransferRequest(client: RapidsShuffleClient,
tableMeta: TableMeta,
tag: Long,
@@ -71,40 +71,40 @@ case class PendingTransferRequest(client: RapidsShuffleClient,
}
/**
- * Class describing the state of a set of [[PendingTransferRequest]]s.
- *
- * This class is *not thread safe*. The way the code is currently designed, bounce buffers being
- * used to receive, or copied from, are acted on a sequential basis, in time and in space.
- *
- * Callers use this class, like so:
- *
- * 1. [[getRequest]]
- * - first call:
- * - Initializes the state tracking the progress of `currentRequest` and our place in
- * `requests` (e.g. offset, bytes remaining)
- *
- * - subsequent calls:
- * - if `currentRequest` is not done, it will return the current request. And perform
- * sanity checks.
- * - if `currentRequest` is done, it will perform sanity checks, and advance to the next
- * request in `requests`
- *
- * 2. [[consumeBuffers]]
- * - first call:
- * - The first time around for `currentRequest` it will allocate the actual full device
- * buffer that we will copy to, copy data sequentially from the bounce buffer to the
- * target buffer.
- * - subsequent calls:
- * - continue copy data sequentially from the bounce buffers passed in.
- * - When `currentRequest` has been fully received, an optional `DeviceMemoryBuffer` is
- * set, and returned.
- *
- * 3. [[close]]
- * - once the caller calls close, the bounce buffers are returned to the pool.
- *
- * @param transport a transport, which in this case is used to free bounce buffers
- * @param bounceMemoryBuffers a sequence of `MemoryBuffer` buffers to use for receives
- */
+ * Class describing the state of a set of [[PendingTransferRequest]]s.
+ *
+ * This class is *not thread safe*. The way the code is currently designed, bounce buffers being
+ * used to receive, or copied from, are acted on a sequential basis, in time and in space.
+ *
+ * Callers use this class, like so:
+ *
+ * 1. [[getRequest]]
+ * - first call:
+ * - Initializes the state tracking the progress of `currentRequest` and our place in
+ * `requests` (e.g. offset, bytes remaining)
+ *
+ * - subsequent calls:
+ * - if `currentRequest` is not done, it will return the current request. And perform
+ * sanity checks.
+ * - if `currentRequest` is done, it will perform sanity checks, and advance to the next
+ * request in `requests`
+ *
+ * 2. [[consumeBuffers]]
+ * - first call:
+ * - The first time around for `currentRequest` it will allocate the actual full device
+ * buffer that we will copy to, copy data sequentially from the bounce buffer to the
+ * target buffer.
+ * - subsequent calls:
+ * - continue copy data sequentially from the bounce buffers passed in.
+ * - When `currentRequest` has been fully received, an optional `DeviceMemoryBuffer` is
+ * set, and returned.
+ *
+ * 3. [[close]]
+ * - once the caller calls close, the bounce buffers are returned to the pool.
+ *
+ * @param transport a transport, which in this case is used to free bounce buffers
+ * @param bounceMemoryBuffers a sequence of `MemoryBuffer` buffers to use for receives
+ */
class BufferReceiveState(
transport: RapidsShuffleTransport,
val bounceMemoryBuffers: Seq[MemoryBuffer])
@@ -115,63 +115,63 @@ class BufferReceiveState(
private[this] val requests = new ArrayBuffer[PendingTransferRequest]()
/**
- * Use by the transport to add to this [[BufferReceiveState]] requests it needs to handle.
- * @param pendingTransferRequest request to add to this [[BufferReceiveState]]
- */
+ * Use by the transport to add to this [[BufferReceiveState]] requests it needs to handle.
+ * @param pendingTransferRequest request to add to this [[BufferReceiveState]]
+ */
def addRequest(pendingTransferRequest: PendingTransferRequest): Unit = synchronized {
requests.append(pendingTransferRequest)
}
/**
- * Holds the target device memory buffer. It is allocated at [[consumeBuffers]] when the first
- * bounce buffer resolves, and it is also handed off to the caller in the same function.
- */
+ * Holds the target device memory buffer. It is allocated at [[consumeBuffers]] when the first
+ * bounce buffer resolves, and it is also handed off to the caller in the same function.
+ */
private[this] var buff: DeviceMemoryBuffer = null
/**
- * This is the address/length/tag object for the target buffer. It is used only to access a cuda
- * synchronous copy method. We should do a regular copy from [[buff]] once we use the async
- * method and add the cuda event synchronization.
- */
+ * This is the address/length/tag object for the target buffer. It is used only to access a cuda
+ * synchronous copy method. We should do a regular copy from [[buff]] once we use the async
+ * method and add the cuda event synchronization.
+ */
private[this] var alt: AddressLengthTag = null
/**
- * True iff this is the last request we are handling. It is used in [[isDone]] to find the
- * stopping point.
- */
+ * True iff this is the last request we are handling. It is used in [[isDone]] to find the
+ * stopping point.
+ */
private[this] var lastRequest: Boolean = false
/**
- * The index into the [[requests]] sequence pointing to the current request we are handling.
- * We handle requests sequentially.
- */
+ * The index into the [[requests]] sequence pointing to the current request we are handling.
+ * We handle requests sequentially.
+ */
private[this] var currentRequestIndex = -1
/**
- * Amount of bytes left in the current request.
- */
+ * Amount of bytes left in the current request.
+ */
private[this] var currentRequestRemaining: Long = 0L
/**
- * Byte offset we are currently at. [[getRequest]] uses and resets it, and [[consumeBuffers]]
- * updates it.
- */
+ * Byte offset we are currently at. [[getRequest]] uses and resets it, and [[consumeBuffers]]
+ * updates it.
+ */
private[this] var currentRequestOffset: Long = 0L
/**
- * The current request (TableMeta, Handler (iterator))
- */
+ * The current request (TableMeta, Handler (iterator))
+ */
private[this] var currentRequest: PendingTransferRequest = null
/**
- * To help debug "closed multiple times" issues
- */
+ * To help debug "closed multiple times" issues
+ */
private[this] var isClosed = false
/**
- * Becomes true when there is an error detected, allowing the client to close this
- * [[BufferReceiveState]] prematurely.
- */
+ * Becomes true when there is an error detected, allowing the client to close this
+ * [[BufferReceiveState]] prematurely.
+ */
private[this] var errorOcurred = false
override def toString: String = {
@@ -182,14 +182,14 @@ class BufferReceiveState(
}
/**
- * When a receive transaction is successful, this function is called to consume the bounce
- * buffers received.
- *
- * @note If the target device buffer is not allocated, this function does so.
- * @param bounceBuffers sequence of buffers that have been received
- * @return if the current request is complete, returns a shallow copy of the target
- * device buffer as an `Option`. Callers will need to close the buffer.
- */
+ * When a receive transaction is successful, this function is called to consume the bounce
+ * buffers received.
+ *
+ * @note If the target device buffer is not allocated, this function does so.
+ * @param bounceBuffers sequence of buffers that have been received
+ * @return if the current request is complete, returns a shallow copy of the target
+ * device buffer as an `Option`. Callers will need to close the buffer.
+ */
def consumeBuffers(
bounceBuffers: Seq[AddressLengthTag]): Option[DeviceMemoryBuffer] = synchronized {
var needsCleanup = true
@@ -238,20 +238,20 @@ class BufferReceiveState(
}
/**
- * Signals whether this [[BufferReceiveState]] is complete.
- * @return boolean when at the last request, and that request is fully received
- */
+ * Signals whether this [[BufferReceiveState]] is complete.
+ * @return boolean when at the last request, and that request is fully received
+ */
def isDone: Boolean = synchronized {
lastRequest && currentRequestDone
}
/**
- * Return the current (making the next request current, if we are done) request and
- * whether we advanced (want to kick off a transfer request to the server)
- *
- * @return returns the currently working transfer request, and
- * true if this is a new request we should be asking the server to trigger
- */
+ * Return the current (making the next request current, if we are done) request and
+ * whether we advanced (want to kick off a transfer request to the server)
+ *
+ * @return returns the currently working transfer request, and
+ * true if this is a new request we should be asking the server to trigger
+ */
def getRequest: (PendingTransferRequest, Boolean) = synchronized {
require(currentRequestIndex < requests.size,
"Something went wrong while handling buffer receives. Asking for more buffers than expected")
@@ -301,18 +301,18 @@ class BufferReceiveState(
}
/**
- * Used to cut the subset of bounce buffers the client will need to issue receives with.
- *
- * This is a subset, because at the moment, the full target length worth of bounce buffers
- * could have been acquired. If we are at the tail end of receive, and it could be fulfilled
- * with 1 bounce buffer, for example, we would return 1 bounce buffer here, rather than the
- * number of buffers in acquired.
- *
- * @note that these extra buffers should really just be freed as soon as we realize
- * they are of no use.
- *
- * @return sequence of [[AddressLengthTag]] pointing to the receive bounce buffers.
- */
+ * Used to cut the subset of bounce buffers the client will need to issue receives with.
+ *
+ * This is a subset, because at the moment, the full target length worth of bounce buffers
+ * could have been acquired. If we are at the tail end of receive, and it could be fulfilled
+ * with 1 bounce buffer, for example, we would return 1 bounce buffer here, rather than the
+ * number of buffers in acquired.
+ *
+ * @note that these extra buffers should really just be freed as soon as we realize
+ * they are of no use.
+ *
+ * @return sequence of [[AddressLengthTag]] pointing to the receive bounce buffers.
+ */
def getBounceBuffersForReceive(): Seq[AddressLengthTag] = synchronized {
var bounceBufferIx = 0
var bounceBuffersForTransfer = Seq[AddressLengthTag]()
@@ -351,28 +351,28 @@ class BufferReceiveState(
}
/**
- * The client makes requests via a [[Connection]] obtained from the [[RapidsShuffleTransport]].
- *
- * The [[Connection]] follows a single threaded callback model, so this class posts operations
- * to an `Executor` as quickly as it gets them from the [[Connection]].
- *
- * This class handles fetch requests from [[RapidsShuffleIterator]], turning them into
- * [[ShuffleMetadata]] messages, and shuffle `TransferRequest`s.
- *
- * Its counterpart is the [[RapidsShuffleServer]] on a specific peer executor, specified by
- * `connection`.
- *
- * @param localExecutorId this id is sent to the server, it is required for the protocol as
- * the server needs to pick an endpoint to send a response back to this
- * executor.
- * @param connection a connection object against a remote executor
- * @param transport used to get metadata buffers and to work with the throttle mechanism
- * @param exec Executor used to handle tasks that take time, and should not be in the
- * transport's thread
- * @param clientCopyExecutor Executors used to handle synchronous mem copies
- * @param maximumMetadataSize The maximum metadata buffer size we are able to request
- * TODO: this should go away
- */
+ * The client makes requests via a [[Connection]] obtained from the [[RapidsShuffleTransport]].
+ *
+ * The [[Connection]] follows a single threaded callback model, so this class posts operations
+ * to an `Executor` as quickly as it gets them from the [[Connection]].
+ *
+ * This class handles fetch requests from [[RapidsShuffleIterator]], turning them into
+ * [[ShuffleMetadata]] messages, and shuffle `TransferRequest`s.
+ *
+ * Its counterpart is the [[RapidsShuffleServer]] on a specific peer executor, specified by
+ * `connection`.
+ *
+ * @param localExecutorId this id is sent to the server, it is required for the protocol as
+ * the server needs to pick an endpoint to send a response back to this
+ * executor.
+ * @param connection a connection object against a remote executor
+ * @param transport used to get metadata buffers and to work with the throttle mechanism
+ * @param exec Executor used to handle tasks that take time, and should not be in the
+ * transport's thread
+ * @param clientCopyExecutor Executors used to handle synchronous mem copies
+ * @param maximumMetadataSize The maximum metadata buffer size we are able to request
+ * TODO: this should go away
+ */
class RapidsShuffleClient(
localExecutorId: Long,
connection: ClientConnection,
@@ -385,48 +385,48 @@ class RapidsShuffleClient(
object ShuffleClientOps {
/**
- * When a metadata response is received, this event is issued to handle it.
- * @param tx the [[Transaction]] to be closed after consuming the response
- * @param resp the response metadata buffer
- * @param shuffleRequests blocks to be requested
- * @param rapidsShuffleFetchHandler the handler (iterator) to callback to
- */
+ * When a metadata response is received, this event is issued to handle it.
+ * @param tx the [[Transaction]] to be closed after consuming the response
+ * @param resp the response metadata buffer
+ * @param shuffleRequests blocks to be requested
+ * @param rapidsShuffleFetchHandler the handler (iterator) to callback to
+ */
case class HandleMetadataResponse(tx: Transaction,
resp: RefCountedDirectByteBuffer,
shuffleRequests: Seq[ShuffleBlockBatchId],
rapidsShuffleFetchHandler: RapidsShuffleFetchHandler)
/**
- * Represents retry due to metadata being larger than expected.
- *
- * @param shuffleRequests request to retry
- * @param rapidsShuffleFetchHandler the handler (iterator) to callback to
- * @param fullResponseSize response size to allocate to fit the server's response in full
- */
+ * Represents retry due to metadata being larger than expected.
+ *
+ * @param shuffleRequests request to retry
+ * @param rapidsShuffleFetchHandler the handler (iterator) to callback to
+ * @param fullResponseSize response size to allocate to fit the server's response in full
+ */
case class FetchRetry(shuffleRequests: Seq[ShuffleBlockBatchId],
rapidsShuffleFetchHandler: RapidsShuffleFetchHandler,
fullResponseSize: Long)
/**
- * Used to have this client handle the enclosed [[BufferReceiveState]] asynchronously.
- *
- * Until [[bufferReceiveState]] completes, this event will continue to get posted.
- *
- * @param bufferReceiveState object containing the state of pending requests to the peer
- */
+ * Used to have this client handle the enclosed [[BufferReceiveState]] asynchronously.
+ *
+ * Until [[bufferReceiveState]] completes, this event will continue to get posted.
+ *
+ * @param bufferReceiveState object containing the state of pending requests to the peer
+ */
case class IssueBufferReceives(bufferReceiveState: BufferReceiveState)
/**
- * When a buffer is received, this event is posted to remove from the progress thread
- * the copy, and the callback into the iterator.
- *
- * Currently not used. There is a TODO below.
- *
- * @param tx live transaction for the buffer, to be closed after the buffer is handled
- * @param bufferReceiveState the object maintaining state for receives
- * @param currentRequest the request these bounce buffers belong to
- * @param bounceBuffers buffers used in the transfer, which contain the fragment of the data
- */
+ * When a buffer is received, this event is posted to remove from the progress thread
+ * the copy, and the callback into the iterator.
+ *
+ * Currently not used. There is a TODO below.
+ *
+ * @param tx live transaction for the buffer, to be closed after the buffer is handled
+ * @param bufferReceiveState the object maintaining state for receives
+ * @param currentRequest the request these bounce buffers belong to
+ * @param bounceBuffers buffers used in the transfer, which contain the fragment of the data
+ */
case class HandleBounceBufferReceive(tx: Transaction,
bufferReceiveState: BufferReceiveState,
currentRequest: PendingTransferRequest,
@@ -459,24 +459,24 @@ class RapidsShuffleClient(
}
/**
- * Pushes a task onto the queue to be handled by the client's copy executor.
- *
- * @note - at this stage, tasks in this pool can block (it will grow as needed)
- *
- * @param op One of the case classes in [[ShuffleClientOps]]
- */
+ * Pushes a task onto the queue to be handled by the client's copy executor.
+ *
+ * @note - at this stage, tasks in this pool can block (it will grow as needed)
+ *
+ * @param op One of the case classes in [[ShuffleClientOps]]
+ */
private[this] def asyncOnCopyThread(op: Any): Unit = {
clientCopyExecutor.execute(() => handleOp(op))
}
/**
- * Starts a fetch request for all the shuffleRequests, using `handler` to communicate
- * events back to the iterator.
- *
- * @param shuffleRequests blocks to fetch
- * @param handler iterator to callback to
- * @param metadataSize metadata size to use for this fetch
- */
+ * Starts a fetch request for all the shuffleRequests, using `handler` to communicate
+ * events back to the iterator.
+ *
+ * @param shuffleRequests blocks to fetch
+ * @param handler iterator to callback to
+ * @param metadataSize metadata size to use for this fetch
+ */
def doFetch(shuffleRequests: Seq[ShuffleBlockBatchId],
handler: RapidsShuffleFetchHandler,
metadataSize: Long = maximumMetadataSize): Unit = {
@@ -522,13 +522,13 @@ class RapidsShuffleClient(
}
/**
- * Function to handle MetadataResponses, as a result of the [[HandleMetadataResponse]] event.
- *
- * @param tx live metadata response transaction to be closed in this handler
- * @param resp response buffer, to be closed in this handler
- * @param shuffleRequests blocks to fetch
- * @param handler iterator to callback to
- */
+ * Function to handle MetadataResponses, as a result of the [[HandleMetadataResponse]] event.
+ *
+ * @param tx live metadata response transaction to be closed in this handler
+ * @param resp response buffer, to be closed in this handler
+ * @param shuffleRequests blocks to fetch
+ * @param handler iterator to callback to
+ */
private[this] def doHandleMetadataResponse(
tx: Transaction,
resp: RefCountedDirectByteBuffer,
@@ -574,21 +574,21 @@ class RapidsShuffleClient(
}
/**
- * Used by the transport, to schedule receives. The requests are sent to the executor for this
- * client.
- * @param bufferReceiveState object tracking the state of pending TransferRequests
- */
+ * Used by the transport, to schedule receives. The requests are sent to the executor for this
+ * client.
+ * @param bufferReceiveState object tracking the state of pending TransferRequests
+ */
def issueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = {
asyncOnCopyThread(IssueBufferReceives(bufferReceiveState))
}
/**
- * Issues transfers requests (if the state of [[bufferReceiveState]] advances), or continue to
- * work a current request (continue receiving bounce buffer sized chunks from a larger receive).
- * @param bufferReceiveState object maintaining state of requests to be issued (current or
- * future). The requests included in this state object originated in
- * the transport's throttle logic.
- */
+ * Issues transfers requests (if the state of [[bufferReceiveState]] advances), or continue to
+ * work a current request (continue receiving bounce buffer sized chunks from a larger receive).
+ * @param bufferReceiveState object maintaining state of requests to be issued (current or
+ * future). The requests included in this state object originated in
+ * the transport's throttle logic.
+ */
private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = {
logDebug(s"At issue for ${bufferReceiveState}, " +
s"remaining: ${bufferReceiveState.getCurrentRequestRemaining}, " +
@@ -691,12 +691,12 @@ class RapidsShuffleClient(
}
/**
- * Feed into the throttle thread in the transport [[PendingTransferRequest]], to be
- * issued later via the [[doIssueBufferReceives]] method.
- *
- * @param metaResponse metadata response flat buffer
- * @param handler callback trait (the iterator implements this)
- */
+ * Feed into the throttle thread in the transport [[PendingTransferRequest]], to be
+ * issued later via the [[doIssueBufferReceives]] method.
+ *
+ * @param metaResponse metadata response flat buffer
+ * @param handler callback trait (the iterator implements this)
+ */
private def queueTransferRequests(metaResponse: MetadataResponse,
handler: RapidsShuffleFetchHandler): Unit = {
val allTables = metaResponse.tableMetasLength()
@@ -725,14 +725,14 @@ class RapidsShuffleClient(
}
/**
- * This function handles data received in `bounceBuffers`. The data should be copied out
- * of the buffers, and the function should call into `bufferReceiveState` to advance its
- * state (consumeBuffers)
- * @param tx live transaction for these bounce buffers, it should be closed in this function
- * @param bufferReceiveState state management objects for live transfer requests
- * @param currentRequest current transfer request being worked on
- * @param bounceBuffers bounce buffers (just received) containing data to be consumed
- */
+ * This function handles data received in `bounceBuffers`. The data should be copied out
+ * of the buffers, and the function should call into `bufferReceiveState` to advance its
+ * state (consumeBuffers)
+ * @param tx live transaction for these bounce buffers, it should be closed in this function
+ * @param bufferReceiveState state management objects for live transfer requests
+ * @param currentRequest current transfer request being worked on
+ * @param bounceBuffers bounce buffers (just received) containing data to be consumed
+ */
def doHandleBounceBufferReceive(tx: Transaction,
bufferReceiveState: BufferReceiveState,
currentRequest: PendingTransferRequest,
@@ -783,12 +783,12 @@ class RapidsShuffleClient(
}
/**
- * Hands [[table]] and [[buffer]] to the device storage/catalog, obtaining an id that can be
- * used to look up the buffer from the catalog going (e.g. from the iterator)
- * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data
- * @param meta [[TableMeta]] describing [[buffer]]
- * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog
- */
+ * Hands [[table]] and [[buffer]] to the device storage/catalog, obtaining an id that can be
+ * used to look up the buffer from the catalog going (e.g. from the iterator)
+ * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data
+ * @param meta [[TableMeta]] describing [[buffer]]
+ * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog
+ */
private[shuffle] def track(buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferId = {
val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId()
logDebug(s"Adding buffer id ${id} to catalog")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala
index 53e829aa2dd..a81d98537ec 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala
@@ -31,18 +31,18 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId}
/**
- * An Iterator over columnar batches that fetches blocks using [[RapidsShuffleClient]]s.
- *
- * A `transport` instance is used to make [[RapidsShuffleClient]]s that are able to fetch
- * blocks.
- *
- * @param localBlockManagerId the `BlockManagerId` for the local executor
- * @param rapidsConf plugin configuration
- * @param transport transport to use to fetch blocks
- * @param blocksByAddress blocks to fetch
- * @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark
- * shuffle metrics
- */
+ * An Iterator over columnar batches that fetches blocks using [[RapidsShuffleClient]]s.
+ *
+ * A `transport` instance is used to make [[RapidsShuffleClient]]s that are able to fetch
+ * blocks.
+ *
+ * @param localBlockManagerId the `BlockManagerId` for the local executor
+ * @param rapidsConf plugin configuration
+ * @param transport transport to use to fetch blocks
+ * @param blocksByAddress blocks to fetch
+ * @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark
+ * shuffle metrics
+ */
class RapidsShuffleIterator(
localBlockManagerId: BlockManagerId,
rapidsConf: RapidsConf,
@@ -54,26 +54,26 @@ class RapidsShuffleIterator(
with Logging {
/**
- * General trait encapsulating either a buffer or an error. Used to hand off batches
- * to tasks (in the good case), or exceptions (in the bad case)
- */
+ * General trait encapsulating either a buffer or an error. Used to hand off batches
+ * to tasks (in the good case), or exceptions (in the bad case)
+ */
trait ShuffleClientResult
/**
- * A result for a successful buffer received
- * @param bufferId - the shuffle received buffer id as tracked in the catalog
- */
+ * A result for a successful buffer received
+ * @param bufferId - the shuffle received buffer id as tracked in the catalog
+ */
case class BufferReceived(
bufferId: ShuffleReceivedBufferId) extends ShuffleClientResult
/**
- * A result for a failed attempt at receiving block metadata, or corresponding batches.
- * @param blockManagerId - the offending peer block manager id
- * @param blockId - shuffle block id that we were fetching
- * @param mapIndex - the mapIndex (as returned by the `MapOutputTracker` in
- * `blocksByAddress`
- * @param errorMessage - a human-friendly error to report
- */
+ * A result for a failed attempt at receiving block metadata, or corresponding batches.
+ * @param blockManagerId - the offending peer block manager id
+ * @param blockId - shuffle block id that we were fetching
+ * @param mapIndex - the mapIndex (as returned by the `MapOutputTracker` in
+ * `blocksByAddress`
+ * @param errorMessage - a human-friendly error to report
+ */
case class TransferError(
blockManagerId: BlockManagerId,
blockId: ShuffleBlockBatchId,
@@ -153,12 +153,12 @@ class RapidsShuffleIterator(
val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] =
blockIds.map { case (blockId, _, mapIndex) =>
/**
- * [[ShuffleBlockBatchId]] is an internal optimization in Spark, which will likely
- * never see it unless explicitly enabled.
- *
- * There are other things that can turn it off, but we really don't care too much
- * about it.
- */
+ * [[ShuffleBlockBatchId]] is an internal optimization in Spark, which will likely
+ * never see it unless explicitly enabled.
+ *
+ * There are other things that can turn it off, but we really don't care too much
+ * about it.
+ */
blockId match {
case sbbid: ShuffleBlockBatchId => BlockIdMapIndex(sbbid, mapIndex)
case sbid: ShuffleBlockId =>
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala
index 5edffb76551..e408749586c 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala
@@ -26,44 +26,44 @@ import org.apache.spark.internal.Logging
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockBatchId}
/**
- * Trait used for the server to get buffer metadata (for metadata requests), and
- * also to acquire a buffer (for transfer requests)
- */
+ * Trait used for the server to get buffer metadata (for metadata requests), and
+ * also to acquire a buffer (for transfer requests)
+ */
trait RapidsShuffleRequestHandler {
/**
- * This is a query into the manager to get the `TableMeta` corresponding to a
- * shuffle block.
- * @param shuffleBlockBatchId `ShuffleBlockBatchId` with (shuffleId, mapId,
- * startReduceId, endReduceId)
- * @return a sequence of `TableMeta` describing batches corresponding to a block.
- */
+ * This is a query into the manager to get the `TableMeta` corresponding to a
+ * shuffle block.
+ * @param shuffleBlockBatchId `ShuffleBlockBatchId` with (shuffleId, mapId,
+ * startReduceId, endReduceId)
+ * @return a sequence of `TableMeta` describing batches corresponding to a block.
+ */
def getShuffleBufferMetas(shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta]
/**
- * Acquires (locks w.r.t. the memory tier) a [[RapidsBuffer]] corresponding to a table id.
- * @param tableId the unique id for a table in the catalog
- * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer
- */
+ * Acquires (locks w.r.t. the memory tier) a [[RapidsBuffer]] corresponding to a table id.
+ * @param tableId the unique id for a table in the catalog
+ * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer
+ */
def acquireShuffleBuffer(tableId: Int): RapidsBuffer
}
/**
- * A server that replies to shuffle metadata messages, and issues device/host memory sends.
- *
- * A single command thread is used to orchestrate sends/receives and to remove
- * from transport's progress thread.
- *
- * @param transport the transport we were configured with
- * @param serverConnection a connection object, which contains functions to send/receive
- * @param originalShuffleServerId spark's `BlockManagerId` for this executor
- * @param requestHandler instance of [[RapidsShuffleRequestHandler]]
- * @param exec Executor used to handle tasks that take time, and should not be in the
- * transport's thread
- * @param copyExec Executor used to handle synchronous mem copies
- * @param bssExec Executor used to handle [[BufferSendState]]s that are waiting
- * for bounce buffers to become available
- * @param rapidsConf plugin configuration instance
- */
+ * A server that replies to shuffle metadata messages, and issues device/host memory sends.
+ *
+ * A single command thread is used to orchestrate sends/receives and to remove
+ * from transport's progress thread.
+ *
+ * @param transport the transport we were configured with
+ * @param serverConnection a connection object, which contains functions to send/receive
+ * @param originalShuffleServerId spark's `BlockManagerId` for this executor
+ * @param requestHandler instance of [[RapidsShuffleRequestHandler]]
+ * @param exec Executor used to handle tasks that take time, and should not be in the
+ * transport's thread
+ * @param copyExec Executor used to handle synchronous mem copies
+ * @param bssExec Executor used to handle [[BufferSendState]]s that are waiting
+ * for bounce buffers to become available
+ * @param rapidsConf plugin configuration instance
+ */
class RapidsShuffleServer(transport: RapidsShuffleTransport,
serverConnection: ServerConnection,
val originalShuffleServerId: BlockManagerId,
@@ -73,27 +73,27 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
bssExec: Executor,
rapidsConf: RapidsConf) extends AutoCloseable with Logging {
/**
- * On close, this is set to false to indicate that the server is shutting down.
- */
+ * On close, this is set to false to indicate that the server is shutting down.
+ */
private[this] var started = true
private object ShuffleServerOps {
/**
- * When a transfer request is received during a callback, the handle code is offloaded via this
- * event to the server thread.
- * @param tx the live transaction that should be closed by the handler
- * @param metaRequestBuffer contains the metadata request that should be closed by the
- * handler
- */
+ * When a transfer request is received during a callback, the handle code is offloaded via this
+ * event to the server thread.
+ * @param tx the live transaction that should be closed by the handler
+ * @param metaRequestBuffer contains the metadata request that should be closed by the
+ * handler
+ */
case class HandleMeta(tx: Transaction, metaRequestBuffer: RefCountedDirectByteBuffer)
/**
- * When transfer request is received (to begin sending buffers), the handling is offloaded via
- * this event on the server thread. Note that, [[BufferSendState]] encapsulates one more more
- * requests to send buffers, and [[HandleTransferRequest]] may be posted multiple times
- * in order to handle the request fully.
- * @param sendState instance of [[BufferSendState]] used to complete a transfer request.
- */
+ * When transfer request is received (to begin sending buffers), the handling is offloaded via
+ * this event on the server thread. Note that, [[BufferSendState]] encapsulates one more more
+ * requests to send buffers, and [[HandleTransferRequest]] may be posted multiple times
+ * in order to handle the request fully.
+ * @param sendState instance of [[BufferSendState]] used to complete a transfer request.
+ */
case class HandleTransferRequest(sendState: BufferSendState)
}
@@ -102,10 +102,10 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
private var port: Int = -1
/**
- * Returns a TCP port that is expected to respond to rapids shuffle protocol.
- * Throws if this server is not started yet, which is an illegal state.
- * @return the port
- */
+ * Returns a TCP port that is expected to respond to rapids shuffle protocol.
+ * Throws if this server is not started yet, which is an illegal state.
+ * @return the port
+ */
def getPort: Int = {
if (port == -1) {
throw new IllegalStateException("RapidsShuffleServer port is not initialized")
@@ -114,8 +114,8 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * Kick off the underlying connection, and listen for initial requests.
- */
+ * Kick off the underlying connection, and listen for initial requests.
+ */
def start(): Unit = {
port = serverConnection.startManagementPort(originalShuffleServerId.host)
// kick off our first receives
@@ -139,37 +139,37 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * Pushes a task onto the queue to be handled by the server executor.
- *
- * All callbacks handled in the server (from the transport) need to be offloaded into
- * this pool. Note, if this thread blocks we are blocking the progress thread of the transport.
- *
- * @param op One of the case classes in `ShuffleServerOps`
- */
+ * Pushes a task onto the queue to be handled by the server executor.
+ *
+ * All callbacks handled in the server (from the transport) need to be offloaded into
+ * this pool. Note, if this thread blocks we are blocking the progress thread of the transport.
+ *
+ * @param op One of the case classes in `ShuffleServerOps`
+ */
def asyncOrBlock(op: Any): Unit = {
exec.execute(() => handleOp(op))
}
/**
- * Pushes a task onto the queue to be handled by the server's copy executor.
- *
- * @note - at this stage, tasks in this pool can block (it will grow as needed)
- *
- * @param op One of the case classes in [[ShuffleServerOps]]
- */
+ * Pushes a task onto the queue to be handled by the server's copy executor.
+ *
+ * @note - at this stage, tasks in this pool can block (it will grow as needed)
+ *
+ * @param op One of the case classes in [[ShuffleServerOps]]
+ */
private[this] def asyncOnCopyThread(op: Any): Unit = {
copyExec.execute(() => handleOp(op))
}
/**
- * Keep a list of BufferSendState that are waiting for bounce buffers.
- */
+ * Keep a list of BufferSendState that are waiting for bounce buffers.
+ */
private[this] val bssQueue = new ConcurrentLinkedQueue[BufferSendState]()
/**
- * Executor that loops until it finds bounce buffers for [[BufferSendState]],
- * and when it does it hands them off to a thread pool for handling.
- */
+ * Executor that loops until it finds bounce buffers for [[BufferSendState]],
+ * and when it does it hands them off to a thread pool for handling.
+ */
bssExec.execute(() => {
while (started) {
var bss: BufferSendState = null
@@ -203,14 +203,14 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
})
/**
- * Handler for a metadata request. It queues request handlers for either
- * [[RequestType.MetadataRequest]] or [[RequestType.TransferRequest]], and re-issues
- * receives for either type of request.
- *
- * NOTE: This call must be non-blocking. It is called from the progress thread.
- *
- * @param requestType The request type received
- */
+ * Handler for a metadata request. It queues request handlers for either
+ * [[RequestType.MetadataRequest]] or [[RequestType.TransferRequest]], and re-issues
+ * receives for either type of request.
+ *
+ * NOTE: This call must be non-blocking. It is called from the progress thread.
+ *
+ * @param requestType The request type received
+ */
private def doIssueReceive(requestType: RequestType.Value): Unit = {
logDebug(s"Waiting for a new connection. Posting ${requestType} receive.")
val metaRequest = transport.getMetaBuffer(rapidsConf.shuffleMaxMetadataSize)
@@ -245,12 +245,12 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * Function to handle `MetadataRequest`s. It will populate and issue a
- * `MetadataResponse` response for the appropriate client.
- *
- * @param tx the inbound [[Transaction]]
- * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message.
- */
+ * Function to handle `MetadataRequest`s. It will populate and issue a
+ * `MetadataResponse` response for the appropriate client.
+ *
+ * @param tx the inbound [[Transaction]]
+ * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message.
+ */
def doHandleMeta(tx: Transaction, metaRequest: RefCountedDirectByteBuffer): Unit = {
val doHandleMetaRange = new NvtxRange("doHandleMeta", NvtxColor.PURPLE)
val start = System.currentTimeMillis()
@@ -278,9 +278,9 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * Handles the very first message that a client will send, in order to request Table/Buffer info.
- * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message.
- */
+ * Handles the very first message that a client will send, in order to request Table/Buffer info.
+ * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message.
+ */
def handleMetadataRequest(metaRequest: RefCountedDirectByteBuffer): Unit = {
try {
val req = ShuffleMetadata.getMetadataRequest(metaRequest.getBuffer())
@@ -340,43 +340,43 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * A helper case class to maintain the state associated with a transfer request initiated by
- * a `TransferRequest` metadata message.
- *
- * This class is *not thread safe*. The way the code is currently designed, bounce buffers
- * being used to send, or copied to, are acted on a sequential basis, in time and in space.
- *
- * Callers use this class, like so:
- *
- * 1) [[getBuffersToSend]]: is used to get bounce buffers that the server should .send on.
- * -- first time:
- * a) the corresponding catalog table is acquired,
- * b) bounce buffers are acquired,
- * c) data is copied from the original catalog table into the bounce buffers available
- * d) the length of the last bounce buffer is adjusted if it would satisfy the full
- * length of the catalog-backed buffer.
- * e) bounce buffers are returned
- *
- * -- subsequent times:
- * if we are not done sending the acquired table:
- * a) data is copied from the original catalog table into the bounce buffers available
- * at sequentially incrementing offsets.
- * b) the length of the last bounce buffer is adjusted if it would satisfy the full
- * length of the catalog-backed buffer.
- * c) bounce buffers are returned
- *
- * 2) [[freeBounceBuffersIfNecessary]]: called when a send finishes, in order to free bounce
- * buffers, if the current table is done sending.
- *
- * 3) [[close]]: used to free state as the [[BufferSendState]] object is no longer needed
- *
- * In terms of the lifecycle of this object, it begins with the client asking for transfers to
- * start, it lasts through all buffers being transmitted, and ultimately finishes when a
- * `TransferResponse` is sent back to the client.
- *
- * @param tx the original `Transaction` from the `TransferRequest`.
- * @param request a transfer request
- */
+ * A helper case class to maintain the state associated with a transfer request initiated by
+ * a `TransferRequest` metadata message.
+ *
+ * This class is *not thread safe*. The way the code is currently designed, bounce buffers
+ * being used to send, or copied to, are acted on a sequential basis, in time and in space.
+ *
+ * Callers use this class, like so:
+ *
+ * 1) [[getBuffersToSend]]: is used to get bounce buffers that the server should .send on.
+ * -- first time:
+ * a) the corresponding catalog table is acquired,
+ * b) bounce buffers are acquired,
+ * c) data is copied from the original catalog table into the bounce buffers available
+ * d) the length of the last bounce buffer is adjusted if it would satisfy the full
+ * length of the catalog-backed buffer.
+ * e) bounce buffers are returned
+ *
+ * -- subsequent times:
+ * if we are not done sending the acquired table:
+ * a) data is copied from the original catalog table into the bounce buffers available
+ * at sequentially incrementing offsets.
+ * b) the length of the last bounce buffer is adjusted if it would satisfy the full
+ * length of the catalog-backed buffer.
+ * c) bounce buffers are returned
+ *
+ * 2) [[freeBounceBuffersIfNecessary]]: called when a send finishes, in order to free bounce
+ * buffers, if the current table is done sending.
+ *
+ * 3) [[close]]: used to free state as the [[BufferSendState]] object is no longer needed
+ *
+ * In terms of the lifecycle of this object, it begins with the client asking for transfers to
+ * start, it lasts through all buffers being transmitted, and ultimately finishes when a
+ * `TransferResponse` is sent back to the client.
+ *
+ * @param tx the original `Transaction` from the `TransferRequest`.
+ * @param request a transfer request
+ */
class BufferSendState(tx: Transaction, request: RefCountedDirectByteBuffer)
extends AutoCloseable {
private[this] var currentTableIndex = -1
@@ -424,10 +424,10 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * Used to pop a [[BufferSendState]] from its queue if and only if there are bounce
- * buffers available
- * @return true if bounce buffers are available to proceed
- */
+ * Used to pop a [[BufferSendState]] from its queue if and only if there are bounce
+ * buffers available
+ * @return true if bounce buffers are available to proceed
+ */
def acquireBounceBuffersNonBlocking: Boolean = {
// we need to secure the table we are about to send, in order to get the correct flavor of
// bounce buffer
@@ -542,14 +542,14 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * This function returns bounce buffers that are ready to be sent. To get there,
- * it will:
- * 1) acquire the bounce buffers in the first place (if it hasn't already)
- * 2) copy data from the source buffer to the bounce buffers, updating the offset accordingly
- * 3) return either the full set of bounce buffers, or a subset, depending on how much is
- * left to send.
- * @return bounce buffers ready to be sent.
- */
+ * This function returns bounce buffers that are ready to be sent. To get there,
+ * it will:
+ * 1) acquire the bounce buffers in the first place (if it hasn't already)
+ * 2) copy data from the source buffer to the bounce buffers, updating the offset accordingly
+ * 3) return either the full set of bounce buffers, or a subset, depending on how much is
+ * left to send.
+ * @return bounce buffers ready to be sent.
+ */
def getBuffersToSend(): Seq[AddressLengthTag] = synchronized {
val alt = acquireTable()
@@ -604,11 +604,11 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport,
}
/**
- * This will kick off, or continue to work, a [[BufferSendState]] object
- * until all tables are fully transmitted.
- *
- * @param bufferSendState state object tracking sends needed to fulfill a TransferRequest
- */
+ * This will kick off, or continue to work, a [[BufferSendState]] object
+ * until all tables are fully transmitted.
+ *
+ * @param bufferSendState state object tracking sends needed to fulfill a TransferRequest
+ */
def doHandleTransferRequest(bufferSendState: BufferSendState): Unit = {
val doHandleTransferRequest = new NvtxRange("doHandleTransferRequest", NvtxColor.CYAN)
val start = System.currentTimeMillis()
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala
index 90d5b65b9a1..9dcff909b78 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala
@@ -28,19 +28,19 @@ import org.apache.spark.sql.rapids.storage.RapidsStorageUtils
import org.apache.spark.storage.BlockManagerId
/**
- * Class representing a memory location (address), length (in bytes), and a tag, for
- * tag based transports.
- * @param address the raw native address, used for transfers (from/to this buffer)
- * @param length the amount of bytes used to indicate to the transport how much to send/receive
- * @param tag a numeric tag identifying this buffer
- * @param memoryBuffer an optional `MemoryBuffer`
- */
+ * Class representing a memory location (address), length (in bytes), and a tag, for
+ * tag based transports.
+ * @param address the raw native address, used for transfers (from/to this buffer)
+ * @param length the amount of bytes used to indicate to the transport how much to send/receive
+ * @param tag a numeric tag identifying this buffer
+ * @param memoryBuffer an optional `MemoryBuffer`
+ */
class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
var memoryBuffer: Option[MemoryBuffer] = None) extends AutoCloseable with Logging {
/**
- * If this is a device memory buffer, we return true here.
- * @return whether this is a device memory buffer
- */
+ * If this is a device memory buffer, we return true here.
+ * @return whether this is a device memory buffer
+ */
def isDeviceBuffer(): Boolean = {
require(memoryBuffer.nonEmpty,
s"$this does not have a memory buffer")
@@ -48,10 +48,10 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
}
/**
- * Get a device memory buffer, this function will throw if it is not backed by a
- * `DeviceMemoryBuffer`
- * @return the backing device memory buffer
- */
+ * Get a device memory buffer, this function will throw if it is not backed by a
+ * `DeviceMemoryBuffer`
+ * @return the backing device memory buffer
+ */
def releaseDeviceMemoryBuffer(): DeviceMemoryBuffer = {
require(isDeviceBuffer(),
s"$this does not have a device memory buffer")
@@ -61,12 +61,12 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
}
/**
- * Copy to a destination [[AddressLengthTag]]
- * @param dstAlt the destination [[AddressLengthTag]]
- * @param srcOffset the offset to start copying from
- * @param toCopy amount to copy to dstAlt
- * @return amount of bytes copied
- */
+ * Copy to a destination [[AddressLengthTag]]
+ * @param dstAlt the destination [[AddressLengthTag]]
+ * @param srcOffset the offset to start copying from
+ * @param toCopy amount to copy to dstAlt
+ * @return amount of bytes copied
+ */
def cudaCopyTo(dstAlt: AddressLengthTag, srcOffset: Long, toCopy: Long): Long = {
require(srcOffset + toCopy <= length,
"Attempting to copy more bytes than the source buffer provides")
@@ -82,11 +82,11 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
}
/**
- * Copy from a source [[AddressLengthTag]]
- * @param srcAlt the source [[AddressLengthTag]]
- * @param dstOffset the offset in which to start copying from
- * @return amount of bytes copied
- */
+ * Copy from a source [[AddressLengthTag]]
+ * @param srcAlt the source [[AddressLengthTag]]
+ * @param dstOffset the offset in which to start copying from
+ * @return amount of bytes copied
+ */
def cudaCopyFrom(srcAlt: AddressLengthTag, dstOffset: Long): Long = {
require(dstOffset + srcAlt.length <= length,
s"Attempting to copy to a buffer that isn't big enough! $dstOffset + ${srcAlt.length} " +
@@ -102,17 +102,17 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
}
/**
- * A bounce buffer at the end of the transfer needs to be sized with the remaining
- * amount. This method is used to communicate that last length.
- * @param newLength the truncated length of this buffer for the transfer
- */
+ * A bounce buffer at the end of the transfer needs to be sized with the remaining
+ * amount. This method is used to communicate that last length.
+ * @param newLength the truncated length of this buffer for the transfer
+ */
def resetLength(newLength: Long): Unit = {
length = newLength
}
/**
- * Reset the length to the length supported by the backing buffer.
- */
+ * Reset the length to the length supported by the backing buffer.
+ */
def resetLength(): Unit = {
require(memoryBuffer.nonEmpty,
"Attempted to reset using an undefined memory buffer")
@@ -124,9 +124,9 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
}
/**
- * The Server will call close as the memory buffer stored here is owned by this
- * [[AddressLengthTag]]
- */
+ * The Server will call close as the memory buffer stored here is owned by this
+ * [[AddressLengthTag]]
+ */
override def close(): Unit = {
memoryBuffer.foreach(_.close())
}
@@ -134,11 +134,11 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long,
object AddressLengthTag {
/**
- * Construct an [[AddressLengthTag]] given a `MemoryBuffer` that is on the device, or host.
- * @param memoryBuffer the buffer the [[AddressLengthTag]] should point to
- * @param tag the transport tag to use to send/receive this buffer
- * @return an instance of [[AddressLengthTag]]
- */
+ * Construct an [[AddressLengthTag]] given a `MemoryBuffer` that is on the device, or host.
+ * @param memoryBuffer the buffer the [[AddressLengthTag]] should point to
+ * @param tag the transport tag to use to send/receive this buffer
+ * @return an instance of [[AddressLengthTag]]
+ */
def from(memoryBuffer: MemoryBuffer, tag: Long): AddressLengthTag = {
new AddressLengthTag(
memoryBuffer.getAddress,
@@ -148,12 +148,12 @@ object AddressLengthTag {
}
/**
- * Construct an [[AddressLengthTag]] given a `ByteBuffer`
- * This is used for metadata messages, and the buffers are direct.
- * @param byteBuffer the buffer the [[AddressLengthTag]] should point to
- * @param tag the transport tag to use to send/receive this buffer
- * @return an instance of [[AddressLengthTag]]
- */
+ * Construct an [[AddressLengthTag]] given a `ByteBuffer`
+ * This is used for metadata messages, and the buffers are direct.
+ * @param byteBuffer the buffer the [[AddressLengthTag]] should point to
+ * @param tag the transport tag to use to send/receive this buffer
+ * @return an instance of [[AddressLengthTag]]
+ */
def from(byteBuffer: ByteBuffer, tag: Long): AddressLengthTag = {
new AddressLengthTag(
TransportUtils.getAddress(byteBuffer),
@@ -171,139 +171,139 @@ trait MemoryRegistrationCallback {
}
/**
- * A server-side interface to the transport.
- *
- * The [[RapidsShuffleServer]] uses a [[ServerConnection]] to start the management port
- * in the transport (in order to allow for incoming connections)
- *
- * Note that [[ServerConnection]] is a [[Connection]], and so it inherits methods to send/receive
- * messages.
- */
+ * A server-side interface to the transport.
+ *
+ * The [[RapidsShuffleServer]] uses a [[ServerConnection]] to start the management port
+ * in the transport (in order to allow for incoming connections)
+ *
+ * Note that [[ServerConnection]] is a [[Connection]], and so it inherits methods to send/receive
+ * messages.
+ */
trait ServerConnection extends Connection {
/**
- * Starts a TCP management port, bound to `host`, on an ephemeral port (returned)
- * @param host host to bind to
- * @return integer ephemeral port that was bound
- */
+ * Starts a TCP management port, bound to `host`, on an ephemeral port (returned)
+ * @param host host to bind to
+ * @return integer ephemeral port that was bound
+ */
def startManagementPort(host: String): Int
/**
- * Function to send bounce buffers to a peer
- * @param peerExecutorId peer's executor id to target
- * @param bounceBuffers bounce buffers to send
- * @param cb callback to trigger once done
- * @return the [[Transaction]], which can be used to block wait for this send.
- */
+ * Function to send bounce buffers to a peer
+ * @param peerExecutorId peer's executor id to target
+ * @param bounceBuffers bounce buffers to send
+ * @param cb callback to trigger once done
+ * @return the [[Transaction]], which can be used to block wait for this send.
+ */
def send(peerExecutorId: Long,
bounceBuffers: Seq[AddressLengthTag],
cb: TransactionCallback): Transaction
/**
- * Function to send bounce buffers to a peer
- * @param peerExecutorId peer's executor id to target
- * @param header an [[AddressLengthTag]] containing a metadata message to send
- * @param cb callback to trigger once done
- * @return the [[Transaction]], which can be used to block wait for this send.
- */
+ * Function to send bounce buffers to a peer
+ * @param peerExecutorId peer's executor id to target
+ * @param header an [[AddressLengthTag]] containing a metadata message to send
+ * @param cb callback to trigger once done
+ * @return the [[Transaction]], which can be used to block wait for this send.
+ */
def send(peerExecutorId: Long,
header: AddressLengthTag,
cb: TransactionCallback): Transaction
}
/**
- * Currently supported request types in the transport
- */
+ * Currently supported request types in the transport
+ */
object RequestType extends Enumeration {
/**
- * A client will issue: `MetadataRequest`
- * A server will respond with: `MetadataResponse`
- */
+ * A client will issue: `MetadataRequest`
+ * A server will respond with: `MetadataResponse`
+ */
val MetadataRequest = Value
/**
- * A client will issue: `TransferRequest`
- * A server will respond with: `TransferResponse`
- */
+ * A client will issue: `TransferRequest`
+ * A server will respond with: `TransferResponse`
+ */
val TransferRequest = Value
}
/**
- * Trait used by the clients to interact with the transport.
- *
- * Note that this subclasses from [[Connection]].
- */
+ * Trait used by the clients to interact with the transport.
+ *
+ * Note that this subclasses from [[Connection]].
+ */
trait ClientConnection extends Connection {
/**
- * This performs a request/response, where the request is read from one
- * `AddressLengthTag` `request`, and the response is populated at the memory
- * described by `response`.
- *
- * @param request the populated request buffer [[AddressLengthTag]]
- * @param response the response buffer [[AddressLengthTag]] where the response will be
- * stored when the request succeeds.
- * @param cb callback to handle transaction status. If successful the memory described
- * using "response" will hold the response as expected, otherwise its contents
- * are not defined.
- * @return a transaction representing the request
- */
+ * This performs a request/response, where the request is read from one
+ * `AddressLengthTag` `request`, and the response is populated at the memory
+ * described by `response`.
+ *
+ * @param request the populated request buffer [[AddressLengthTag]]
+ * @param response the response buffer [[AddressLengthTag]] where the response will be
+ * stored when the request succeeds.
+ * @param cb callback to handle transaction status. If successful the memory described
+ * using "response" will hold the response as expected, otherwise its contents
+ * are not defined.
+ * @return a transaction representing the request
+ */
def request(request: AddressLengthTag,
response: AddressLengthTag,
cb: TransactionCallback): Transaction
/**
- * This function assigns tags that are valid for responses in this connection.
- * @return a Long tag to use for a response
- */
+ * This function assigns tags that are valid for responses in this connection.
+ * @return a Long tag to use for a response
+ */
def assignResponseTag: Long
/**
- * This function assigns tags for individual buffers to be received in this connection.
- * @param msgId an application-level id that should be unique to the peer.
- * @return a Long tag to use for a buffer
- */
+ * This function assigns tags for individual buffers to be received in this connection.
+ * @param msgId an application-level id that should be unique to the peer.
+ * @return a Long tag to use for a buffer
+ */
def assignBufferTag(msgId: Int): Long
/**
- * Get a long representing the executorId for the peer of this connection.
- * @return the executorId as a long
- */
+ * Get a long representing the executorId for the peer of this connection.
+ * @return the executorId as a long
+ */
def getPeerExecutorId: Long
}
/**
- * [[Connection]] trait defines what a "connection" must support.
- *
- * [[ServerConnection]] and [[ClientConnection]] extend this, adding a few methods needed
- * in each case.
- */
+ * [[Connection]] trait defines what a "connection" must support.
+ *
+ * [[ServerConnection]] and [[ClientConnection]] extend this, adding a few methods needed
+ * in each case.
+ */
trait Connection {
/**
- * Both the client and the server need to compose request tags depending on the
- * type of request being sent or handled.
- *
- * Note it is up to the implemented to compose the tag in whatever way makes
- * most sense for the underlying transport.
- *
- * @param requestType the type of request this tag is for
- * @return a Long tag to be used for this request
- */
+ * Both the client and the server need to compose request tags depending on the
+ * type of request being sent or handled.
+ *
+ * Note it is up to the implemented to compose the tag in whatever way makes
+ * most sense for the underlying transport.
+ *
+ * @param requestType the type of request this tag is for
+ * @return a Long tag to be used for this request
+ */
def composeRequestTag(requestType: RequestType.Value): Long
/**
- * Function to receive a buffer
- * @param header an [[AddressLengthTag]] to receive a metadata message
- * @param cb callback to trigger once this receive completes
- * @return a [[Transaction]] that can be used to block while this transaction is not done
- */
+ * Function to receive a buffer
+ * @param header an [[AddressLengthTag]] to receive a metadata message
+ * @param cb callback to trigger once this receive completes
+ * @return a [[Transaction]] that can be used to block while this transaction is not done
+ */
def receive(header: AddressLengthTag,
cb: TransactionCallback): Transaction
/**
- * Function to receive a buffer
- * @param bounceBuffers a sequence of [[AddressLengthTag]] where to receive data
- * @param cb callback to trigger once this receive completes
- * @return a [[Transaction]] that can be used to block while this transaction is not done
- */
+ * Function to receive a buffer
+ * @param bounceBuffers a sequence of [[AddressLengthTag]] where to receive data
+ * @param cb callback to trigger once this receive completes
+ * @return a [[Transaction]] that can be used to block while this transaction is not done
+ */
def receive(bounceBuffers: Seq[AddressLengthTag],
cb: TransactionCallback): Transaction
}
@@ -313,13 +313,13 @@ object TransactionStatus extends Enumeration {
}
/**
- * Case class representing stats for the a transaction
- * @param txTimeMs amount of time this [[Transaction]] took
- * @param sendSize amount of bytes sent
- * @param receiveSize amount of bytes received
- * @param sendThroughput send throughput in GB/sec
- * @param recvThroughput receive throughput in GB/sec
- */
+ * Case class representing stats for the a transaction
+ * @param txTimeMs amount of time this [[Transaction]] took
+ * @param sendSize amount of bytes sent
+ * @param receiveSize amount of bytes received
+ * @param sendThroughput send throughput in GB/sec
+ * @param recvThroughput receive throughput in GB/sec
+ */
case class TransactionStats(txTimeMs: Double,
sendSize: Long,
receiveSize: Long,
@@ -327,179 +327,179 @@ case class TransactionStats(txTimeMs: Double,
recvThroughput: Double)
/**
- * This trait represents a shuffle "transaction", and it is specific to a transfer (or set of
- * transfers).
- *
- * It is useful in that it groups a set of sends and receives requires in order to carry an action
- * against a peer. It can be used to find statistics about the transfer (bytes send/received,
- * throughput),
- * and it also can be waited on, for blocking clients.
- *
- * NOTE: a Transaction is thread safe w.r.t. a connection's callback. Calling methods on the
- * transaction
- * outside of [[waitForCompletion]] produces undefined behavior.
- */
+ * This trait represents a shuffle "transaction", and it is specific to a transfer (or set of
+ * transfers).
+ *
+ * It is useful in that it groups a set of sends and receives requires in order to carry an action
+ * against a peer. It can be used to find statistics about the transfer (bytes send/received,
+ * throughput),
+ * and it also can be waited on, for blocking clients.
+ *
+ * NOTE: a Transaction is thread safe w.r.t. a connection's callback. Calling methods on the
+ * transaction
+ * outside of [[waitForCompletion]] produces undefined behavior.
+ */
trait Transaction extends AutoCloseable {
/**
- * Get the status this transaction is in. Callbacks use this to handle various transaction states
- * (e.g. success, error, etc.)
- */
+ * Get the status this transaction is in. Callbacks use this to handle various transaction states
+ * (e.g. success, error, etc.)
+ */
def getStatus: TransactionStatus.Value
/**
- * Get error messages that could have occurred during the transaction.
- * @return returns an optional error message
- */
+ * Get error messages that could have occurred during the transaction.
+ * @return returns an optional error message
+ */
def getErrorMessage: Option[String]
/**
- * Get the statistics object (bytes sent/recv, tx time, and throughput are available)
- */
+ * Get the statistics object (bytes sent/recv, tx time, and throughput are available)
+ */
def getStats: TransactionStats
/**
- * Block until this transaction is completed.
- *
- * NOTE: not only does this include the transaction time, but it also includes the
- * code performed in the callback. If the callback would block, you could end up in a situation
- * of
- * deadlock.
- */
+ * Block until this transaction is completed.
+ *
+ * NOTE: not only does this include the transaction time, but it also includes the
+ * code performed in the callback. If the callback would block, you could end up in a situation
+ * of
+ * deadlock.
+ */
def waitForCompletion(): Unit
}
/**
- * This defines what a "transport" should support. The intention is to allow for
- * various transport implementations to exist, for different communication frameworks.
- *
- * It is an `AutoCloseable` and so the caller should close when the transport is no longer
- * needed.
- */
+ * This defines what a "transport" should support. The intention is to allow for
+ * various transport implementations to exist, for different communication frameworks.
+ *
+ * It is an `AutoCloseable` and so the caller should close when the transport is no longer
+ * needed.
+ */
trait RapidsShuffleTransport extends AutoCloseable {
/**
- * This function will connect (if not connected already) to a peer
- * described by `blockManagerId`. Connections are cached.
- *
- * @param localExecutorId the local executor id
- * @param blockManagerId the peer's block manager id
- * @return RapidsShuffleClient instance that can be used to interact with the peer
- */
+ * This function will connect (if not connected already) to a peer
+ * described by `blockManagerId`. Connections are cached.
+ *
+ * @param localExecutorId the local executor id
+ * @param blockManagerId the peer's block manager id
+ * @return RapidsShuffleClient instance that can be used to interact with the peer
+ */
def makeClient(localExecutorId: Long,
blockManagerId: BlockManagerId): RapidsShuffleClient
/**
- * This function should only be needed once. The caller creates *a* server and it is used
- * for the duration of the process.
- * @param requestHandler used to get metadata info, and acquire tables used in the shuffle.
- * @return the server instance
- */
+ * This function should only be needed once. The caller creates *a* server and it is used
+ * for the duration of the process.
+ * @param requestHandler used to get metadata info, and acquire tables used in the shuffle.
+ * @return the server instance
+ */
def makeServer(requestHandler: RapidsShuffleRequestHandler): RapidsShuffleServer
/**
- * Returns a wrapped buffer of size Long. The buffer may or may not be pooled.
- *
- * The caller should call .close() on the returned [[RefCountedDirectByteBuffer]]
- * when done.
- *
- * @param size size of buffer required
- * @return the ref counted buffer
- */
+ * Returns a wrapped buffer of size Long. The buffer may or may not be pooled.
+ *
+ * The caller should call .close() on the returned [[RefCountedDirectByteBuffer]]
+ * when done.
+ *
+ * @param size size of buffer required
+ * @return the ref counted buffer
+ */
def getMetaBuffer(size: Long): RefCountedDirectByteBuffer
/**
- * (throttle) Adds a set of requests to be throttled as limits allowed.
- * @param reqs requests to add to the throttle queue
- */
+ * (throttle) Adds a set of requests to be throttled as limits allowed.
+ * @param reqs requests to add to the throttle queue
+ */
def queuePending(reqs: Seq[PendingTransferRequest])
/**
- * (throttle) Signals that `bytesCompleted` are done, allowing more requests through the
- * throttle.
- * @param bytesCompleted amount of bytes handled
- */
+ * (throttle) Signals that `bytesCompleted` are done, allowing more requests through the
+ * throttle.
+ * @param bytesCompleted amount of bytes handled
+ */
def doneBytesInFlight(bytesCompleted: Long): Unit
// Bounce Buffer Management
/**
- * Get receive bounce buffers needed for a receive, limited by the amount of bytes
- * to be received, and a hard limit on the number of buffers set by the caller
- * using `totalRequired`.
- *
- * This function blocks if it can't satisfy the bounce buffer request.
- *
- * @param remaining amount of bytes remaining in the receive
- * @param totalRequired maximum no. of buffers that should be returned
- * @return a sequence of bounce buffers
- */
+ * Get receive bounce buffers needed for a receive, limited by the amount of bytes
+ * to be received, and a hard limit on the number of buffers set by the caller
+ * using `totalRequired`.
+ *
+ * This function blocks if it can't satisfy the bounce buffer request.
+ *
+ * @param remaining amount of bytes remaining in the receive
+ * @param totalRequired maximum no. of buffers that should be returned
+ * @return a sequence of bounce buffers
+ */
def getReceiveBounceBuffers(remaining: Long, totalRequired: Int): Seq[MemoryBuffer]
/**
- * Get receive bounce buffers needed for a receive, limited by the amount of bytes
- * to be received, and a hard limit on the number of buffers set by the caller
- * using `totalRequired`.
- *
- * This function is non blocking. If it can't satisfy the bounce buffer request, an empty
- * sequence is returned.
- *
- * @param remaining amount of bytes remaining in the receive
- * @param totalRequired maximum no. of buffers that should be returned
- * @return a sequence of bounce buffers, or empty if the request can't be satisfied
- */
+ * Get receive bounce buffers needed for a receive, limited by the amount of bytes
+ * to be received, and a hard limit on the number of buffers set by the caller
+ * using `totalRequired`.
+ *
+ * This function is non blocking. If it can't satisfy the bounce buffer request, an empty
+ * sequence is returned.
+ *
+ * @param remaining amount of bytes remaining in the receive
+ * @param totalRequired maximum no. of buffers that should be returned
+ * @return a sequence of bounce buffers, or empty if the request can't be satisfied
+ */
def tryGetReceiveBounceBuffers(remaining: Long, totalRequired: Int): Seq[MemoryBuffer]
/**
- * Free receive bounce buffers. These may be pooled and so reused by other requests.
- * @param bounceBuffers the bounce buffers to free
- */
+ * Free receive bounce buffers. These may be pooled and so reused by other requests.
+ * @param bounceBuffers the bounce buffers to free
+ */
def freeReceiveBounceBuffers(bounceBuffers: Seq[MemoryBuffer]): Unit
/**
- * Get send bounce buffers needed for a receive, limited by the amount of bytes
- * to be sent, and a hard limit on the number of buffers set by the caller
- * using `totalRequired`.
- *
- * This function blocks if it can't satisfy the bounce buffer request.
- *
- * @param deviceMemory true: returns a device buffer, false: returns a host buffer
- * @param remaining amount of bytes remaining in the receive
- * @param totalRequired maximum no. of buffers that should be returned
- * @return a sequence of bounce buffers
- */
+ * Get send bounce buffers needed for a receive, limited by the amount of bytes
+ * to be sent, and a hard limit on the number of buffers set by the caller
+ * using `totalRequired`.
+ *
+ * This function blocks if it can't satisfy the bounce buffer request.
+ *
+ * @param deviceMemory true: returns a device buffer, false: returns a host buffer
+ * @param remaining amount of bytes remaining in the receive
+ * @param totalRequired maximum no. of buffers that should be returned
+ * @return a sequence of bounce buffers
+ */
def getSendBounceBuffers(deviceMemory: Boolean, remaining: Long,
totalRequired: Int): Seq[MemoryBuffer]
/**
- * Get send bounce buffers needed for a receive, limited by the amount of bytes
- * to be sent, and a hard limit on the number of buffers set by the caller
- * using `totalRequired`.
- *
- * This function is non blocking. If it can't satisfy the bounce buffer request, an empty
- * sequence is returned.
- *
- * @param deviceMemory true: returns a device buffer, false: returns a host buffer
- * @param remaining amount of bytes remaining in the receive
- * @param totalRequired maximum no. of buffers that should be returned
- * @return a sequence of bounce buffers, or empty if the request can't be satisfied
- */
+ * Get send bounce buffers needed for a receive, limited by the amount of bytes
+ * to be sent, and a hard limit on the number of buffers set by the caller
+ * using `totalRequired`.
+ *
+ * This function is non blocking. If it can't satisfy the bounce buffer request, an empty
+ * sequence is returned.
+ *
+ * @param deviceMemory true: returns a device buffer, false: returns a host buffer
+ * @param remaining amount of bytes remaining in the receive
+ * @param totalRequired maximum no. of buffers that should be returned
+ * @return a sequence of bounce buffers, or empty if the request can't be satisfied
+ */
def tryGetSendBounceBuffers(deviceMemory: Boolean, remaining: Long,
totalRequired: Int): Seq[MemoryBuffer]
/**
- * Free send bounce buffers. These may be pooled and so reused by other requests.
- * @param bounceBuffers the bounce buffers to free
- */
+ * Free send bounce buffers. These may be pooled and so reused by other requests.
+ * @param bounceBuffers the bounce buffers to free
+ */
def freeSendBounceBuffers(bounceBuffers: Seq[MemoryBuffer]): Unit
}
/**
- * A pool of direct byte buffers, sized to be `bufferSize`.
- * This is a controlled leak at the moment, there is no reclaiming of buffers.
- *
- * NOTE: this is used for metadata messages.
- *
- * @param bufferSize the size of direct `ByteBuffer` to allocate.
- */
+ * A pool of direct byte buffers, sized to be `bufferSize`.
+ * This is a controlled leak at the moment, there is no reclaiming of buffers.
+ *
+ * NOTE: this is used for metadata messages.
+ *
+ * @param bufferSize the size of direct `ByteBuffer` to allocate.
+ */
class DirectByteBufferPool(bufferSize: Long) extends Logging {
val buffers = new ConcurrentLinkedQueue[RefCountedDirectByteBuffer]()
val high = new AtomicInteger(0)
@@ -526,18 +526,18 @@ class DirectByteBufferPool(bufferSize: Long) extends Logging {
}
/**
- * [[RefCountedDirectByteBuffer]] is a simple wrapper on top of a `ByteBuffer` that has been
- * allocated in direct mode.
- *
- * The pool is used to return the `ByteBuffer` to be reused, but not all of these buffers
- * are pooled (hence the argument is optional)
- *
- * The user should always close a [[RefCountedDirectByteBuffer]]. The close could hard destroy
- * the buffer, or return the object to the pool
- *
- * @param bb buffer to wrap
- * @param pool optional pool
- */
+ * [[RefCountedDirectByteBuffer]] is a simple wrapper on top of a `ByteBuffer` that has been
+ * allocated in direct mode.
+ *
+ * The pool is used to return the `ByteBuffer` to be reused, but not all of these buffers
+ * are pooled (hence the argument is optional)
+ *
+ * The user should always close a [[RefCountedDirectByteBuffer]]. The close could hard destroy
+ * the buffer, or return the object to the pool
+ *
+ * @param bb buffer to wrap
+ * @param pool optional pool
+ */
class RefCountedDirectByteBuffer(
bb: ByteBuffer,
pool: Option[DirectByteBufferPool] = None) extends AutoCloseable {
@@ -546,24 +546,24 @@ class RefCountedDirectByteBuffer(
var refCount: Int = 0
/**
- * Adds one to the ref count. Caller should call .close() when done
- * @return wrapped buffer
- */
+ * Adds one to the ref count. Caller should call .close() when done
+ * @return wrapped buffer
+ */
def acquire(): ByteBuffer = synchronized {
refCount = refCount + 1
bb
}
/**
- * Peeks into the wrapped buffer, without changing the ref count.
- * @return wrapped buffer
- */
+ * Peeks into the wrapped buffer, without changing the ref count.
+ * @return wrapped buffer
+ */
def getBuffer(): ByteBuffer = bb
/**
- * Decrements the ref count. If the ref count reaches 0, the buffer is
- * either returned to the (optional) pool or destroyed.
- */
+ * Decrements the ref count. If the ref count reaches 0, the buffer is
+ * either returned to the (optional) pool or destroyed.
+ */
override def close(): Unit = synchronized {
refCount = refCount - 1
if (refCount <= 0) {
@@ -576,10 +576,10 @@ class RefCountedDirectByteBuffer(
}
/**
- * Destroys the direct byte buffer forcefully, rather than wait for GC
- * to do it later. This helps with fragmentation issues with nearly depleted
- * native heap.
- */
+ * Destroys the direct byte buffer forcefully, rather than wait for GC
+ * to do it later. This helps with fragmentation issues with nearly depleted
+ * native heap.
+ */
def unsafeDestroy(): Unit = synchronized {
RapidsStorageUtils.dispose(bb)
}
@@ -590,8 +590,8 @@ class RefCountedDirectByteBuffer(
}
/**
- * A set of util functions used throughout
- */
+ * A set of util functions used throughout
+ */
object TransportUtils {
def formatTag(tag: Long): String = {
f"0x$tag%016X"
@@ -622,19 +622,19 @@ object TransportUtils {
object RapidsShuffleTransport extends Logging {
/**
- * Used in `BlockManagerId`s when returning a map status after a shuffle write to
- * let the readers know what TCP port to use to establish a transport connection.
- */
+ * Used in `BlockManagerId`s when returning a map status after a shuffle write to
+ * let the readers know what TCP port to use to establish a transport connection.
+ */
val BLOCK_MANAGER_ID_TOPO_PREFIX: String = "rapids"
/**
- * Returns an instance of `RapidsShuffleTransport`.
- * @note the single current implementation is `UCXShuffleTransport`.
- * @param shuffleServerId this is the original `BlockManagerId` that Spark has for this
- * executor
- * @param rapidsConf instance of `RapidsConf`
- * @return a transport instance to be used to create a server and clients.
- */
+ * Returns an instance of `RapidsShuffleTransport`.
+ * @note the single current implementation is `UCXShuffleTransport`.
+ * @param shuffleServerId this is the original `BlockManagerId` that Spark has for this
+ * executor
+ * @param rapidsConf instance of `RapidsConf`
+ * @return a transport instance to be used to create a server and clients.
+ */
def makeTransport(shuffleServerId: BlockManagerId,
rapidsConf: RapidsConf): RapidsShuffleTransport = {
val transportClass = try {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala
index 33519eb1fc5..24eb1620f8c 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala
@@ -37,12 +37,12 @@ import org.apache.spark.sql.types._
object OrcFilters extends OrcFiltersBase {
/**
- * Create ORC filter as a SearchArgument instance.
- *
- * NOTE: These filters should be pre-filtered by Spark to only contain the
- * filters convertible to ORC, so checking what is convertible is
- * not necessary here.
- */
+ * Create ORC filter as a SearchArgument instance.
+ *
+ * NOTE: These filters should be pre-filtered by Spark to only contain the
+ * filters convertible to ORC, so checking what is convertible is
+ * not necessary here.
+ */
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
// Combines all filters using `And` to produce a single conjunction
@@ -56,8 +56,8 @@ object OrcFilters extends OrcFiltersBase {
}
/**
- * Get PredicateLeafType which is corresponding to the given DataType.
- */
+ * Get PredicateLeafType which is corresponding to the given DataType.
+ */
private def getPredicateLeafType(dataType: DataType) = dataType match {
case BooleanType => PredicateLeaf.Type.BOOLEAN
case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
@@ -70,11 +70,11 @@ object OrcFilters extends OrcFiltersBase {
}
/**
- * Cast literal values for filters.
- *
- * We need to cast to long because ORC raises exceptions
- * at 'checkLiteralType' of SearchArgumentImpl.java.
- */
+ * Cast literal values for filters.
+ *
+ * We need to cast to long because ORC raises exceptions
+ * at 'checkLiteralType' of SearchArgumentImpl.java.
+ */
private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match {
case ByteType | ShortType | IntegerType | LongType =>
value.asInstanceOf[Number].longValue
@@ -86,13 +86,13 @@ object OrcFilters extends OrcFiltersBase {
}
/**
- * Build a SearchArgument and return the builder so far.
- *
- * @param dataTypeMap a map from the attribute name to its data type.
- * @param expression the input predicates, which should be fully convertible to SearchArgument.
- * @param builder the input SearchArgument.Builder.
- * @return the builder so far.
- */
+ * Build a SearchArgument and return the builder so far.
+ *
+ * @param dataTypeMap a map from the attribute name to its data type.
+ * @param expression the input predicates, which should be fully convertible to SearchArgument.
+ * @param builder the input SearchArgument.Builder.
+ * @return the builder so far.
+ */
private def buildSearchArgument(
dataTypeMap: Map[String, DataType],
expression: Filter,
@@ -122,13 +122,13 @@ object OrcFilters extends OrcFiltersBase {
}
/**
- * Build a SearchArgument for a leaf predicate and return the builder so far.
- *
- * @param dataTypeMap a map from the attribute name to its data type.
- * @param expression the input filter predicates.
- * @param builder the input SearchArgument.Builder.
- * @return the builder so far.
- */
+ * Build a SearchArgument for a leaf predicate and return the builder so far.
+ *
+ * @param dataTypeMap a map from the attribute name to its data type.
+ * @param expression the input filter predicates.
+ * @param builder the input SearchArgument.Builder.
+ * @return the builder so far.
+ */
private def buildLeafSearchArgument(
dataTypeMap: Map[String, DataType],
expression: Filter,
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala
index 08822cd041a..59097606276 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala
@@ -31,14 +31,14 @@ import org.apache.spark.util.CompletionIterator
trait ShuffleMetricsUpdater {
/**
- * Trait used as a way to expose the `ShuffleReadMetricsReporter` to the iterator.
- * @param fetchWaitTimeInMs this matches the CPU name (except for the units) but it is actually
- * the aggreagate amount of time a task is blocked, not working on
- * anything, waiting for data.
- * @param remoteBlocksFetched aggregate of number of `ShuffleBlockId`s fetched.
- * @param remoteBytesRead aggregate size of all contiguous buffers received
- * @param rowsFetched aggregate of number of rows received
- */
+ * Trait used as a way to expose the `ShuffleReadMetricsReporter` to the iterator.
+ * @param fetchWaitTimeInMs this matches the CPU name (except for the units) but it is actually
+ * the aggreagate amount of time a task is blocked, not working on
+ * anything, waiting for data.
+ * @param remoteBlocksFetched aggregate of number of `ShuffleBlockId`s fetched.
+ * @param remoteBytesRead aggregate size of all contiguous buffers received
+ * @param rowsFetched aggregate of number of rows received
+ */
def update(
fetchWaitTimeInMs: Long,
remoteBlocksFetched: Long,
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
index dc8a847a098..d8d602205e9 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
@@ -141,8 +141,8 @@ class RapidsCachingWriter[K, V](
}
/**
- * Used to remove shuffle buffers when the writing task detects an error, calling `stop(false)`
- */
+ * Used to remove shuffle buffers when the writing task detects an error, calling `stop(false)`
+ */
private def cleanStorage(): Unit = {
writtenBufferIds.foreach(catalog.removeBuffer)
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala
index 2f6db8f7fdd..b1c20fbd66d 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala
@@ -23,16 +23,16 @@ import org.apache.spark.storage.StorageUtils
object RapidsStorageUtils {
// scalastyle:off line.size.limit
/**
- * Calls into spark's `StorageUtils` to expose the [[dispose]] method.
- *
- * NOTE: This the spark code as of the writing of this function is:
- * https://github.com/apache/spark/blob/e9f3f62b2c0f521f3cc23fef381fc6754853ad4f/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala#L206
- *
- * If the implementation in Spark later on breaks our build, we may need to replicate
- * the dispose method here.
- *
- * @param buffer byte buffer to dispose
- */
+ * Calls into spark's `StorageUtils` to expose the [[dispose]] method.
+ *
+ * NOTE: This the spark code as of the writing of this function is:
+ * https://github.com/apache/spark/blob/e9f3f62b2c0f521f3cc23fef381fc6754853ad4f/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala#L206
+ *
+ * If the implementation in Spark later on breaks our build, we may need to replicate
+ * the dispose method here.
+ *
+ * @param buffer byte buffer to dispose
+ */
// scalastyle:on line.size.limit
def dispose(buffer: ByteBuffer): Unit = StorageUtils.dispose(buffer)
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
index 1fbfe8cb571..ba7e5863226 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala
@@ -519,16 +519,16 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char)
override def dataType: DataType = BooleanType
/**
- * Validate and convert SQL 'like' pattern to a cuDF regular expression.
- *
- * Underscores (_) are converted to '.' including newlines and percent signs (%)
- * are converted to '.*' including newlines, other characters are quoted literally or escaped.
- * An invalid pattern will throw an `IllegalArgumentException`.
- *
- * @param pattern the SQL pattern to convert
- * @param escapeChar the escape string contains one character.
- * @return the equivalent cuDF regular expression of the pattern
- */
+ * Validate and convert SQL 'like' pattern to a cuDF regular expression.
+ *
+ * Underscores (_) are converted to '.' including newlines and percent signs (%)
+ * are converted to '.*' including newlines, other characters are quoted literally or escaped.
+ * An invalid pattern will throw an `IllegalArgumentException`.
+ *
+ * @param pattern the SQL pattern to convert
+ * @param escapeChar the escape string contains one character.
+ * @return the equivalent cuDF regular expression of the pattern
+ */
def escapeLikeRegex(pattern: String, escapeChar: Char): String = {
val in = pattern.toIterator
val out = new StringBuilder()
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala
index 5947da4684f..25ed1764cf6 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala
@@ -47,8 +47,8 @@ class CsvScanSuite extends SparkQueryCompareTestSuite {
}
/**
- * Running with an inferred schema results in running things that are not columnar optimized.
- */
+ * Running with an inferred schema results in running things that are not columnar optimized.
+ */
ALLOW_NON_GPU_testSparkResultsAreEqual("Test CSV inferred schema",
intsFromCsvInferredSchema, Seq("FileSourceScanExec", "FilterExec", "CollectLimitExec",
"GreaterThan", "Length", "StringTrim", "LocalTableScanExec", "DeserializeToObjectExec",
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
index bc921edd3c7..32ddd666067 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
@@ -294,13 +294,13 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
}
/**
- * Runs a test defined by fun, using dataframe df.
- *
- * @param df the DataFrame to use as input
- * @param fun the function to transform the DataFrame (produces another DataFrame)
- * @param conf spark conf
- * @return tuple of (cpu results, gpu results) as arrays of Row
- */
+ * Runs a test defined by fun, using dataframe df.
+ *
+ * @param df the DataFrame to use as input
+ * @param fun the function to transform the DataFrame (produces another DataFrame)
+ * @param conf spark conf
+ * @return tuple of (cpu results, gpu results) as arrays of Row
+ */
def runOnCpuAndGpu(
df: SparkSession => DataFrame,
fun: DataFrame => DataFrame,