From 9f947b79f27a629f53d793fc74d04dcf58fe1cd2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 28 Jul 2020 16:40:13 -0600 Subject: [PATCH] Enforce NoScalaDoc rule in scalastyle checks (#449) * fix rule * fix violations * fix violations * fix violations * remove comment * re-order rules, add comment back * fix violations in tests module --- scalastyle-config.xml | 18 +- .../main/scala/ai/rapids/cudf/CudaUtil.scala | 18 +- ...arPartitionReaderWithPartitionValues.scala | 8 +- .../com/nvidia/spark/rapids/GpuOrcScan.scala | 140 ++--- .../nvidia/spark/rapids/GpuParquetScan.scala | 78 +-- .../spark/rapids/GpuRangePartitioner.scala | 26 +- .../spark/rapids/GpuRangePartitioning.scala | 16 +- .../spark/rapids/GpuSinglePartitioning.scala | 20 +- .../spark/rapids/HostColumnarToGpu.scala | 16 +- .../spark/rapids/HostMemoryStreams.scala | 26 +- .../com/nvidia/spark/rapids/MetaUtils.scala | 16 +- .../nvidia/spark/rapids/NvtxWithMetrics.scala | 6 +- .../com/nvidia/spark/rapids/Plugin.scala | 4 +- .../nvidia/spark/rapids/SamplingUtils.scala | 34 +- .../spark/rapids/ShuffleBufferCatalog.scala | 10 +- .../com/nvidia/spark/rapids/aggregate.scala | 78 +-- .../com/nvidia/spark/rapids/implicits.scala | 180 +++--- .../nvidia/spark/rapids/nullExpressions.scala | 28 +- .../rapids/shuffle/BounceBufferManager.scala | 60 +- .../rapids/shuffle/RapidsShuffleClient.scala | 426 ++++++------- .../shuffle/RapidsShuffleIterator.scala | 62 +- .../rapids/shuffle/RapidsShuffleServer.scala | 276 ++++----- .../shuffle/RapidsShuffleTransport.scala | 576 +++++++++--------- .../apache/spark/sql/rapids/OrcFilters.scala | 54 +- .../sql/rapids/RapidsCachingReader.scala | 16 +- .../rapids/RapidsShuffleInternalManager.scala | 4 +- .../rapids/storage/RapidsStorageUtils.scala | 20 +- .../spark/sql/rapids/stringFunctions.scala | 20 +- .../nvidia/spark/rapids/CsvScanSuite.scala | 4 +- .../rapids/SparkQueryCompareTestSuite.scala | 14 +- 30 files changed, 1127 insertions(+), 1127 deletions(-) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 981ecc560d4..845a11a69e5 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -96,19 +96,19 @@ You can also disable only one rule, by specifying its rule id, as specified in: - - - - - - - + enabled="true"> - (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] Use Javadoc style indentation for multiline comments + + + + + + + diff --git a/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala b/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala index 071547dcee0..db5a0746515 100755 --- a/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala +++ b/sql-plugin/src/main/scala/ai/rapids/cudf/CudaUtil.scala @@ -18,15 +18,15 @@ package ai.rapids.cudf object CudaUtil { /** - * Copy from `src` buffer, starting at `srcOffset`, - * to a destination buffer `dst` starting at `dstOffset`, - * `length` bytes, in the default stream. - * @param src source buffer - * @param srcOffset source offset - * @param dst destination buffer - * @param dstOffset destination offset - * @param length amount to copy - */ + * Copy from `src` buffer, starting at `srcOffset`, + * to a destination buffer `dst` starting at `dstOffset`, + * `length` bytes, in the default stream. + * @param src source buffer + * @param srcOffset source offset + * @param dst destination buffer + * @param dstOffset destination offset + * @param length amount to copy + */ def copy(src: MemoryBuffer, srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long): Unit = { Cuda.memcpy( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala index 2489902097f..4bb2c3ca017 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarPartitionReaderWithPartitionValues.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch /** - * A wrapper reader that always appends partition values to the ColumnarBatch produced by the input - * reader `fileReader`. Each scalar value is splatted to a column with the same number of - * rows as the batch returned by the reader. - */ + * A wrapper reader that always appends partition values to the ColumnarBatch produced by the input + * reader `fileReader`. Each scalar value is splatted to a column with the same number of + * rows as the batch returned by the reader. + */ class ColumnarPartitionReaderWithPartitionValues( fileReader: PartitionReader[ColumnarBatch], partitionValues: Array[Scalar]) extends PartitionReader[ColumnarBatch] { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 256bc2c1122..4fa0d7c3901 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -179,14 +179,14 @@ case class GpuOrcPartitionReaderFactory( object GpuOrcPartitionReader { /** - * This class describes a stripe that will appear in the ORC output memory file. - * - * @param infoBuilder builder for output stripe info that has been populated with - * all fields except those that can only be known when the file - * is being written (e.g.: file offset, compressed footer length) - * @param footer stripe footer - * @param inputDataRanges input file ranges (based at file offset 0) of stripe data - */ + * This class describes a stripe that will appear in the ORC output memory file. + * + * @param infoBuilder builder for output stripe info that has been populated with + * all fields except those that can only be known when the file + * is being written (e.g.: file offset, compressed footer length) + * @param footer stripe footer + * @param inputDataRanges input file ranges (based at file offset 0) of stripe data + */ private case class OrcOutputStripe( infoBuilder: OrcProto.StripeInformation.Builder, footer: OrcProto.StripeFooter, @@ -200,32 +200,32 @@ object GpuOrcPartitionReader { OrcProto.Stream.Kind.ROW_INDEX) /** - * This class holds fields needed to read and iterate over the OrcFile - * - * @param updatedReadSchema read schema mapped to the file's field names - * @param evolution ORC SchemaEvolution - * @param dataReader ORC DataReader - * @param orcReader ORC Input File Reader - * @param blockIterator An iterator over the ORC output stripes - */ + * This class holds fields needed to read and iterate over the OrcFile + * + * @param updatedReadSchema read schema mapped to the file's field names + * @param evolution ORC SchemaEvolution + * @param dataReader ORC DataReader + * @param orcReader ORC Input File Reader + * @param blockIterator An iterator over the ORC output stripes + */ private case class OrcPartitionReaderContext(updatedReadSchema: TypeDescription, evolution: SchemaEvolution, dataReader: DataReader, orcReader: Reader, blockIterator: BufferedIterator[OrcOutputStripe]) } /** - * A PartitionReader that reads an ORC file split on the GPU. - * - * Efficiently reading an ORC split on the GPU requires rebuilding the ORC file - * in memory such that only relevant data is present in the memory file. - * This avoids sending unnecessary data to the GPU and saves GPU memory. - * - * @param conf Hadoop configuration - * @param partFile file split to read - * @param dataSchema Spark schema of the file - * @param readDataSchema Spark schema of what will be read from the file - * @param debugDumpPrefix path prefix for dumping the memory file or null - */ + * A PartitionReader that reads an ORC file split on the GPU. + * + * Efficiently reading an ORC split on the GPU requires rebuilding the ORC file + * in memory such that only relevant data is present in the memory file. + * This avoids sending unnecessary data to the GPU and saves GPU memory. + * + * @param conf Hadoop configuration + * @param partFile file split to read + * @param dataSchema Spark schema of the file + * @param readDataSchema Spark schema of what will be read from the file + * @param debugDumpPrefix path prefix for dumping the memory file or null + */ class GpuOrcPartitionReader( conf: Configuration, partFile: PartitionedFile, @@ -319,13 +319,13 @@ class GpuOrcPartitionReader( } /** - * Build an integer array that maps the original ORC file's column IDs - * to column IDs in the memory file. Columns that are not present in - * the memory file will have a mapping of -1. - * - * @param evolution ORC SchemaEvolution - * @return column mapping array - */ + * Build an integer array that maps the original ORC file's column IDs + * to column IDs in the memory file. Columns that are not present in + * the memory file will have a mapping of -1. + * + * @param evolution ORC SchemaEvolution + * @return column mapping array + */ private def columnRemap(evolution: SchemaEvolution): Array[Int] = { val fileIncluded = evolution.getFileIncluded if (fileIncluded != null) { @@ -346,17 +346,17 @@ class GpuOrcPartitionReader( } /** - * Build the output stripe descriptors for what will appear in the ORC memory file. - * - * @param stripes descriptors for the ORC input stripes, filtered to what is in the split - * @param evolution ORC SchemaEvolution - * @param sargApp ORC search argument applier - * @param sargColumns mapping of ORC search argument columns - * @param ignoreNonUtf8BloomFilter true if bloom filters other than UTF8 should be ignored - * @param writerVersion writer version from the original ORC input file - * @param dataReader ORC DataReader - * @return output stripes descriptors - */ + * Build the output stripe descriptors for what will appear in the ORC memory file. + * + * @param stripes descriptors for the ORC input stripes, filtered to what is in the split + * @param evolution ORC SchemaEvolution + * @param sargApp ORC search argument applier + * @param sargColumns mapping of ORC search argument columns + * @param ignoreNonUtf8BloomFilter true if bloom filters other than UTF8 should be ignored + * @param writerVersion writer version from the original ORC input file + * @param dataReader ORC DataReader + * @return output stripes descriptors + */ private def buildOutputStripes( stripes: Seq[StripeInformation], evolution: SchemaEvolution, @@ -392,14 +392,14 @@ class GpuOrcPartitionReader( } /** - * Build the output stripe descriptor for a corresponding input stripe - * that should be copied to the ORC memory file. - * - * @param inputStripe input stripe descriptor - * @param inputFooter input stripe footer - * @param columnMapping mapping of input column IDs to output column IDs - * @return output stripe descriptor - */ + * Build the output stripe descriptor for a corresponding input stripe + * that should be copied to the ORC memory file. + * + * @param inputStripe input stripe descriptor + * @param inputFooter input stripe footer + * @param columnMapping mapping of input column IDs to output column IDs + * @return output stripe descriptor + */ private def buildOutputStripe( inputStripe: StripeInformation, inputFooter: OrcProto.StripeFooter, @@ -564,13 +564,13 @@ class GpuOrcPartitionReader( } /** - * Check if the read schema is compatible with the file schema. - * - * @param fileSchema input file's ORC schema - * @param readSchema ORC schema for what will be read - * @param isCaseAware true if field names are case-sensitive - * @return read schema mapped to the file's field names - */ + * Check if the read schema is compatible with the file schema. + * + * @param fileSchema input file's ORC schema + * @param readSchema ORC schema for what will be read + * @param isCaseAware true if field names are case-sensitive + * @return read schema mapped to the file's field names + */ private def checkSchemaCompatibility( fileSchema: TypeDescription, readSchema: TypeDescription, @@ -602,15 +602,15 @@ class GpuOrcPartitionReader( } /** - * Build an ORC search argument applier that can filter input file splits - * when predicate push-down filters have been specified. - * - * @param orcReader ORC input file reader - * @param readerOpts ORC reader options - * @param evolution ORC SchemaEvolution - * @param useUTCTimestamp true if timestamps are UTC - * @return the search argument applier and search argument column mapping - */ + * Build an ORC search argument applier that can filter input file splits + * when predicate push-down filters have been specified. + * + * @param orcReader ORC input file reader + * @param readerOpts ORC reader options + * @param evolution ORC SchemaEvolution + * @param useUTCTimestamp true if timestamps are UTC + * @return the search argument applier and search argument column mapping + */ private def getSearchApplier( orcReader: Reader, readerOpts: Reader.Options, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index d780fefcf4d..0e16dc3bf9b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -288,22 +288,22 @@ case class GpuParquetPartitionReaderFactory( } /** - * A PartitionReader that reads a Parquet file split on the GPU. - * - * Efficiently reading a Parquet split on the GPU requires re-constructing the Parquet file - * in memory that contains just the column chunks that are needed. This avoids sending - * unnecessary data to the GPU and saves GPU memory. - * - * @param conf the Hadoop configuration - * @param split the file split to read - * @param filePath the path to the Parquet file - * @param clippedBlocks the block metadata from the original Parquet file that has been clipped - * to only contain the column chunks to be read - * @param clippedParquetSchema the Parquet schema from the original Parquet file that has been - * clipped to contain only the columns to be read - * @param readDataSchema the Spark schema describing what will be read - * @param debugDumpPrefix a path prefix to use for dumping the fabricated Parquet data or null - */ + * A PartitionReader that reads a Parquet file split on the GPU. + * + * Efficiently reading a Parquet split on the GPU requires re-constructing the Parquet file + * in memory that contains just the column chunks that are needed. This avoids sending + * unnecessary data to the GPU and saves GPU memory. + * + * @param conf the Hadoop configuration + * @param split the file split to read + * @param filePath the path to the Parquet file + * @param clippedBlocks the block metadata from the original Parquet file that has been clipped + * to only contain the column chunks to be read + * @param clippedParquetSchema the Parquet schema from the original Parquet file that has been + * clipped to contain only the columns to be read + * @param readDataSchema the Spark schema describing what will be read + * @param debugDumpPrefix a path prefix to use for dumping the fabricated Parquet data or null + */ class ParquetPartitionReader( conf: Configuration, split: PartitionedFile, @@ -433,15 +433,15 @@ class ParquetPartitionReader( } /** - * Copies the data corresponding to the clipped blocks in the original file and compute the - * block metadata for the output. The output blocks will contain the same column chunk - * metadata but with the file offsets updated to reflect the new position of the column data - * as written to the output. - * - * @param in the input stream for the original Parquet file - * @param out the output stream to receive the data - * @return updated block metadata corresponding to the output - */ + * Copies the data corresponding to the clipped blocks in the original file and compute the + * block metadata for the output. The output blocks will contain the same column chunk + * metadata but with the file offsets updated to reflect the new position of the column data + * as written to the output. + * + * @param in the input stream for the original Parquet file + * @param out the output stream to receive the data + * @return updated block metadata corresponding to the output + */ private def copyBlocksData( in: FSDataInputStream, out: HostMemoryOutputStream, @@ -675,12 +675,12 @@ object ParquetPartitionReader { private case class CopyRange(offset: Long, length: Long) /** - * Build a new BlockMetaData - * - * @param rowCount the number of rows in this block - * @param columns the new column chunks to reference in the new BlockMetaData - * @return the new BlockMetaData - */ + * Build a new BlockMetaData + * + * @param rowCount the number of rows in this block + * @param columns the new column chunks to reference in the new BlockMetaData + * @return the new BlockMetaData + */ private def newParquetBlock( rowCount: Long, columns: Seq[ColumnChunkMetaData]): BlockMetaData = { @@ -698,14 +698,14 @@ object ParquetPartitionReader { } /** - * Trim block metadata to contain only the column chunks that occur in the specified columns. - * The column chunks that are returned are preserved verbatim - * (i.e.: file offsets remain unchanged). - * - * @param columnPaths the paths of columns to preserve - * @param blocks the block metadata from the original Parquet file - * @return the updated block metadata with undesired column chunks removed - */ + * Trim block metadata to contain only the column chunks that occur in the specified columns. + * The column chunks that are returned are preserved verbatim + * (i.e.: file offsets remain unchanged). + * + * @param columnPaths the paths of columns to preserve + * @param blocks the block metadata from the original Parquet file + * @return the updated block metadata with undesired column chunks removed + */ private[spark] def clipBlocks(columnPaths: Seq[ColumnPath], blocks: Seq[BlockMetaData]): Seq[BlockMetaData] = { val pathSet = columnPaths.toSet diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala index c24f2060a88..c6775d1073b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala @@ -33,12 +33,12 @@ import org.apache.spark.util.MutablePair class GpuRangePartitioner extends Serializable { var rangeBounds: Array[InternalRow] = _ /** - * Sketches the input RDD via reservoir sampling on each partition. - * - * @param rdd the input RDD to sketch - * @param sampleSizePerPartition max sample size per partition - * @return (total number of items, an array of (partitionId, number of items, sample)) - */ + * Sketches the input RDD via reservoir sampling on each partition. + * + * @param rdd the input RDD to sketch + * @param sampleSizePerPartition max sample size per partition + * @return (total number of items, an array of (partitionId, number of items, sample)) + */ def sketch[K: ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = { @@ -55,13 +55,13 @@ class GpuRangePartitioner extends Serializable { } /** - * Determines the bounds for range partitioning from candidates with weights indicating how many - * items each represents. Usually this is 1 over the probability used to sample this candidate. - * - * @param candidates unordered candidates with weights - * @param partitions number of partitions - * @return selected bounds - */ + * Determines the bounds for range partitioning from candidates with weights indicating how many + * items each represents. Usually this is 1 over the probability used to sample this candidate. + * + * @param candidates unordered candidates with weights + * @param partitions number of partitions + * @return selected bounds + */ def determineBounds[K: Ordering : ClassTag](candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala index deff3bee23b..ee23233e5b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioning.scala @@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructTyp import org.apache.spark.sql.vectorized.ColumnarBatch /** - * A GPU accelerated `org.apache.spark.sql.catalyst.plans.physical.Partitioning` that partitions - * sortable records by range into roughly equal ranges. The ranges are determined by sampling - * the content of the RDD passed in. - * - * @note The actual number of partitions created might not be the same - * as the `numPartitions` parameter, in the case where the number of sampled records is less than - * the value of `partitions`. - */ + * A GPU accelerated `org.apache.spark.sql.catalyst.plans.physical.Partitioning` that partitions + * sortable records by range into roughly equal ranges. The ranges are determined by sampling + * the content of the RDD passed in. + * + * @note The actual number of partitions created might not be the same + * as the `numPartitions` parameter, in the case where the number of sampled records is less than + * the value of `partitions`. + */ case class GpuRangePartitioning( gpuOrdering: Seq[SortOrder], diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala index 6caab17460a..cea112c3c5b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala @@ -23,16 +23,16 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuSinglePartitioning(expressions: Seq[Expression]) extends GpuExpression with GpuPartitioning { /** - * Returns the result of evaluating this expression on the entire `ColumnarBatch`. - * The result of calling this may be a single [[GpuColumnVector]] or a scalar value. - * Scalar values typically happen if they are a part of the expression i.e. col("a") + 100. - * In this case the 100 is a literal that Add would have to be able to handle. - * - * By convention any [[GpuColumnVector]] returned by [[columnarEval]] - * is owned by the caller and will need to be closed by them. This can happen by putting it - * into a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a - * temporary value. - */ + * Returns the result of evaluating this expression on the entire `ColumnarBatch`. + * The result of calling this may be a single [[GpuColumnVector]] or a scalar value. + * Scalar values typically happen if they are a part of the expression i.e. col("a") + 100. + * In this case the 100 is a literal that Add would have to be able to handle. + * + * By convention any [[GpuColumnVector]] returned by [[columnarEval]] + * is owned by the caller and will need to be closed by them. This can happen by putting it + * into a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a + * temporary value. + */ override def columnarEval(batch: ColumnarBatch): Any = { if (batch.numCols == 0) { Array(batch).zipWithIndex diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala index 7c04daea502..3f916d9e87d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala @@ -244,14 +244,14 @@ case class HostColumnarToGpu(child: SparkPlan, goal: CoalesceGoal) } /** - * Returns an RDD[ColumnarBatch] that when mapped over will produce GPU-side column vectors - * that are expected to be closed by its caller, not [[HostColumnarToGpu]]. - * - * The expectation is that the only valid instantiation of this node is - * as a child of a GPU exec node. - * - * @return an RDD of `ColumnarBatch` - */ + * Returns an RDD[ColumnarBatch] that when mapped over will produce GPU-side column vectors + * that are expected to be closed by its caller, not [[HostColumnarToGpu]]. + * + * The expectation is that the only valid instantiation of this node is + * as a child of a GPU exec node. + * + * @return an RDD of `ColumnarBatch` + */ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { val numInputRows = longMetric(NUM_INPUT_ROWS) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala index 5434bbc4741..02a7f8991ed 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostMemoryStreams.scala @@ -21,12 +21,12 @@ import java.io.{InputStream, IOException, OutputStream} import ai.rapids.cudf.HostMemoryBuffer /** - * An implementation of OutputStream that writes to a HostMemoryBuffer. - * - * NOTE: Closing this output stream does NOT close the buffer! - * - * @param buffer the buffer to receive written data - */ + * An implementation of OutputStream that writes to a HostMemoryBuffer. + * + * NOTE: Closing this output stream does NOT close the buffer! + * + * @param buffer the buffer to receive written data + */ class HostMemoryOutputStream(buffer: HostMemoryBuffer) extends OutputStream { private var pos: Long = 0 @@ -49,13 +49,13 @@ class HostMemoryOutputStream(buffer: HostMemoryBuffer) extends OutputStream { } /** - * An implementation of InputStream that reads from a HostMemoryBuffer. - * - * NOTE: Closing this input stream does NOT close the buffer! - * - * @param hmb the buffer from which to read data - * @param hmbLength the amount of data available in the buffer - */ + * An implementation of InputStream that reads from a HostMemoryBuffer. + * + * NOTE: Closing this input stream does NOT close the buffer! + * + * @param hmb the buffer from which to read data + * @param hmbLength the amount of data available in the buffer + */ class HostMemoryInputStream( hmb: HostMemoryBuffer, hmbLength: Long) extends InputStream { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index 63d6a2f8bef..bf352633e74 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -270,11 +270,11 @@ object ShuffleMetadata extends Logging{ } /** - * Given a sequence of `TableMeta`, re-lay the metas using the flat buffer builder in `fbb`. - * @param fbb builder to use - * @param tables sequence of `TableMeta` to copy - * @return an array of flat buffer offsets for the copied `TableMeta`s - */ + * Given a sequence of `TableMeta`, re-lay the metas using the flat buffer builder in `fbb`. + * @param fbb builder to use + * @param tables sequence of `TableMeta` to copy + * @return an array of flat buffer offsets for the copied `TableMeta`s + */ def copyTables(fbb: FlatBufferBuilder, tables: Seq[TableMeta]): Array[Int] = { tables.map { tableMeta => val buffMeta = tableMeta.bufferMeta() @@ -474,9 +474,9 @@ object ShuffleMetadata extends Logging{ /** - * Utility function to transfer a `TableMeta` to the heap, - * @todo we would like to look for an easier way, perhaps just a memcpy will do. - */ + * Utility function to transfer a `TableMeta` to the heap, + * @todo we would like to look for an easier way, perhaps just a memcpy will do. + */ def copyTableMetaToHeap(meta: TableMeta): TableMeta = { val fbb = ShuffleMetadata.getHeapBuilder val tables = ShuffleMetadata.copyTables(fbb, Seq(meta)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala index a4f01c0ff94..ea942d1339f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala @@ -21,9 +21,9 @@ import ai.rapids.cudf.{NvtxColor, NvtxRange} import org.apache.spark.sql.execution.metric.SQLMetric /** - * NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close - * by the amount of time spent in the range - */ + * NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close + * by the amount of time spent in the range + */ class NvtxWithMetrics(name: String, color: NvtxColor, val metric: SQLMetric) extends NvtxRange(name, color) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 4705aa25348..a94ec1b7597 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -43,8 +43,8 @@ case class ColumnarOverrideRules() extends ColumnarRule with Logging { } /** - * Extension point to enable GPU SQL processing. - */ + * Extension point to enable GPU SQL processing. + */ class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logWarning("Installing extensions to enable rapids GPU SQL support." + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala index cb043ffedb9..af982b79f41 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SamplingUtils.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow object SamplingUtils { /** - * Reservoir sampling implementation that also returns the input size. - * - * @param input input size - * @param k reservoir size - * @param seed random seed - * @return (samples, input size) - */ + * Reservoir sampling implementation that also returns the input size. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return (samples, input size) + */ def reservoirSampleAndCount[T: ClassTag]( input: Iterator[T], k: Int, @@ -77,16 +77,16 @@ object SamplingUtils { /** - * This class implements a XORShift random number generator algorithm - * Source: - * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. - * @see Paper - * This implementation is approximately 3.5 times faster than - * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due - * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class - * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG - * for each thread. - */ + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see Paper + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { def this() = this(System.nanoTime) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 0ac39d17c28..5856a9c7e30 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -203,11 +203,11 @@ class ShuffleBufferCatalog( } /** - * Remove a buffer and table given a buffer ID - * NOTE: This function is not thread safe! The caller should only invoke if - * the [[ShuffleBufferId]] being removed is not being utilized by another thread. - * @param id buffer identifier - */ + * Remove a buffer and table given a buffer ID + * NOTE: This function is not thread safe! The caller should only invoke if + * the [[ShuffleBufferId]] being removed is not being utilized by another thread. + * @param id buffer identifier + */ def removeBuffer(id: ShuffleBufferId): Unit = { tableMap.remove(id.tableId) catalog.removeBuffer(id) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index d36e29f167f..50b61b24008 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -206,24 +206,24 @@ class GpuSortAggregateMeta( } /** - * GpuHashAggregateExec - is the GPU version of HashAggregateExec, with some major differences: - * - it doesn't support spilling to disk - * - it doesn't support strings in the grouping key - * - it doesn't support count(col1, col2, ..., colN) - * - it doesn't support distinct - * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in - * EnsureRequirements to be able to add shuffle nodes - * @param groupingExpressions The expressions that, when applied to the input batch, return the - * grouping key - * @param aggregateExpressions The GpuAggregateExpression instances for this node - * @param aggregateAttributes References to each GpuAggregateExpression (attribute references) - * @param initialInputBufferOffset this is not used in the GPU version, but it's used to offset - * the slot in the aggregation buffer that aggregates should - * start referencing - * @param resultExpressions the expected output expression of this hash aggregate (which this - * node should project) - * @param child incoming plan (where we get input columns from) - */ + * GpuHashAggregateExec - is the GPU version of HashAggregateExec, with some major differences: + * - it doesn't support spilling to disk + * - it doesn't support strings in the grouping key + * - it doesn't support count(col1, col2, ..., colN) + * - it doesn't support distinct + * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in + * EnsureRequirements to be able to add shuffle nodes + * @param groupingExpressions The expressions that, when applied to the input batch, return the + * grouping key + * @param aggregateExpressions The GpuAggregateExpression instances for this node + * @param aggregateAttributes References to each GpuAggregateExpression (attribute references) + * @param initialInputBufferOffset this is not used in the GPU version, but it's used to offset + * the slot in the aggregation buffer that aggregates should + * start referencing + * @param resultExpressions the expected output expression of this hash aggregate (which this + * node should project) + * @param child incoming plan (where we get input columns from) + */ case class GpuHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[Expression], @@ -531,12 +531,12 @@ case class GpuHashAggregateExec( } /** - * concatenateBatches - given two ColumnarBatch instances, return a sequence of GpuColumnVector - * that is the concatenated columns of the two input batches. - * @param aggregatedInputCb this is an incoming batch - * @param aggregatedCb this is a batch that was kept for concatenation - * @return Seq[GpuColumnVector] with concatenated vectors - */ + * concatenateBatches - given two ColumnarBatch instances, return a sequence of GpuColumnVector + * that is the concatenated columns of the two input batches. + * @param aggregatedInputCb this is an incoming batch + * @param aggregatedCb this is a batch that was kept for concatenation + * @return Seq[GpuColumnVector] with concatenated vectors + */ private def concatenateBatches(aggregatedInputCb: ColumnarBatch, aggregatedCb: ColumnarBatch, concatTime: SQLMetric): Seq[GpuColumnVector] = { val nvtxRange = new NvtxWithMetrics("concatenateBatches", NvtxColor.BLUE, concatTime) @@ -568,19 +568,19 @@ case class GpuHashAggregateExec( private lazy val completeMode = uniqueModes.contains(Complete) /** - * getCudfAggregates returns a sequence of `cudf.Aggregate`, given the current mode - * `AggregateMode`, and a sequence of all expressions for this [[GpuHashAggregateExec]] - * node, we get all the expressions as that's important for us to be able to resolve the current - * ordinal for this cudf aggregate. - * - * Examples: - * fn = sum, min, max will always be Seq(fn) - * avg will be Seq(sum, count) for Partial mode, but Seq(sum, sum) for other modes - * count will be Seq(count) for Partial mode, but Seq(sum) for other modes - * - * @return Seq of `cudf.Aggregate`, with one or more aggregates that correspond to each - * expression in allExpressions - */ + * getCudfAggregates returns a sequence of `cudf.Aggregate`, given the current mode + * `AggregateMode`, and a sequence of all expressions for this [[GpuHashAggregateExec]] + * node, we get all the expressions as that's important for us to be able to resolve the current + * ordinal for this cudf aggregate. + * + * Examples: + * fn = sum, min, max will always be Seq(fn) + * avg will be Seq(sum, count) for Partial mode, but Seq(sum, sum) for other modes + * count will be Seq(count) for Partial mode, but Seq(sum) for other modes + * + * @return Seq of `cudf.Aggregate`, with one or more aggregates that correspond to each + * expression in allExpressions + */ def setupReferences(childAttr: AttributeSeq, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[GpuAggregateExpression]): BoundExpressionsModeAggregates = { @@ -866,8 +866,8 @@ case class GpuHashAggregateExec( "Row-based execution should not occur for this class") /** - * All the attributes that are used for this plan. NOT used for aggregation - */ + * All the attributes that are used for this plan. NOT used for aggregation + */ override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala index 799b3c5be32..ae484940de1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/implicits.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.vectorized.ColumnarBatch /** - * RapidsPluginImplicits, adds implicit functions for ColumnarBatch, Seq, Seq[AutoCloseable], - * and Array[AutoCloseable] that help make resource management easier within the project. - */ + * RapidsPluginImplicits, adds implicit functions for ColumnarBatch, Seq, Seq[AutoCloseable], + * and Array[AutoCloseable] that help make resource management easier within the project. + */ object RapidsPluginImplicits { import scala.language.implicitConversions @@ -60,11 +60,11 @@ object RapidsPluginImplicits { implicit class AutoCloseableSeq[A <: AutoCloseable](val in: SeqLike[A, _]) { /** - * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each - * element of the sequence, even if prior close calls fail. In case of failure in any of the - * close calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), - * if any. - */ + * safeClose: Is an implicit on a sequence of AutoCloseable classes that tries to close each + * element of the sequence, even if prior close calls fail. In case of failure in any of the + * close calls, an Exception is thrown containing the suppressed exceptions (getSuppressed), + * if any. + */ def safeClose(): Unit = if (in != null) { var closeException: Throwable = null in.foreach { element => @@ -93,57 +93,57 @@ object RapidsPluginImplicits { class MapsSafely[A, Repr] { /** - * safeMap: safeMap implementation that is leveraged by other type-specific implicits. - * - * safeMap has the added safety net that as you produce AutoCloseable values they are - * tracked, and if an exception were to occur within the maps's body, it will make every - * attempt to close each produced value. - * - * Note: safeMap will close in case of errors, without any knowledge of whether it should - * or not. - * Use safeMap only in these circumstances if `fn` increases the reference count, - * producing an AutoCloseable, and nothing else is tracking these references: - * a) seq.safeMap(x => {...; x.incRefCount; x}) - * b) seq.safeMap(x => GpuColumnVector.from(...)) - * - * Usage of safeMap chained with other maps is a bit confusing: - * - * seq.map(GpuColumnVector.from).safeMap(couldThrow) - * - * Will close the column vectors produced from couldThrow up until the time where safeMap - * throws. - * - * The correct pattern of usage in cases like this is: - * - * val closeTheseLater = seq.safeMap(GpuColumnVector.from) - * closeTheseLater.safeMap{ x => - * var success = false - * try { - * val res = couldThrow(x.incRefCount()) - * success = true - * res // return a ref count of 2 - * } finally { - * if (!success) { - * // in case of an error, we close x as part of normal error handling - * // the exception will be caught by the safeMap, and it will close all - * // AutoCloseables produced before x - * // - Sequence looks like: [2, 2, 2, ..., 2] + x, which has also has a refcount of 2 - * x.close() // x now has a ref count of 1, the rest of the sequence has 2s - * } - * } - * } // safeMap cleaned, and now everything has 1s for ref counts (as they were before) - * - * closeTheseLater.safeClose() // go from 1 to 0 in all things inside closeTheseLater - * - * @param in the Seq[A] to map on - * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) - * @tparam A the type of the elements in Seq - * @tparam B the type of the elements produced in the safeMap (should be subclasses of - * AutoCloseable) - * @tparam Repr the type of the input collection (needed by builder) - * @tparam That the type of the output collection (needed by builder) - * @return a sequence of B, in the success case - */ + * safeMap: safeMap implementation that is leveraged by other type-specific implicits. + * + * safeMap has the added safety net that as you produce AutoCloseable values they are + * tracked, and if an exception were to occur within the maps's body, it will make every + * attempt to close each produced value. + * + * Note: safeMap will close in case of errors, without any knowledge of whether it should + * or not. + * Use safeMap only in these circumstances if `fn` increases the reference count, + * producing an AutoCloseable, and nothing else is tracking these references: + * a) seq.safeMap(x => {...; x.incRefCount; x}) + * b) seq.safeMap(x => GpuColumnVector.from(...)) + * + * Usage of safeMap chained with other maps is a bit confusing: + * + * seq.map(GpuColumnVector.from).safeMap(couldThrow) + * + * Will close the column vectors produced from couldThrow up until the time where safeMap + * throws. + * + * The correct pattern of usage in cases like this is: + * + * val closeTheseLater = seq.safeMap(GpuColumnVector.from) + * closeTheseLater.safeMap{ x => + * var success = false + * try { + * val res = couldThrow(x.incRefCount()) + * success = true + * res // return a ref count of 2 + * } finally { + * if (!success) { + * // in case of an error, we close x as part of normal error handling + * // the exception will be caught by the safeMap, and it will close all + * // AutoCloseables produced before x + * // - Sequence looks like: [2, 2, 2, ..., 2] + x, which has also has a refcount of 2 + * x.close() // x now has a ref count of 1, the rest of the sequence has 2s + * } + * } + * } // safeMap cleaned, and now everything has 1s for ref counts (as they were before) + * + * closeTheseLater.safeClose() // go from 1 to 0 in all things inside closeTheseLater + * + * @param in the Seq[A] to map on + * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) + * @tparam A the type of the elements in Seq + * @tparam B the type of the elements produced in the safeMap (should be subclasses of + * AutoCloseable) + * @tparam Repr the type of the input collection (needed by builder) + * @tparam That the type of the output collection (needed by builder) + * @return a sequence of B, in the success case + */ protected def safeMap[B <: AutoCloseable, That]( in: SeqLike[A, Repr], fn: A => B) @@ -179,48 +179,48 @@ object RapidsPluginImplicits { implicit class AutoCloseableProducingSeq[A](val in: Seq[A]) extends MapsSafely[A, Seq[A]] { /** - * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of - * AutoCloseable. - * See [[MapsSafely.safeMap]] for a more detailed explanation. - * - * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) - * @tparam A the type of the elements in Seq - * @tparam B the type of the elements produced in the safeMap (should be subclasses of - * AutoCloseable) - * @return a sequence of B, in the success case - */ + * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of + * AutoCloseable. + * See [[MapsSafely.safeMap]] for a more detailed explanation. + * + * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) + * @tparam A the type of the elements in Seq + * @tparam B the type of the elements produced in the safeMap (should be subclasses of + * AutoCloseable) + * @return a sequence of B, in the success case + */ def safeMap[B <: AutoCloseable](fn: A => B): Seq[B] = super.safeMap(in, fn) } implicit class AutoCloseableProducingArray[A](val in: Array[A]) extends MapsSafely[A, Array[A]] { /** - * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of - * AutoCloseable. - * See [[MapsSafely.safeMap]] for a more detailed explanation. - * - * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) - * @tparam A the type of the elements in Seq - * @tparam B the type of the elements produced in the safeMap (should be subclasses of - * AutoCloseable) - * @return a sequence of B, in the success case - */ + * safeMap: implicit map on a Seq[A] that produces Seq[B], where B is a subclass of + * AutoCloseable. + * See [[MapsSafely.safeMap]] for a more detailed explanation. + * + * @param fn a function that takes A, and produces B (a subclass of AutoCloseable) + * @tparam A the type of the elements in Seq + * @tparam B the type of the elements produced in the safeMap (should be subclasses of + * AutoCloseable) + * @return a sequence of B, in the success case + */ def safeMap[B <: AutoCloseable : ClassTag](fn: A => B): Array[B] = super.safeMap(in, fn) } implicit class AutoCloseableFromBatchColumns(val in: ColumnarBatch) extends MapsSafely[Int, Seq[Int]] { /** - * safeMap: Is an implicit on ColumnarBatch, that lets you map over the columns - * of a batch as if the batch was a Seq[GpuColumnVector], iff safeMap's body is producing - * AutoCloseable (otherwise, it is not defined). - * - * See [[MapsSafely.safeMap]] for a more detailed explanation. - * - * @param fn a function that takes GpuColumnVector, and returns a subclass of AutoCloseable - * @tparam B the type of the elements produced in the safeMap (should be subclasses of - * AutoCloseable) - * @return a sequence of B, in the success case - */ + * safeMap: Is an implicit on ColumnarBatch, that lets you map over the columns + * of a batch as if the batch was a Seq[GpuColumnVector], iff safeMap's body is producing + * AutoCloseable (otherwise, it is not defined). + * + * See [[MapsSafely.safeMap]] for a more detailed explanation. + * + * @param fn a function that takes GpuColumnVector, and returns a subclass of AutoCloseable + * @tparam B the type of the elements produced in the safeMap (should be subclasses of + * AutoCloseable) + * @return a sequence of B, in the success case + */ def safeMap[B <: AutoCloseable](fn: GpuColumnVector => B): Seq[B] = { val colIds: Seq[Int] = 0 until in.numCols super.safeMap(colIds, (i: Int) => fn(in.column(i).asInstanceOf[GpuColumnVector])) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala index 7ba2da5d9ec..8aea6eda62b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala @@ -147,9 +147,9 @@ case class GpuIsNan(child: Expression) extends GpuUnaryExpression with Predicate } /** - * A GPU accelerated predicate that is evaluated to be true if there are at least `n` non-null - * and non-NaN values. - */ + * A GPU accelerated predicate that is evaluated to be true if there are at least `n` non-null + * and non-NaN values. + */ case class GpuAtLeastNNonNulls( n: Int, exprs: Seq[Expression]) @@ -160,17 +160,17 @@ case class GpuAtLeastNNonNulls( override def toString: String = s"GpuAtLeastNNulls(n, ${children.mkString(",")})" override def children: Seq[Expression] = exprs /** - * Returns the result of evaluating this expression on the entire - * `ColumnarBatch`. The result of calling this may be a single [[GpuColumnVector]] or a scalar - * value. Scalar values typically happen if they are a part of the expression - * i.e. col("a") + 100. - * In this case the 100 is a literal that Add would have to be able to handle. - * - * By convention any [[GpuColumnVector]] returned by [[columnarEval]] - * is owned by the caller and will need to be closed by them. This can happen by putting it into - * a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a - * temporary value. - */ + * Returns the result of evaluating this expression on the entire + * `ColumnarBatch`. The result of calling this may be a single [[GpuColumnVector]] or a scalar + * value. Scalar values typically happen if they are a part of the expression + * i.e. col("a") + 100. + * In this case the 100 is a literal that Add would have to be able to handle. + * + * By convention any [[GpuColumnVector]] returned by [[columnarEval]] + * is owned by the caller and will need to be closed by them. This can happen by putting it into + * a `ColumnarBatch` and closing the batch or by closing the vector directly if it is a + * temporary value. + */ override def columnarEval(batch: ColumnarBatch): Any = { val nonNullNanCounts : mutable.Queue[ColumnVector] = new mutable.Queue[ColumnVector]() try { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala index e05c7b6fb57..c7f0ed8ff58 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BounceBufferManager.scala @@ -23,15 +23,15 @@ import ai.rapids.cudf.MemoryBuffer import org.apache.spark.internal.Logging /** - * This classes manages a set of bounce buffers, that are instances of `MemoryBuffer`. - * The size/quantity of buffers is configurable, and so is the allocator. - * @param poolName a human-friendly name to use for debug logs - * @param bufferSize the size of buffer to use - * @param numBuffers the number of buffers to allocate on instantiation - * @param allocator function that takes a size, and returns a `MemoryBuffer` instance. - * @tparam T the specific type of MemoryBuffer i.e. `DeviceMemoryBuffer`, - * `HostMemoryBuffer`, etc. - */ + * This classes manages a set of bounce buffers, that are instances of `MemoryBuffer`. + * The size/quantity of buffers is configurable, and so is the allocator. + * @param poolName a human-friendly name to use for debug logs + * @param bufferSize the size of buffer to use + * @param numBuffers the number of buffers to allocate on instantiation + * @param allocator function that takes a size, and returns a `MemoryBuffer` instance. + * @tparam T the specific type of MemoryBuffer i.e. `DeviceMemoryBuffer`, + * `HostMemoryBuffer`, etc. + */ class BounceBufferManager[T <: MemoryBuffer]( poolName: String, val bufferSize: Long, @@ -47,11 +47,11 @@ class BounceBufferManager[T <: MemoryBuffer]( freeBufferMap.set(0, numBuffers) /** - * Acquires a [[MemoryBuffer]] from the pool. Blocks if the pool is empty. - * - * @note calls to this function should have a lock on this [[BounceBufferManager]] - * @return the acquired `MemoryBuffer` - */ + * Acquires a [[MemoryBuffer]] from the pool. Blocks if the pool is empty. + * + * @note calls to this function should have a lock on this [[BounceBufferManager]] + * @return the acquired `MemoryBuffer` + */ private def acquireBuffer(): MemoryBuffer = { val start = System.currentTimeMillis() var bufferIndex = freeBufferMap.nextSetBit(0) @@ -73,10 +73,10 @@ class BounceBufferManager[T <: MemoryBuffer]( } /** - * Acquire `possibleNumBuffers` buffers from the pool. This method will not block. - * @param possibleNumBuffers number of buffers to acquire - * @return a sequence of `MemoryBuffer`s, or empty if the request can't be satisfied - */ + * Acquire `possibleNumBuffers` buffers from the pool. This method will not block. + * @param possibleNumBuffers number of buffers to acquire + * @return a sequence of `MemoryBuffer`s, or empty if the request can't be satisfied + */ def acquireBuffersNonBlocking(possibleNumBuffers: Int): Seq[MemoryBuffer] = synchronized { if (numFree < possibleNumBuffers) { // would block @@ -89,11 +89,11 @@ class BounceBufferManager[T <: MemoryBuffer]( } /** - * Acquire `possibleNumBuffers` buffers from the pool. This method will block until - * it can get the buffers requested. - * @param possibleNumBuffers number of buffers to acquire - * @return a sequence of `MemoryBuffer`s - */ + * Acquire `possibleNumBuffers` buffers from the pool. This method will block until + * it can get the buffers requested. + * @param possibleNumBuffers number of buffers to acquire + * @return a sequence of `MemoryBuffer`s + */ def acquireBuffersBlocking(possibleNumBuffers: Int): Seq[MemoryBuffer] = synchronized { val res = (0 until possibleNumBuffers).map(_ => acquireBuffer()) logDebug(s"$poolName at acquire. Has numFree ${numFree}") @@ -101,9 +101,9 @@ class BounceBufferManager[T <: MemoryBuffer]( } /** - * Free a `MemoryBuffer`, putting it back into the pool. - * @param buffer the memory buffer to free - */ + * Free a `MemoryBuffer`, putting it back into the pool. + * @param buffer the memory buffer to free + */ def freeBuffer(buffer: MemoryBuffer): Unit = synchronized { require(buffer.getAddress >= rootBuffer.getAddress && (buffer.getAddress - rootBuffer.getAddress) % bufferSize == 0, @@ -119,10 +119,10 @@ class BounceBufferManager[T <: MemoryBuffer]( } /** - * Returns the root (backing) `MemoryBuffer`. This is used for a transport - * that wants to register the bounce buffers against hardware, for pinning purposes. - * @return the root (backing) memory buffer - */ + * Returns the root (backing) `MemoryBuffer`. This is used for a transport + * that wants to register the bounce buffers against hardware, for pinning purposes. + * @return the root (backing) memory buffer + */ def getRootBuffer(): MemoryBuffer = rootBuffer override def close(): Unit = rootBuffer.close() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index c6d913c37d4..9b1c10a6be1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -29,40 +29,40 @@ import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.storage.ShuffleBlockBatchId /** - * trait used by client consumers ([[RapidsShuffleIterator]]) to gather what the - * expected number of batches (Tables) is, and the ids for each table as they are received. - */ + * trait used by client consumers ([[RapidsShuffleIterator]]) to gather what the + * expected number of batches (Tables) is, and the ids for each table as they are received. + */ trait RapidsShuffleFetchHandler { /** - * After a response for shuffle metadata is received, the expected number of columnar - * batches is communicated back to the caller via this method. - * @param expectedBatches - number of columnar batches - */ + * After a response for shuffle metadata is received, the expected number of columnar + * batches is communicated back to the caller via this method. + * @param expectedBatches - number of columnar batches + */ def start(expectedBatches: Int): Unit /** - * Called when a buffer is received and has been handed off to the catalog. - * @param bufferId - a tracked shuffle buffer id - */ + * Called when a buffer is received and has been handed off to the catalog. + * @param bufferId - a tracked shuffle buffer id + */ def batchReceived(bufferId: ShuffleReceivedBufferId): Unit /** - * Called when the transport layer is not able to handle a fetch error for metadata - * or buffer fetches. - * - * @param errorMessage - a string containing an error message - */ + * Called when the transport layer is not able to handle a fetch error for metadata + * or buffer fetches. + * + * @param errorMessage - a string containing an error message + */ def transferError(errorMessage: String): Unit } /** - * A helper case class that describes a pending table. It allows - * the transport to schedule/throttle these requests to fit within the maximum bytes in flight. - * @param client client used to issue the requests - * @param tableMeta shuffle metadata describing the table - * @param tag a transport specific tag to use for this transfer - * @param handler a specific handler that is waiting for this batch - */ + * A helper case class that describes a pending table. It allows + * the transport to schedule/throttle these requests to fit within the maximum bytes in flight. + * @param client client used to issue the requests + * @param tableMeta shuffle metadata describing the table + * @param tag a transport specific tag to use for this transfer + * @param handler a specific handler that is waiting for this batch + */ case class PendingTransferRequest(client: RapidsShuffleClient, tableMeta: TableMeta, tag: Long, @@ -71,40 +71,40 @@ case class PendingTransferRequest(client: RapidsShuffleClient, } /** - * Class describing the state of a set of [[PendingTransferRequest]]s. - * - * This class is *not thread safe*. The way the code is currently designed, bounce buffers being - * used to receive, or copied from, are acted on a sequential basis, in time and in space. - * - * Callers use this class, like so: - * - * 1. [[getRequest]] - * - first call: - * - Initializes the state tracking the progress of `currentRequest` and our place in - * `requests` (e.g. offset, bytes remaining) - * - * - subsequent calls: - * - if `currentRequest` is not done, it will return the current request. And perform - * sanity checks. - * - if `currentRequest` is done, it will perform sanity checks, and advance to the next - * request in `requests` - * - * 2. [[consumeBuffers]] - * - first call: - * - The first time around for `currentRequest` it will allocate the actual full device - * buffer that we will copy to, copy data sequentially from the bounce buffer to the - * target buffer. - * - subsequent calls: - * - continue copy data sequentially from the bounce buffers passed in. - * - When `currentRequest` has been fully received, an optional `DeviceMemoryBuffer` is - * set, and returned. - * - * 3. [[close]] - * - once the caller calls close, the bounce buffers are returned to the pool. - * - * @param transport a transport, which in this case is used to free bounce buffers - * @param bounceMemoryBuffers a sequence of `MemoryBuffer` buffers to use for receives - */ + * Class describing the state of a set of [[PendingTransferRequest]]s. + * + * This class is *not thread safe*. The way the code is currently designed, bounce buffers being + * used to receive, or copied from, are acted on a sequential basis, in time and in space. + * + * Callers use this class, like so: + * + * 1. [[getRequest]] + * - first call: + * - Initializes the state tracking the progress of `currentRequest` and our place in + * `requests` (e.g. offset, bytes remaining) + * + * - subsequent calls: + * - if `currentRequest` is not done, it will return the current request. And perform + * sanity checks. + * - if `currentRequest` is done, it will perform sanity checks, and advance to the next + * request in `requests` + * + * 2. [[consumeBuffers]] + * - first call: + * - The first time around for `currentRequest` it will allocate the actual full device + * buffer that we will copy to, copy data sequentially from the bounce buffer to the + * target buffer. + * - subsequent calls: + * - continue copy data sequentially from the bounce buffers passed in. + * - When `currentRequest` has been fully received, an optional `DeviceMemoryBuffer` is + * set, and returned. + * + * 3. [[close]] + * - once the caller calls close, the bounce buffers are returned to the pool. + * + * @param transport a transport, which in this case is used to free bounce buffers + * @param bounceMemoryBuffers a sequence of `MemoryBuffer` buffers to use for receives + */ class BufferReceiveState( transport: RapidsShuffleTransport, val bounceMemoryBuffers: Seq[MemoryBuffer]) @@ -115,63 +115,63 @@ class BufferReceiveState( private[this] val requests = new ArrayBuffer[PendingTransferRequest]() /** - * Use by the transport to add to this [[BufferReceiveState]] requests it needs to handle. - * @param pendingTransferRequest request to add to this [[BufferReceiveState]] - */ + * Use by the transport to add to this [[BufferReceiveState]] requests it needs to handle. + * @param pendingTransferRequest request to add to this [[BufferReceiveState]] + */ def addRequest(pendingTransferRequest: PendingTransferRequest): Unit = synchronized { requests.append(pendingTransferRequest) } /** - * Holds the target device memory buffer. It is allocated at [[consumeBuffers]] when the first - * bounce buffer resolves, and it is also handed off to the caller in the same function. - */ + * Holds the target device memory buffer. It is allocated at [[consumeBuffers]] when the first + * bounce buffer resolves, and it is also handed off to the caller in the same function. + */ private[this] var buff: DeviceMemoryBuffer = null /** - * This is the address/length/tag object for the target buffer. It is used only to access a cuda - * synchronous copy method. We should do a regular copy from [[buff]] once we use the async - * method and add the cuda event synchronization. - */ + * This is the address/length/tag object for the target buffer. It is used only to access a cuda + * synchronous copy method. We should do a regular copy from [[buff]] once we use the async + * method and add the cuda event synchronization. + */ private[this] var alt: AddressLengthTag = null /** - * True iff this is the last request we are handling. It is used in [[isDone]] to find the - * stopping point. - */ + * True iff this is the last request we are handling. It is used in [[isDone]] to find the + * stopping point. + */ private[this] var lastRequest: Boolean = false /** - * The index into the [[requests]] sequence pointing to the current request we are handling. - * We handle requests sequentially. - */ + * The index into the [[requests]] sequence pointing to the current request we are handling. + * We handle requests sequentially. + */ private[this] var currentRequestIndex = -1 /** - * Amount of bytes left in the current request. - */ + * Amount of bytes left in the current request. + */ private[this] var currentRequestRemaining: Long = 0L /** - * Byte offset we are currently at. [[getRequest]] uses and resets it, and [[consumeBuffers]] - * updates it. - */ + * Byte offset we are currently at. [[getRequest]] uses and resets it, and [[consumeBuffers]] + * updates it. + */ private[this] var currentRequestOffset: Long = 0L /** - * The current request (TableMeta, Handler (iterator)) - */ + * The current request (TableMeta, Handler (iterator)) + */ private[this] var currentRequest: PendingTransferRequest = null /** - * To help debug "closed multiple times" issues - */ + * To help debug "closed multiple times" issues + */ private[this] var isClosed = false /** - * Becomes true when there is an error detected, allowing the client to close this - * [[BufferReceiveState]] prematurely. - */ + * Becomes true when there is an error detected, allowing the client to close this + * [[BufferReceiveState]] prematurely. + */ private[this] var errorOcurred = false override def toString: String = { @@ -182,14 +182,14 @@ class BufferReceiveState( } /** - * When a receive transaction is successful, this function is called to consume the bounce - * buffers received. - * - * @note If the target device buffer is not allocated, this function does so. - * @param bounceBuffers sequence of buffers that have been received - * @return if the current request is complete, returns a shallow copy of the target - * device buffer as an `Option`. Callers will need to close the buffer. - */ + * When a receive transaction is successful, this function is called to consume the bounce + * buffers received. + * + * @note If the target device buffer is not allocated, this function does so. + * @param bounceBuffers sequence of buffers that have been received + * @return if the current request is complete, returns a shallow copy of the target + * device buffer as an `Option`. Callers will need to close the buffer. + */ def consumeBuffers( bounceBuffers: Seq[AddressLengthTag]): Option[DeviceMemoryBuffer] = synchronized { var needsCleanup = true @@ -238,20 +238,20 @@ class BufferReceiveState( } /** - * Signals whether this [[BufferReceiveState]] is complete. - * @return boolean when at the last request, and that request is fully received - */ + * Signals whether this [[BufferReceiveState]] is complete. + * @return boolean when at the last request, and that request is fully received + */ def isDone: Boolean = synchronized { lastRequest && currentRequestDone } /** - * Return the current (making the next request current, if we are done) request and - * whether we advanced (want to kick off a transfer request to the server) - * - * @return returns the currently working transfer request, and - * true if this is a new request we should be asking the server to trigger - */ + * Return the current (making the next request current, if we are done) request and + * whether we advanced (want to kick off a transfer request to the server) + * + * @return returns the currently working transfer request, and + * true if this is a new request we should be asking the server to trigger + */ def getRequest: (PendingTransferRequest, Boolean) = synchronized { require(currentRequestIndex < requests.size, "Something went wrong while handling buffer receives. Asking for more buffers than expected") @@ -301,18 +301,18 @@ class BufferReceiveState( } /** - * Used to cut the subset of bounce buffers the client will need to issue receives with. - * - * This is a subset, because at the moment, the full target length worth of bounce buffers - * could have been acquired. If we are at the tail end of receive, and it could be fulfilled - * with 1 bounce buffer, for example, we would return 1 bounce buffer here, rather than the - * number of buffers in acquired. - * - * @note that these extra buffers should really just be freed as soon as we realize - * they are of no use. - * - * @return sequence of [[AddressLengthTag]] pointing to the receive bounce buffers. - */ + * Used to cut the subset of bounce buffers the client will need to issue receives with. + * + * This is a subset, because at the moment, the full target length worth of bounce buffers + * could have been acquired. If we are at the tail end of receive, and it could be fulfilled + * with 1 bounce buffer, for example, we would return 1 bounce buffer here, rather than the + * number of buffers in acquired. + * + * @note that these extra buffers should really just be freed as soon as we realize + * they are of no use. + * + * @return sequence of [[AddressLengthTag]] pointing to the receive bounce buffers. + */ def getBounceBuffersForReceive(): Seq[AddressLengthTag] = synchronized { var bounceBufferIx = 0 var bounceBuffersForTransfer = Seq[AddressLengthTag]() @@ -351,28 +351,28 @@ class BufferReceiveState( } /** - * The client makes requests via a [[Connection]] obtained from the [[RapidsShuffleTransport]]. - * - * The [[Connection]] follows a single threaded callback model, so this class posts operations - * to an `Executor` as quickly as it gets them from the [[Connection]]. - * - * This class handles fetch requests from [[RapidsShuffleIterator]], turning them into - * [[ShuffleMetadata]] messages, and shuffle `TransferRequest`s. - * - * Its counterpart is the [[RapidsShuffleServer]] on a specific peer executor, specified by - * `connection`. - * - * @param localExecutorId this id is sent to the server, it is required for the protocol as - * the server needs to pick an endpoint to send a response back to this - * executor. - * @param connection a connection object against a remote executor - * @param transport used to get metadata buffers and to work with the throttle mechanism - * @param exec Executor used to handle tasks that take time, and should not be in the - * transport's thread - * @param clientCopyExecutor Executors used to handle synchronous mem copies - * @param maximumMetadataSize The maximum metadata buffer size we are able to request - * TODO: this should go away - */ + * The client makes requests via a [[Connection]] obtained from the [[RapidsShuffleTransport]]. + * + * The [[Connection]] follows a single threaded callback model, so this class posts operations + * to an `Executor` as quickly as it gets them from the [[Connection]]. + * + * This class handles fetch requests from [[RapidsShuffleIterator]], turning them into + * [[ShuffleMetadata]] messages, and shuffle `TransferRequest`s. + * + * Its counterpart is the [[RapidsShuffleServer]] on a specific peer executor, specified by + * `connection`. + * + * @param localExecutorId this id is sent to the server, it is required for the protocol as + * the server needs to pick an endpoint to send a response back to this + * executor. + * @param connection a connection object against a remote executor + * @param transport used to get metadata buffers and to work with the throttle mechanism + * @param exec Executor used to handle tasks that take time, and should not be in the + * transport's thread + * @param clientCopyExecutor Executors used to handle synchronous mem copies + * @param maximumMetadataSize The maximum metadata buffer size we are able to request + * TODO: this should go away + */ class RapidsShuffleClient( localExecutorId: Long, connection: ClientConnection, @@ -385,48 +385,48 @@ class RapidsShuffleClient( object ShuffleClientOps { /** - * When a metadata response is received, this event is issued to handle it. - * @param tx the [[Transaction]] to be closed after consuming the response - * @param resp the response metadata buffer - * @param shuffleRequests blocks to be requested - * @param rapidsShuffleFetchHandler the handler (iterator) to callback to - */ + * When a metadata response is received, this event is issued to handle it. + * @param tx the [[Transaction]] to be closed after consuming the response + * @param resp the response metadata buffer + * @param shuffleRequests blocks to be requested + * @param rapidsShuffleFetchHandler the handler (iterator) to callback to + */ case class HandleMetadataResponse(tx: Transaction, resp: RefCountedDirectByteBuffer, shuffleRequests: Seq[ShuffleBlockBatchId], rapidsShuffleFetchHandler: RapidsShuffleFetchHandler) /** - * Represents retry due to metadata being larger than expected. - * - * @param shuffleRequests request to retry - * @param rapidsShuffleFetchHandler the handler (iterator) to callback to - * @param fullResponseSize response size to allocate to fit the server's response in full - */ + * Represents retry due to metadata being larger than expected. + * + * @param shuffleRequests request to retry + * @param rapidsShuffleFetchHandler the handler (iterator) to callback to + * @param fullResponseSize response size to allocate to fit the server's response in full + */ case class FetchRetry(shuffleRequests: Seq[ShuffleBlockBatchId], rapidsShuffleFetchHandler: RapidsShuffleFetchHandler, fullResponseSize: Long) /** - * Used to have this client handle the enclosed [[BufferReceiveState]] asynchronously. - * - * Until [[bufferReceiveState]] completes, this event will continue to get posted. - * - * @param bufferReceiveState object containing the state of pending requests to the peer - */ + * Used to have this client handle the enclosed [[BufferReceiveState]] asynchronously. + * + * Until [[bufferReceiveState]] completes, this event will continue to get posted. + * + * @param bufferReceiveState object containing the state of pending requests to the peer + */ case class IssueBufferReceives(bufferReceiveState: BufferReceiveState) /** - * When a buffer is received, this event is posted to remove from the progress thread - * the copy, and the callback into the iterator. - * - * Currently not used. There is a TODO below. - * - * @param tx live transaction for the buffer, to be closed after the buffer is handled - * @param bufferReceiveState the object maintaining state for receives - * @param currentRequest the request these bounce buffers belong to - * @param bounceBuffers buffers used in the transfer, which contain the fragment of the data - */ + * When a buffer is received, this event is posted to remove from the progress thread + * the copy, and the callback into the iterator. + * + * Currently not used. There is a TODO below. + * + * @param tx live transaction for the buffer, to be closed after the buffer is handled + * @param bufferReceiveState the object maintaining state for receives + * @param currentRequest the request these bounce buffers belong to + * @param bounceBuffers buffers used in the transfer, which contain the fragment of the data + */ case class HandleBounceBufferReceive(tx: Transaction, bufferReceiveState: BufferReceiveState, currentRequest: PendingTransferRequest, @@ -459,24 +459,24 @@ class RapidsShuffleClient( } /** - * Pushes a task onto the queue to be handled by the client's copy executor. - * - * @note - at this stage, tasks in this pool can block (it will grow as needed) - * - * @param op One of the case classes in [[ShuffleClientOps]] - */ + * Pushes a task onto the queue to be handled by the client's copy executor. + * + * @note - at this stage, tasks in this pool can block (it will grow as needed) + * + * @param op One of the case classes in [[ShuffleClientOps]] + */ private[this] def asyncOnCopyThread(op: Any): Unit = { clientCopyExecutor.execute(() => handleOp(op)) } /** - * Starts a fetch request for all the shuffleRequests, using `handler` to communicate - * events back to the iterator. - * - * @param shuffleRequests blocks to fetch - * @param handler iterator to callback to - * @param metadataSize metadata size to use for this fetch - */ + * Starts a fetch request for all the shuffleRequests, using `handler` to communicate + * events back to the iterator. + * + * @param shuffleRequests blocks to fetch + * @param handler iterator to callback to + * @param metadataSize metadata size to use for this fetch + */ def doFetch(shuffleRequests: Seq[ShuffleBlockBatchId], handler: RapidsShuffleFetchHandler, metadataSize: Long = maximumMetadataSize): Unit = { @@ -522,13 +522,13 @@ class RapidsShuffleClient( } /** - * Function to handle MetadataResponses, as a result of the [[HandleMetadataResponse]] event. - * - * @param tx live metadata response transaction to be closed in this handler - * @param resp response buffer, to be closed in this handler - * @param shuffleRequests blocks to fetch - * @param handler iterator to callback to - */ + * Function to handle MetadataResponses, as a result of the [[HandleMetadataResponse]] event. + * + * @param tx live metadata response transaction to be closed in this handler + * @param resp response buffer, to be closed in this handler + * @param shuffleRequests blocks to fetch + * @param handler iterator to callback to + */ private[this] def doHandleMetadataResponse( tx: Transaction, resp: RefCountedDirectByteBuffer, @@ -574,21 +574,21 @@ class RapidsShuffleClient( } /** - * Used by the transport, to schedule receives. The requests are sent to the executor for this - * client. - * @param bufferReceiveState object tracking the state of pending TransferRequests - */ + * Used by the transport, to schedule receives. The requests are sent to the executor for this + * client. + * @param bufferReceiveState object tracking the state of pending TransferRequests + */ def issueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = { asyncOnCopyThread(IssueBufferReceives(bufferReceiveState)) } /** - * Issues transfers requests (if the state of [[bufferReceiveState]] advances), or continue to - * work a current request (continue receiving bounce buffer sized chunks from a larger receive). - * @param bufferReceiveState object maintaining state of requests to be issued (current or - * future). The requests included in this state object originated in - * the transport's throttle logic. - */ + * Issues transfers requests (if the state of [[bufferReceiveState]] advances), or continue to + * work a current request (continue receiving bounce buffer sized chunks from a larger receive). + * @param bufferReceiveState object maintaining state of requests to be issued (current or + * future). The requests included in this state object originated in + * the transport's throttle logic. + */ private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = { logDebug(s"At issue for ${bufferReceiveState}, " + s"remaining: ${bufferReceiveState.getCurrentRequestRemaining}, " + @@ -691,12 +691,12 @@ class RapidsShuffleClient( } /** - * Feed into the throttle thread in the transport [[PendingTransferRequest]], to be - * issued later via the [[doIssueBufferReceives]] method. - * - * @param metaResponse metadata response flat buffer - * @param handler callback trait (the iterator implements this) - */ + * Feed into the throttle thread in the transport [[PendingTransferRequest]], to be + * issued later via the [[doIssueBufferReceives]] method. + * + * @param metaResponse metadata response flat buffer + * @param handler callback trait (the iterator implements this) + */ private def queueTransferRequests(metaResponse: MetadataResponse, handler: RapidsShuffleFetchHandler): Unit = { val allTables = metaResponse.tableMetasLength() @@ -725,14 +725,14 @@ class RapidsShuffleClient( } /** - * This function handles data received in `bounceBuffers`. The data should be copied out - * of the buffers, and the function should call into `bufferReceiveState` to advance its - * state (consumeBuffers) - * @param tx live transaction for these bounce buffers, it should be closed in this function - * @param bufferReceiveState state management objects for live transfer requests - * @param currentRequest current transfer request being worked on - * @param bounceBuffers bounce buffers (just received) containing data to be consumed - */ + * This function handles data received in `bounceBuffers`. The data should be copied out + * of the buffers, and the function should call into `bufferReceiveState` to advance its + * state (consumeBuffers) + * @param tx live transaction for these bounce buffers, it should be closed in this function + * @param bufferReceiveState state management objects for live transfer requests + * @param currentRequest current transfer request being worked on + * @param bounceBuffers bounce buffers (just received) containing data to be consumed + */ def doHandleBounceBufferReceive(tx: Transaction, bufferReceiveState: BufferReceiveState, currentRequest: PendingTransferRequest, @@ -783,12 +783,12 @@ class RapidsShuffleClient( } /** - * Hands [[table]] and [[buffer]] to the device storage/catalog, obtaining an id that can be - * used to look up the buffer from the catalog going (e.g. from the iterator) - * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data - * @param meta [[TableMeta]] describing [[buffer]] - * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog - */ + * Hands [[table]] and [[buffer]] to the device storage/catalog, obtaining an id that can be + * used to look up the buffer from the catalog going (e.g. from the iterator) + * @param buffer contiguous [[DeviceMemoryBuffer]] with the tables' data + * @param meta [[TableMeta]] describing [[buffer]] + * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog + */ private[shuffle] def track(buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferId = { val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId() logDebug(s"Adding buffer id ${id} to catalog") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 53e829aa2dd..a81d98537ec 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -31,18 +31,18 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} /** - * An Iterator over columnar batches that fetches blocks using [[RapidsShuffleClient]]s. - * - * A `transport` instance is used to make [[RapidsShuffleClient]]s that are able to fetch - * blocks. - * - * @param localBlockManagerId the `BlockManagerId` for the local executor - * @param rapidsConf plugin configuration - * @param transport transport to use to fetch blocks - * @param blocksByAddress blocks to fetch - * @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark - * shuffle metrics - */ + * An Iterator over columnar batches that fetches blocks using [[RapidsShuffleClient]]s. + * + * A `transport` instance is used to make [[RapidsShuffleClient]]s that are able to fetch + * blocks. + * + * @param localBlockManagerId the `BlockManagerId` for the local executor + * @param rapidsConf plugin configuration + * @param transport transport to use to fetch blocks + * @param blocksByAddress blocks to fetch + * @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark + * shuffle metrics + */ class RapidsShuffleIterator( localBlockManagerId: BlockManagerId, rapidsConf: RapidsConf, @@ -54,26 +54,26 @@ class RapidsShuffleIterator( with Logging { /** - * General trait encapsulating either a buffer or an error. Used to hand off batches - * to tasks (in the good case), or exceptions (in the bad case) - */ + * General trait encapsulating either a buffer or an error. Used to hand off batches + * to tasks (in the good case), or exceptions (in the bad case) + */ trait ShuffleClientResult /** - * A result for a successful buffer received - * @param bufferId - the shuffle received buffer id as tracked in the catalog - */ + * A result for a successful buffer received + * @param bufferId - the shuffle received buffer id as tracked in the catalog + */ case class BufferReceived( bufferId: ShuffleReceivedBufferId) extends ShuffleClientResult /** - * A result for a failed attempt at receiving block metadata, or corresponding batches. - * @param blockManagerId - the offending peer block manager id - * @param blockId - shuffle block id that we were fetching - * @param mapIndex - the mapIndex (as returned by the `MapOutputTracker` in - * `blocksByAddress` - * @param errorMessage - a human-friendly error to report - */ + * A result for a failed attempt at receiving block metadata, or corresponding batches. + * @param blockManagerId - the offending peer block manager id + * @param blockId - shuffle block id that we were fetching + * @param mapIndex - the mapIndex (as returned by the `MapOutputTracker` in + * `blocksByAddress` + * @param errorMessage - a human-friendly error to report + */ case class TransferError( blockManagerId: BlockManagerId, blockId: ShuffleBlockBatchId, @@ -153,12 +153,12 @@ class RapidsShuffleIterator( val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] = blockIds.map { case (blockId, _, mapIndex) => /** - * [[ShuffleBlockBatchId]] is an internal optimization in Spark, which will likely - * never see it unless explicitly enabled. - * - * There are other things that can turn it off, but we really don't care too much - * about it. - */ + * [[ShuffleBlockBatchId]] is an internal optimization in Spark, which will likely + * never see it unless explicitly enabled. + * + * There are other things that can turn it off, but we really don't care too much + * about it. + */ blockId match { case sbbid: ShuffleBlockBatchId => BlockIdMapIndex(sbbid, mapIndex) case sbid: ShuffleBlockId => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala index 5edffb76551..e408749586c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleServer.scala @@ -26,44 +26,44 @@ import org.apache.spark.internal.Logging import org.apache.spark.storage.{BlockManagerId, ShuffleBlockBatchId} /** - * Trait used for the server to get buffer metadata (for metadata requests), and - * also to acquire a buffer (for transfer requests) - */ + * Trait used for the server to get buffer metadata (for metadata requests), and + * also to acquire a buffer (for transfer requests) + */ trait RapidsShuffleRequestHandler { /** - * This is a query into the manager to get the `TableMeta` corresponding to a - * shuffle block. - * @param shuffleBlockBatchId `ShuffleBlockBatchId` with (shuffleId, mapId, - * startReduceId, endReduceId) - * @return a sequence of `TableMeta` describing batches corresponding to a block. - */ + * This is a query into the manager to get the `TableMeta` corresponding to a + * shuffle block. + * @param shuffleBlockBatchId `ShuffleBlockBatchId` with (shuffleId, mapId, + * startReduceId, endReduceId) + * @return a sequence of `TableMeta` describing batches corresponding to a block. + */ def getShuffleBufferMetas(shuffleBlockBatchId: ShuffleBlockBatchId): Seq[TableMeta] /** - * Acquires (locks w.r.t. the memory tier) a [[RapidsBuffer]] corresponding to a table id. - * @param tableId the unique id for a table in the catalog - * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer - */ + * Acquires (locks w.r.t. the memory tier) a [[RapidsBuffer]] corresponding to a table id. + * @param tableId the unique id for a table in the catalog + * @return a [[RapidsBuffer]] which is reference counted, and should be closed by the acquirer + */ def acquireShuffleBuffer(tableId: Int): RapidsBuffer } /** - * A server that replies to shuffle metadata messages, and issues device/host memory sends. - * - * A single command thread is used to orchestrate sends/receives and to remove - * from transport's progress thread. - * - * @param transport the transport we were configured with - * @param serverConnection a connection object, which contains functions to send/receive - * @param originalShuffleServerId spark's `BlockManagerId` for this executor - * @param requestHandler instance of [[RapidsShuffleRequestHandler]] - * @param exec Executor used to handle tasks that take time, and should not be in the - * transport's thread - * @param copyExec Executor used to handle synchronous mem copies - * @param bssExec Executor used to handle [[BufferSendState]]s that are waiting - * for bounce buffers to become available - * @param rapidsConf plugin configuration instance - */ + * A server that replies to shuffle metadata messages, and issues device/host memory sends. + * + * A single command thread is used to orchestrate sends/receives and to remove + * from transport's progress thread. + * + * @param transport the transport we were configured with + * @param serverConnection a connection object, which contains functions to send/receive + * @param originalShuffleServerId spark's `BlockManagerId` for this executor + * @param requestHandler instance of [[RapidsShuffleRequestHandler]] + * @param exec Executor used to handle tasks that take time, and should not be in the + * transport's thread + * @param copyExec Executor used to handle synchronous mem copies + * @param bssExec Executor used to handle [[BufferSendState]]s that are waiting + * for bounce buffers to become available + * @param rapidsConf plugin configuration instance + */ class RapidsShuffleServer(transport: RapidsShuffleTransport, serverConnection: ServerConnection, val originalShuffleServerId: BlockManagerId, @@ -73,27 +73,27 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, bssExec: Executor, rapidsConf: RapidsConf) extends AutoCloseable with Logging { /** - * On close, this is set to false to indicate that the server is shutting down. - */ + * On close, this is set to false to indicate that the server is shutting down. + */ private[this] var started = true private object ShuffleServerOps { /** - * When a transfer request is received during a callback, the handle code is offloaded via this - * event to the server thread. - * @param tx the live transaction that should be closed by the handler - * @param metaRequestBuffer contains the metadata request that should be closed by the - * handler - */ + * When a transfer request is received during a callback, the handle code is offloaded via this + * event to the server thread. + * @param tx the live transaction that should be closed by the handler + * @param metaRequestBuffer contains the metadata request that should be closed by the + * handler + */ case class HandleMeta(tx: Transaction, metaRequestBuffer: RefCountedDirectByteBuffer) /** - * When transfer request is received (to begin sending buffers), the handling is offloaded via - * this event on the server thread. Note that, [[BufferSendState]] encapsulates one more more - * requests to send buffers, and [[HandleTransferRequest]] may be posted multiple times - * in order to handle the request fully. - * @param sendState instance of [[BufferSendState]] used to complete a transfer request. - */ + * When transfer request is received (to begin sending buffers), the handling is offloaded via + * this event on the server thread. Note that, [[BufferSendState]] encapsulates one more more + * requests to send buffers, and [[HandleTransferRequest]] may be posted multiple times + * in order to handle the request fully. + * @param sendState instance of [[BufferSendState]] used to complete a transfer request. + */ case class HandleTransferRequest(sendState: BufferSendState) } @@ -102,10 +102,10 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, private var port: Int = -1 /** - * Returns a TCP port that is expected to respond to rapids shuffle protocol. - * Throws if this server is not started yet, which is an illegal state. - * @return the port - */ + * Returns a TCP port that is expected to respond to rapids shuffle protocol. + * Throws if this server is not started yet, which is an illegal state. + * @return the port + */ def getPort: Int = { if (port == -1) { throw new IllegalStateException("RapidsShuffleServer port is not initialized") @@ -114,8 +114,8 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * Kick off the underlying connection, and listen for initial requests. - */ + * Kick off the underlying connection, and listen for initial requests. + */ def start(): Unit = { port = serverConnection.startManagementPort(originalShuffleServerId.host) // kick off our first receives @@ -139,37 +139,37 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * Pushes a task onto the queue to be handled by the server executor. - * - * All callbacks handled in the server (from the transport) need to be offloaded into - * this pool. Note, if this thread blocks we are blocking the progress thread of the transport. - * - * @param op One of the case classes in `ShuffleServerOps` - */ + * Pushes a task onto the queue to be handled by the server executor. + * + * All callbacks handled in the server (from the transport) need to be offloaded into + * this pool. Note, if this thread blocks we are blocking the progress thread of the transport. + * + * @param op One of the case classes in `ShuffleServerOps` + */ def asyncOrBlock(op: Any): Unit = { exec.execute(() => handleOp(op)) } /** - * Pushes a task onto the queue to be handled by the server's copy executor. - * - * @note - at this stage, tasks in this pool can block (it will grow as needed) - * - * @param op One of the case classes in [[ShuffleServerOps]] - */ + * Pushes a task onto the queue to be handled by the server's copy executor. + * + * @note - at this stage, tasks in this pool can block (it will grow as needed) + * + * @param op One of the case classes in [[ShuffleServerOps]] + */ private[this] def asyncOnCopyThread(op: Any): Unit = { copyExec.execute(() => handleOp(op)) } /** - * Keep a list of BufferSendState that are waiting for bounce buffers. - */ + * Keep a list of BufferSendState that are waiting for bounce buffers. + */ private[this] val bssQueue = new ConcurrentLinkedQueue[BufferSendState]() /** - * Executor that loops until it finds bounce buffers for [[BufferSendState]], - * and when it does it hands them off to a thread pool for handling. - */ + * Executor that loops until it finds bounce buffers for [[BufferSendState]], + * and when it does it hands them off to a thread pool for handling. + */ bssExec.execute(() => { while (started) { var bss: BufferSendState = null @@ -203,14 +203,14 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } }) /** - * Handler for a metadata request. It queues request handlers for either - * [[RequestType.MetadataRequest]] or [[RequestType.TransferRequest]], and re-issues - * receives for either type of request. - * - * NOTE: This call must be non-blocking. It is called from the progress thread. - * - * @param requestType The request type received - */ + * Handler for a metadata request. It queues request handlers for either + * [[RequestType.MetadataRequest]] or [[RequestType.TransferRequest]], and re-issues + * receives for either type of request. + * + * NOTE: This call must be non-blocking. It is called from the progress thread. + * + * @param requestType The request type received + */ private def doIssueReceive(requestType: RequestType.Value): Unit = { logDebug(s"Waiting for a new connection. Posting ${requestType} receive.") val metaRequest = transport.getMetaBuffer(rapidsConf.shuffleMaxMetadataSize) @@ -245,12 +245,12 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * Function to handle `MetadataRequest`s. It will populate and issue a - * `MetadataResponse` response for the appropriate client. - * - * @param tx the inbound [[Transaction]] - * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. - */ + * Function to handle `MetadataRequest`s. It will populate and issue a + * `MetadataResponse` response for the appropriate client. + * + * @param tx the inbound [[Transaction]] + * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. + */ def doHandleMeta(tx: Transaction, metaRequest: RefCountedDirectByteBuffer): Unit = { val doHandleMetaRange = new NvtxRange("doHandleMeta", NvtxColor.PURPLE) val start = System.currentTimeMillis() @@ -278,9 +278,9 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * Handles the very first message that a client will send, in order to request Table/Buffer info. - * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. - */ + * Handles the very first message that a client will send, in order to request Table/Buffer info. + * @param metaRequest a [[RefCountedDirectByteBuffer]] holding a `MetadataRequest` message. + */ def handleMetadataRequest(metaRequest: RefCountedDirectByteBuffer): Unit = { try { val req = ShuffleMetadata.getMetadataRequest(metaRequest.getBuffer()) @@ -340,43 +340,43 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * A helper case class to maintain the state associated with a transfer request initiated by - * a `TransferRequest` metadata message. - * - * This class is *not thread safe*. The way the code is currently designed, bounce buffers - * being used to send, or copied to, are acted on a sequential basis, in time and in space. - * - * Callers use this class, like so: - * - * 1) [[getBuffersToSend]]: is used to get bounce buffers that the server should .send on. - * -- first time: - * a) the corresponding catalog table is acquired, - * b) bounce buffers are acquired, - * c) data is copied from the original catalog table into the bounce buffers available - * d) the length of the last bounce buffer is adjusted if it would satisfy the full - * length of the catalog-backed buffer. - * e) bounce buffers are returned - * - * -- subsequent times: - * if we are not done sending the acquired table: - * a) data is copied from the original catalog table into the bounce buffers available - * at sequentially incrementing offsets. - * b) the length of the last bounce buffer is adjusted if it would satisfy the full - * length of the catalog-backed buffer. - * c) bounce buffers are returned - * - * 2) [[freeBounceBuffersIfNecessary]]: called when a send finishes, in order to free bounce - * buffers, if the current table is done sending. - * - * 3) [[close]]: used to free state as the [[BufferSendState]] object is no longer needed - * - * In terms of the lifecycle of this object, it begins with the client asking for transfers to - * start, it lasts through all buffers being transmitted, and ultimately finishes when a - * `TransferResponse` is sent back to the client. - * - * @param tx the original `Transaction` from the `TransferRequest`. - * @param request a transfer request - */ + * A helper case class to maintain the state associated with a transfer request initiated by + * a `TransferRequest` metadata message. + * + * This class is *not thread safe*. The way the code is currently designed, bounce buffers + * being used to send, or copied to, are acted on a sequential basis, in time and in space. + * + * Callers use this class, like so: + * + * 1) [[getBuffersToSend]]: is used to get bounce buffers that the server should .send on. + * -- first time: + * a) the corresponding catalog table is acquired, + * b) bounce buffers are acquired, + * c) data is copied from the original catalog table into the bounce buffers available + * d) the length of the last bounce buffer is adjusted if it would satisfy the full + * length of the catalog-backed buffer. + * e) bounce buffers are returned + * + * -- subsequent times: + * if we are not done sending the acquired table: + * a) data is copied from the original catalog table into the bounce buffers available + * at sequentially incrementing offsets. + * b) the length of the last bounce buffer is adjusted if it would satisfy the full + * length of the catalog-backed buffer. + * c) bounce buffers are returned + * + * 2) [[freeBounceBuffersIfNecessary]]: called when a send finishes, in order to free bounce + * buffers, if the current table is done sending. + * + * 3) [[close]]: used to free state as the [[BufferSendState]] object is no longer needed + * + * In terms of the lifecycle of this object, it begins with the client asking for transfers to + * start, it lasts through all buffers being transmitted, and ultimately finishes when a + * `TransferResponse` is sent back to the client. + * + * @param tx the original `Transaction` from the `TransferRequest`. + * @param request a transfer request + */ class BufferSendState(tx: Transaction, request: RefCountedDirectByteBuffer) extends AutoCloseable { private[this] var currentTableIndex = -1 @@ -424,10 +424,10 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * Used to pop a [[BufferSendState]] from its queue if and only if there are bounce - * buffers available - * @return true if bounce buffers are available to proceed - */ + * Used to pop a [[BufferSendState]] from its queue if and only if there are bounce + * buffers available + * @return true if bounce buffers are available to proceed + */ def acquireBounceBuffersNonBlocking: Boolean = { // we need to secure the table we are about to send, in order to get the correct flavor of // bounce buffer @@ -542,14 +542,14 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * This function returns bounce buffers that are ready to be sent. To get there, - * it will: - * 1) acquire the bounce buffers in the first place (if it hasn't already) - * 2) copy data from the source buffer to the bounce buffers, updating the offset accordingly - * 3) return either the full set of bounce buffers, or a subset, depending on how much is - * left to send. - * @return bounce buffers ready to be sent. - */ + * This function returns bounce buffers that are ready to be sent. To get there, + * it will: + * 1) acquire the bounce buffers in the first place (if it hasn't already) + * 2) copy data from the source buffer to the bounce buffers, updating the offset accordingly + * 3) return either the full set of bounce buffers, or a subset, depending on how much is + * left to send. + * @return bounce buffers ready to be sent. + */ def getBuffersToSend(): Seq[AddressLengthTag] = synchronized { val alt = acquireTable() @@ -604,11 +604,11 @@ class RapidsShuffleServer(transport: RapidsShuffleTransport, } /** - * This will kick off, or continue to work, a [[BufferSendState]] object - * until all tables are fully transmitted. - * - * @param bufferSendState state object tracking sends needed to fulfill a TransferRequest - */ + * This will kick off, or continue to work, a [[BufferSendState]] object + * until all tables are fully transmitted. + * + * @param bufferSendState state object tracking sends needed to fulfill a TransferRequest + */ def doHandleTransferRequest(bufferSendState: BufferSendState): Unit = { val doHandleTransferRequest = new NvtxRange("doHandleTransferRequest", NvtxColor.CYAN) val start = System.currentTimeMillis() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala index 90d5b65b9a1..9dcff909b78 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala @@ -28,19 +28,19 @@ import org.apache.spark.sql.rapids.storage.RapidsStorageUtils import org.apache.spark.storage.BlockManagerId /** - * Class representing a memory location (address), length (in bytes), and a tag, for - * tag based transports. - * @param address the raw native address, used for transfers (from/to this buffer) - * @param length the amount of bytes used to indicate to the transport how much to send/receive - * @param tag a numeric tag identifying this buffer - * @param memoryBuffer an optional `MemoryBuffer` - */ + * Class representing a memory location (address), length (in bytes), and a tag, for + * tag based transports. + * @param address the raw native address, used for transfers (from/to this buffer) + * @param length the amount of bytes used to indicate to the transport how much to send/receive + * @param tag a numeric tag identifying this buffer + * @param memoryBuffer an optional `MemoryBuffer` + */ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, var memoryBuffer: Option[MemoryBuffer] = None) extends AutoCloseable with Logging { /** - * If this is a device memory buffer, we return true here. - * @return whether this is a device memory buffer - */ + * If this is a device memory buffer, we return true here. + * @return whether this is a device memory buffer + */ def isDeviceBuffer(): Boolean = { require(memoryBuffer.nonEmpty, s"$this does not have a memory buffer") @@ -48,10 +48,10 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } /** - * Get a device memory buffer, this function will throw if it is not backed by a - * `DeviceMemoryBuffer` - * @return the backing device memory buffer - */ + * Get a device memory buffer, this function will throw if it is not backed by a + * `DeviceMemoryBuffer` + * @return the backing device memory buffer + */ def releaseDeviceMemoryBuffer(): DeviceMemoryBuffer = { require(isDeviceBuffer(), s"$this does not have a device memory buffer") @@ -61,12 +61,12 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } /** - * Copy to a destination [[AddressLengthTag]] - * @param dstAlt the destination [[AddressLengthTag]] - * @param srcOffset the offset to start copying from - * @param toCopy amount to copy to dstAlt - * @return amount of bytes copied - */ + * Copy to a destination [[AddressLengthTag]] + * @param dstAlt the destination [[AddressLengthTag]] + * @param srcOffset the offset to start copying from + * @param toCopy amount to copy to dstAlt + * @return amount of bytes copied + */ def cudaCopyTo(dstAlt: AddressLengthTag, srcOffset: Long, toCopy: Long): Long = { require(srcOffset + toCopy <= length, "Attempting to copy more bytes than the source buffer provides") @@ -82,11 +82,11 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } /** - * Copy from a source [[AddressLengthTag]] - * @param srcAlt the source [[AddressLengthTag]] - * @param dstOffset the offset in which to start copying from - * @return amount of bytes copied - */ + * Copy from a source [[AddressLengthTag]] + * @param srcAlt the source [[AddressLengthTag]] + * @param dstOffset the offset in which to start copying from + * @return amount of bytes copied + */ def cudaCopyFrom(srcAlt: AddressLengthTag, dstOffset: Long): Long = { require(dstOffset + srcAlt.length <= length, s"Attempting to copy to a buffer that isn't big enough! $dstOffset + ${srcAlt.length} " + @@ -102,17 +102,17 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } /** - * A bounce buffer at the end of the transfer needs to be sized with the remaining - * amount. This method is used to communicate that last length. - * @param newLength the truncated length of this buffer for the transfer - */ + * A bounce buffer at the end of the transfer needs to be sized with the remaining + * amount. This method is used to communicate that last length. + * @param newLength the truncated length of this buffer for the transfer + */ def resetLength(newLength: Long): Unit = { length = newLength } /** - * Reset the length to the length supported by the backing buffer. - */ + * Reset the length to the length supported by the backing buffer. + */ def resetLength(): Unit = { require(memoryBuffer.nonEmpty, "Attempted to reset using an undefined memory buffer") @@ -124,9 +124,9 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, } /** - * The Server will call close as the memory buffer stored here is owned by this - * [[AddressLengthTag]] - */ + * The Server will call close as the memory buffer stored here is owned by this + * [[AddressLengthTag]] + */ override def close(): Unit = { memoryBuffer.foreach(_.close()) } @@ -134,11 +134,11 @@ class AddressLengthTag(val address: Long, var length: Long, val tag: Long, object AddressLengthTag { /** - * Construct an [[AddressLengthTag]] given a `MemoryBuffer` that is on the device, or host. - * @param memoryBuffer the buffer the [[AddressLengthTag]] should point to - * @param tag the transport tag to use to send/receive this buffer - * @return an instance of [[AddressLengthTag]] - */ + * Construct an [[AddressLengthTag]] given a `MemoryBuffer` that is on the device, or host. + * @param memoryBuffer the buffer the [[AddressLengthTag]] should point to + * @param tag the transport tag to use to send/receive this buffer + * @return an instance of [[AddressLengthTag]] + */ def from(memoryBuffer: MemoryBuffer, tag: Long): AddressLengthTag = { new AddressLengthTag( memoryBuffer.getAddress, @@ -148,12 +148,12 @@ object AddressLengthTag { } /** - * Construct an [[AddressLengthTag]] given a `ByteBuffer` - * This is used for metadata messages, and the buffers are direct. - * @param byteBuffer the buffer the [[AddressLengthTag]] should point to - * @param tag the transport tag to use to send/receive this buffer - * @return an instance of [[AddressLengthTag]] - */ + * Construct an [[AddressLengthTag]] given a `ByteBuffer` + * This is used for metadata messages, and the buffers are direct. + * @param byteBuffer the buffer the [[AddressLengthTag]] should point to + * @param tag the transport tag to use to send/receive this buffer + * @return an instance of [[AddressLengthTag]] + */ def from(byteBuffer: ByteBuffer, tag: Long): AddressLengthTag = { new AddressLengthTag( TransportUtils.getAddress(byteBuffer), @@ -171,139 +171,139 @@ trait MemoryRegistrationCallback { } /** - * A server-side interface to the transport. - * - * The [[RapidsShuffleServer]] uses a [[ServerConnection]] to start the management port - * in the transport (in order to allow for incoming connections) - * - * Note that [[ServerConnection]] is a [[Connection]], and so it inherits methods to send/receive - * messages. - */ + * A server-side interface to the transport. + * + * The [[RapidsShuffleServer]] uses a [[ServerConnection]] to start the management port + * in the transport (in order to allow for incoming connections) + * + * Note that [[ServerConnection]] is a [[Connection]], and so it inherits methods to send/receive + * messages. + */ trait ServerConnection extends Connection { /** - * Starts a TCP management port, bound to `host`, on an ephemeral port (returned) - * @param host host to bind to - * @return integer ephemeral port that was bound - */ + * Starts a TCP management port, bound to `host`, on an ephemeral port (returned) + * @param host host to bind to + * @return integer ephemeral port that was bound + */ def startManagementPort(host: String): Int /** - * Function to send bounce buffers to a peer - * @param peerExecutorId peer's executor id to target - * @param bounceBuffers bounce buffers to send - * @param cb callback to trigger once done - * @return the [[Transaction]], which can be used to block wait for this send. - */ + * Function to send bounce buffers to a peer + * @param peerExecutorId peer's executor id to target + * @param bounceBuffers bounce buffers to send + * @param cb callback to trigger once done + * @return the [[Transaction]], which can be used to block wait for this send. + */ def send(peerExecutorId: Long, bounceBuffers: Seq[AddressLengthTag], cb: TransactionCallback): Transaction /** - * Function to send bounce buffers to a peer - * @param peerExecutorId peer's executor id to target - * @param header an [[AddressLengthTag]] containing a metadata message to send - * @param cb callback to trigger once done - * @return the [[Transaction]], which can be used to block wait for this send. - */ + * Function to send bounce buffers to a peer + * @param peerExecutorId peer's executor id to target + * @param header an [[AddressLengthTag]] containing a metadata message to send + * @param cb callback to trigger once done + * @return the [[Transaction]], which can be used to block wait for this send. + */ def send(peerExecutorId: Long, header: AddressLengthTag, cb: TransactionCallback): Transaction } /** - * Currently supported request types in the transport - */ + * Currently supported request types in the transport + */ object RequestType extends Enumeration { /** - * A client will issue: `MetadataRequest` - * A server will respond with: `MetadataResponse` - */ + * A client will issue: `MetadataRequest` + * A server will respond with: `MetadataResponse` + */ val MetadataRequest = Value /** - * A client will issue: `TransferRequest` - * A server will respond with: `TransferResponse` - */ + * A client will issue: `TransferRequest` + * A server will respond with: `TransferResponse` + */ val TransferRequest = Value } /** - * Trait used by the clients to interact with the transport. - * - * Note that this subclasses from [[Connection]]. - */ + * Trait used by the clients to interact with the transport. + * + * Note that this subclasses from [[Connection]]. + */ trait ClientConnection extends Connection { /** - * This performs a request/response, where the request is read from one - * `AddressLengthTag` `request`, and the response is populated at the memory - * described by `response`. - * - * @param request the populated request buffer [[AddressLengthTag]] - * @param response the response buffer [[AddressLengthTag]] where the response will be - * stored when the request succeeds. - * @param cb callback to handle transaction status. If successful the memory described - * using "response" will hold the response as expected, otherwise its contents - * are not defined. - * @return a transaction representing the request - */ + * This performs a request/response, where the request is read from one + * `AddressLengthTag` `request`, and the response is populated at the memory + * described by `response`. + * + * @param request the populated request buffer [[AddressLengthTag]] + * @param response the response buffer [[AddressLengthTag]] where the response will be + * stored when the request succeeds. + * @param cb callback to handle transaction status. If successful the memory described + * using "response" will hold the response as expected, otherwise its contents + * are not defined. + * @return a transaction representing the request + */ def request(request: AddressLengthTag, response: AddressLengthTag, cb: TransactionCallback): Transaction /** - * This function assigns tags that are valid for responses in this connection. - * @return a Long tag to use for a response - */ + * This function assigns tags that are valid for responses in this connection. + * @return a Long tag to use for a response + */ def assignResponseTag: Long /** - * This function assigns tags for individual buffers to be received in this connection. - * @param msgId an application-level id that should be unique to the peer. - * @return a Long tag to use for a buffer - */ + * This function assigns tags for individual buffers to be received in this connection. + * @param msgId an application-level id that should be unique to the peer. + * @return a Long tag to use for a buffer + */ def assignBufferTag(msgId: Int): Long /** - * Get a long representing the executorId for the peer of this connection. - * @return the executorId as a long - */ + * Get a long representing the executorId for the peer of this connection. + * @return the executorId as a long + */ def getPeerExecutorId: Long } /** - * [[Connection]] trait defines what a "connection" must support. - * - * [[ServerConnection]] and [[ClientConnection]] extend this, adding a few methods needed - * in each case. - */ + * [[Connection]] trait defines what a "connection" must support. + * + * [[ServerConnection]] and [[ClientConnection]] extend this, adding a few methods needed + * in each case. + */ trait Connection { /** - * Both the client and the server need to compose request tags depending on the - * type of request being sent or handled. - * - * Note it is up to the implemented to compose the tag in whatever way makes - * most sense for the underlying transport. - * - * @param requestType the type of request this tag is for - * @return a Long tag to be used for this request - */ + * Both the client and the server need to compose request tags depending on the + * type of request being sent or handled. + * + * Note it is up to the implemented to compose the tag in whatever way makes + * most sense for the underlying transport. + * + * @param requestType the type of request this tag is for + * @return a Long tag to be used for this request + */ def composeRequestTag(requestType: RequestType.Value): Long /** - * Function to receive a buffer - * @param header an [[AddressLengthTag]] to receive a metadata message - * @param cb callback to trigger once this receive completes - * @return a [[Transaction]] that can be used to block while this transaction is not done - */ + * Function to receive a buffer + * @param header an [[AddressLengthTag]] to receive a metadata message + * @param cb callback to trigger once this receive completes + * @return a [[Transaction]] that can be used to block while this transaction is not done + */ def receive(header: AddressLengthTag, cb: TransactionCallback): Transaction /** - * Function to receive a buffer - * @param bounceBuffers a sequence of [[AddressLengthTag]] where to receive data - * @param cb callback to trigger once this receive completes - * @return a [[Transaction]] that can be used to block while this transaction is not done - */ + * Function to receive a buffer + * @param bounceBuffers a sequence of [[AddressLengthTag]] where to receive data + * @param cb callback to trigger once this receive completes + * @return a [[Transaction]] that can be used to block while this transaction is not done + */ def receive(bounceBuffers: Seq[AddressLengthTag], cb: TransactionCallback): Transaction } @@ -313,13 +313,13 @@ object TransactionStatus extends Enumeration { } /** - * Case class representing stats for the a transaction - * @param txTimeMs amount of time this [[Transaction]] took - * @param sendSize amount of bytes sent - * @param receiveSize amount of bytes received - * @param sendThroughput send throughput in GB/sec - * @param recvThroughput receive throughput in GB/sec - */ + * Case class representing stats for the a transaction + * @param txTimeMs amount of time this [[Transaction]] took + * @param sendSize amount of bytes sent + * @param receiveSize amount of bytes received + * @param sendThroughput send throughput in GB/sec + * @param recvThroughput receive throughput in GB/sec + */ case class TransactionStats(txTimeMs: Double, sendSize: Long, receiveSize: Long, @@ -327,179 +327,179 @@ case class TransactionStats(txTimeMs: Double, recvThroughput: Double) /** - * This trait represents a shuffle "transaction", and it is specific to a transfer (or set of - * transfers). - * - * It is useful in that it groups a set of sends and receives requires in order to carry an action - * against a peer. It can be used to find statistics about the transfer (bytes send/received, - * throughput), - * and it also can be waited on, for blocking clients. - * - * NOTE: a Transaction is thread safe w.r.t. a connection's callback. Calling methods on the - * transaction - * outside of [[waitForCompletion]] produces undefined behavior. - */ + * This trait represents a shuffle "transaction", and it is specific to a transfer (or set of + * transfers). + * + * It is useful in that it groups a set of sends and receives requires in order to carry an action + * against a peer. It can be used to find statistics about the transfer (bytes send/received, + * throughput), + * and it also can be waited on, for blocking clients. + * + * NOTE: a Transaction is thread safe w.r.t. a connection's callback. Calling methods on the + * transaction + * outside of [[waitForCompletion]] produces undefined behavior. + */ trait Transaction extends AutoCloseable { /** - * Get the status this transaction is in. Callbacks use this to handle various transaction states - * (e.g. success, error, etc.) - */ + * Get the status this transaction is in. Callbacks use this to handle various transaction states + * (e.g. success, error, etc.) + */ def getStatus: TransactionStatus.Value /** - * Get error messages that could have occurred during the transaction. - * @return returns an optional error message - */ + * Get error messages that could have occurred during the transaction. + * @return returns an optional error message + */ def getErrorMessage: Option[String] /** - * Get the statistics object (bytes sent/recv, tx time, and throughput are available) - */ + * Get the statistics object (bytes sent/recv, tx time, and throughput are available) + */ def getStats: TransactionStats /** - * Block until this transaction is completed. - * - * NOTE: not only does this include the transaction time, but it also includes the - * code performed in the callback. If the callback would block, you could end up in a situation - * of - * deadlock. - */ + * Block until this transaction is completed. + * + * NOTE: not only does this include the transaction time, but it also includes the + * code performed in the callback. If the callback would block, you could end up in a situation + * of + * deadlock. + */ def waitForCompletion(): Unit } /** - * This defines what a "transport" should support. The intention is to allow for - * various transport implementations to exist, for different communication frameworks. - * - * It is an `AutoCloseable` and so the caller should close when the transport is no longer - * needed. - */ + * This defines what a "transport" should support. The intention is to allow for + * various transport implementations to exist, for different communication frameworks. + * + * It is an `AutoCloseable` and so the caller should close when the transport is no longer + * needed. + */ trait RapidsShuffleTransport extends AutoCloseable { /** - * This function will connect (if not connected already) to a peer - * described by `blockManagerId`. Connections are cached. - * - * @param localExecutorId the local executor id - * @param blockManagerId the peer's block manager id - * @return RapidsShuffleClient instance that can be used to interact with the peer - */ + * This function will connect (if not connected already) to a peer + * described by `blockManagerId`. Connections are cached. + * + * @param localExecutorId the local executor id + * @param blockManagerId the peer's block manager id + * @return RapidsShuffleClient instance that can be used to interact with the peer + */ def makeClient(localExecutorId: Long, blockManagerId: BlockManagerId): RapidsShuffleClient /** - * This function should only be needed once. The caller creates *a* server and it is used - * for the duration of the process. - * @param requestHandler used to get metadata info, and acquire tables used in the shuffle. - * @return the server instance - */ + * This function should only be needed once. The caller creates *a* server and it is used + * for the duration of the process. + * @param requestHandler used to get metadata info, and acquire tables used in the shuffle. + * @return the server instance + */ def makeServer(requestHandler: RapidsShuffleRequestHandler): RapidsShuffleServer /** - * Returns a wrapped buffer of size Long. The buffer may or may not be pooled. - * - * The caller should call .close() on the returned [[RefCountedDirectByteBuffer]] - * when done. - * - * @param size size of buffer required - * @return the ref counted buffer - */ + * Returns a wrapped buffer of size Long. The buffer may or may not be pooled. + * + * The caller should call .close() on the returned [[RefCountedDirectByteBuffer]] + * when done. + * + * @param size size of buffer required + * @return the ref counted buffer + */ def getMetaBuffer(size: Long): RefCountedDirectByteBuffer /** - * (throttle) Adds a set of requests to be throttled as limits allowed. - * @param reqs requests to add to the throttle queue - */ + * (throttle) Adds a set of requests to be throttled as limits allowed. + * @param reqs requests to add to the throttle queue + */ def queuePending(reqs: Seq[PendingTransferRequest]) /** - * (throttle) Signals that `bytesCompleted` are done, allowing more requests through the - * throttle. - * @param bytesCompleted amount of bytes handled - */ + * (throttle) Signals that `bytesCompleted` are done, allowing more requests through the + * throttle. + * @param bytesCompleted amount of bytes handled + */ def doneBytesInFlight(bytesCompleted: Long): Unit // Bounce Buffer Management /** - * Get receive bounce buffers needed for a receive, limited by the amount of bytes - * to be received, and a hard limit on the number of buffers set by the caller - * using `totalRequired`. - * - * This function blocks if it can't satisfy the bounce buffer request. - * - * @param remaining amount of bytes remaining in the receive - * @param totalRequired maximum no. of buffers that should be returned - * @return a sequence of bounce buffers - */ + * Get receive bounce buffers needed for a receive, limited by the amount of bytes + * to be received, and a hard limit on the number of buffers set by the caller + * using `totalRequired`. + * + * This function blocks if it can't satisfy the bounce buffer request. + * + * @param remaining amount of bytes remaining in the receive + * @param totalRequired maximum no. of buffers that should be returned + * @return a sequence of bounce buffers + */ def getReceiveBounceBuffers(remaining: Long, totalRequired: Int): Seq[MemoryBuffer] /** - * Get receive bounce buffers needed for a receive, limited by the amount of bytes - * to be received, and a hard limit on the number of buffers set by the caller - * using `totalRequired`. - * - * This function is non blocking. If it can't satisfy the bounce buffer request, an empty - * sequence is returned. - * - * @param remaining amount of bytes remaining in the receive - * @param totalRequired maximum no. of buffers that should be returned - * @return a sequence of bounce buffers, or empty if the request can't be satisfied - */ + * Get receive bounce buffers needed for a receive, limited by the amount of bytes + * to be received, and a hard limit on the number of buffers set by the caller + * using `totalRequired`. + * + * This function is non blocking. If it can't satisfy the bounce buffer request, an empty + * sequence is returned. + * + * @param remaining amount of bytes remaining in the receive + * @param totalRequired maximum no. of buffers that should be returned + * @return a sequence of bounce buffers, or empty if the request can't be satisfied + */ def tryGetReceiveBounceBuffers(remaining: Long, totalRequired: Int): Seq[MemoryBuffer] /** - * Free receive bounce buffers. These may be pooled and so reused by other requests. - * @param bounceBuffers the bounce buffers to free - */ + * Free receive bounce buffers. These may be pooled and so reused by other requests. + * @param bounceBuffers the bounce buffers to free + */ def freeReceiveBounceBuffers(bounceBuffers: Seq[MemoryBuffer]): Unit /** - * Get send bounce buffers needed for a receive, limited by the amount of bytes - * to be sent, and a hard limit on the number of buffers set by the caller - * using `totalRequired`. - * - * This function blocks if it can't satisfy the bounce buffer request. - * - * @param deviceMemory true: returns a device buffer, false: returns a host buffer - * @param remaining amount of bytes remaining in the receive - * @param totalRequired maximum no. of buffers that should be returned - * @return a sequence of bounce buffers - */ + * Get send bounce buffers needed for a receive, limited by the amount of bytes + * to be sent, and a hard limit on the number of buffers set by the caller + * using `totalRequired`. + * + * This function blocks if it can't satisfy the bounce buffer request. + * + * @param deviceMemory true: returns a device buffer, false: returns a host buffer + * @param remaining amount of bytes remaining in the receive + * @param totalRequired maximum no. of buffers that should be returned + * @return a sequence of bounce buffers + */ def getSendBounceBuffers(deviceMemory: Boolean, remaining: Long, totalRequired: Int): Seq[MemoryBuffer] /** - * Get send bounce buffers needed for a receive, limited by the amount of bytes - * to be sent, and a hard limit on the number of buffers set by the caller - * using `totalRequired`. - * - * This function is non blocking. If it can't satisfy the bounce buffer request, an empty - * sequence is returned. - * - * @param deviceMemory true: returns a device buffer, false: returns a host buffer - * @param remaining amount of bytes remaining in the receive - * @param totalRequired maximum no. of buffers that should be returned - * @return a sequence of bounce buffers, or empty if the request can't be satisfied - */ + * Get send bounce buffers needed for a receive, limited by the amount of bytes + * to be sent, and a hard limit on the number of buffers set by the caller + * using `totalRequired`. + * + * This function is non blocking. If it can't satisfy the bounce buffer request, an empty + * sequence is returned. + * + * @param deviceMemory true: returns a device buffer, false: returns a host buffer + * @param remaining amount of bytes remaining in the receive + * @param totalRequired maximum no. of buffers that should be returned + * @return a sequence of bounce buffers, or empty if the request can't be satisfied + */ def tryGetSendBounceBuffers(deviceMemory: Boolean, remaining: Long, totalRequired: Int): Seq[MemoryBuffer] /** - * Free send bounce buffers. These may be pooled and so reused by other requests. - * @param bounceBuffers the bounce buffers to free - */ + * Free send bounce buffers. These may be pooled and so reused by other requests. + * @param bounceBuffers the bounce buffers to free + */ def freeSendBounceBuffers(bounceBuffers: Seq[MemoryBuffer]): Unit } /** - * A pool of direct byte buffers, sized to be `bufferSize`. - * This is a controlled leak at the moment, there is no reclaiming of buffers. - * - * NOTE: this is used for metadata messages. - * - * @param bufferSize the size of direct `ByteBuffer` to allocate. - */ + * A pool of direct byte buffers, sized to be `bufferSize`. + * This is a controlled leak at the moment, there is no reclaiming of buffers. + * + * NOTE: this is used for metadata messages. + * + * @param bufferSize the size of direct `ByteBuffer` to allocate. + */ class DirectByteBufferPool(bufferSize: Long) extends Logging { val buffers = new ConcurrentLinkedQueue[RefCountedDirectByteBuffer]() val high = new AtomicInteger(0) @@ -526,18 +526,18 @@ class DirectByteBufferPool(bufferSize: Long) extends Logging { } /** - * [[RefCountedDirectByteBuffer]] is a simple wrapper on top of a `ByteBuffer` that has been - * allocated in direct mode. - * - * The pool is used to return the `ByteBuffer` to be reused, but not all of these buffers - * are pooled (hence the argument is optional) - * - * The user should always close a [[RefCountedDirectByteBuffer]]. The close could hard destroy - * the buffer, or return the object to the pool - * - * @param bb buffer to wrap - * @param pool optional pool - */ + * [[RefCountedDirectByteBuffer]] is a simple wrapper on top of a `ByteBuffer` that has been + * allocated in direct mode. + * + * The pool is used to return the `ByteBuffer` to be reused, but not all of these buffers + * are pooled (hence the argument is optional) + * + * The user should always close a [[RefCountedDirectByteBuffer]]. The close could hard destroy + * the buffer, or return the object to the pool + * + * @param bb buffer to wrap + * @param pool optional pool + */ class RefCountedDirectByteBuffer( bb: ByteBuffer, pool: Option[DirectByteBufferPool] = None) extends AutoCloseable { @@ -546,24 +546,24 @@ class RefCountedDirectByteBuffer( var refCount: Int = 0 /** - * Adds one to the ref count. Caller should call .close() when done - * @return wrapped buffer - */ + * Adds one to the ref count. Caller should call .close() when done + * @return wrapped buffer + */ def acquire(): ByteBuffer = synchronized { refCount = refCount + 1 bb } /** - * Peeks into the wrapped buffer, without changing the ref count. - * @return wrapped buffer - */ + * Peeks into the wrapped buffer, without changing the ref count. + * @return wrapped buffer + */ def getBuffer(): ByteBuffer = bb /** - * Decrements the ref count. If the ref count reaches 0, the buffer is - * either returned to the (optional) pool or destroyed. - */ + * Decrements the ref count. If the ref count reaches 0, the buffer is + * either returned to the (optional) pool or destroyed. + */ override def close(): Unit = synchronized { refCount = refCount - 1 if (refCount <= 0) { @@ -576,10 +576,10 @@ class RefCountedDirectByteBuffer( } /** - * Destroys the direct byte buffer forcefully, rather than wait for GC - * to do it later. This helps with fragmentation issues with nearly depleted - * native heap. - */ + * Destroys the direct byte buffer forcefully, rather than wait for GC + * to do it later. This helps with fragmentation issues with nearly depleted + * native heap. + */ def unsafeDestroy(): Unit = synchronized { RapidsStorageUtils.dispose(bb) } @@ -590,8 +590,8 @@ class RefCountedDirectByteBuffer( } /** - * A set of util functions used throughout - */ + * A set of util functions used throughout + */ object TransportUtils { def formatTag(tag: Long): String = { f"0x$tag%016X" @@ -622,19 +622,19 @@ object TransportUtils { object RapidsShuffleTransport extends Logging { /** - * Used in `BlockManagerId`s when returning a map status after a shuffle write to - * let the readers know what TCP port to use to establish a transport connection. - */ + * Used in `BlockManagerId`s when returning a map status after a shuffle write to + * let the readers know what TCP port to use to establish a transport connection. + */ val BLOCK_MANAGER_ID_TOPO_PREFIX: String = "rapids" /** - * Returns an instance of `RapidsShuffleTransport`. - * @note the single current implementation is `UCXShuffleTransport`. - * @param shuffleServerId this is the original `BlockManagerId` that Spark has for this - * executor - * @param rapidsConf instance of `RapidsConf` - * @return a transport instance to be used to create a server and clients. - */ + * Returns an instance of `RapidsShuffleTransport`. + * @note the single current implementation is `UCXShuffleTransport`. + * @param shuffleServerId this is the original `BlockManagerId` that Spark has for this + * executor + * @param rapidsConf instance of `RapidsConf` + * @return a transport instance to be used to create a server and clients. + */ def makeTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsConf): RapidsShuffleTransport = { val transportClass = try { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala index 33519eb1fc5..24eb1620f8c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/OrcFilters.scala @@ -37,12 +37,12 @@ import org.apache.spark.sql.types._ object OrcFilters extends OrcFiltersBase { /** - * Create ORC filter as a SearchArgument instance. - * - * NOTE: These filters should be pre-filtered by Spark to only contain the - * filters convertible to ORC, so checking what is convertible is - * not necessary here. - */ + * Create ORC filter as a SearchArgument instance. + * + * NOTE: These filters should be pre-filtered by Spark to only contain the + * filters convertible to ORC, so checking what is convertible is + * not necessary here. + */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap // Combines all filters using `And` to produce a single conjunction @@ -56,8 +56,8 @@ object OrcFilters extends OrcFiltersBase { } /** - * Get PredicateLeafType which is corresponding to the given DataType. - */ + * Get PredicateLeafType which is corresponding to the given DataType. + */ private def getPredicateLeafType(dataType: DataType) = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG @@ -70,11 +70,11 @@ object OrcFilters extends OrcFiltersBase { } /** - * Cast literal values for filters. - * - * We need to cast to long because ORC raises exceptions - * at 'checkLiteralType' of SearchArgumentImpl.java. - */ + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { case ByteType | ShortType | IntegerType | LongType => value.asInstanceOf[Number].longValue @@ -86,13 +86,13 @@ object OrcFilters extends OrcFiltersBase { } /** - * Build a SearchArgument and return the builder so far. - * - * @param dataTypeMap a map from the attribute name to its data type. - * @param expression the input predicates, which should be fully convertible to SearchArgument. - * @param builder the input SearchArgument.Builder. - * @return the builder so far. - */ + * Build a SearchArgument and return the builder so far. + * + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input predicates, which should be fully convertible to SearchArgument. + * @param builder the input SearchArgument.Builder. + * @return the builder so far. + */ private def buildSearchArgument( dataTypeMap: Map[String, DataType], expression: Filter, @@ -122,13 +122,13 @@ object OrcFilters extends OrcFiltersBase { } /** - * Build a SearchArgument for a leaf predicate and return the builder so far. - * - * @param dataTypeMap a map from the attribute name to its data type. - * @param expression the input filter predicates. - * @param builder the input SearchArgument.Builder. - * @return the builder so far. - */ + * Build a SearchArgument for a leaf predicate and return the builder so far. + * + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input filter predicates. + * @param builder the input SearchArgument.Builder. + * @return the builder so far. + */ private def buildLeafSearchArgument( dataTypeMap: Map[String, DataType], expression: Filter, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index 08822cd041a..59097606276 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -31,14 +31,14 @@ import org.apache.spark.util.CompletionIterator trait ShuffleMetricsUpdater { /** - * Trait used as a way to expose the `ShuffleReadMetricsReporter` to the iterator. - * @param fetchWaitTimeInMs this matches the CPU name (except for the units) but it is actually - * the aggreagate amount of time a task is blocked, not working on - * anything, waiting for data. - * @param remoteBlocksFetched aggregate of number of `ShuffleBlockId`s fetched. - * @param remoteBytesRead aggregate size of all contiguous buffers received - * @param rowsFetched aggregate of number of rows received - */ + * Trait used as a way to expose the `ShuffleReadMetricsReporter` to the iterator. + * @param fetchWaitTimeInMs this matches the CPU name (except for the units) but it is actually + * the aggreagate amount of time a task is blocked, not working on + * anything, waiting for data. + * @param remoteBlocksFetched aggregate of number of `ShuffleBlockId`s fetched. + * @param remoteBytesRead aggregate size of all contiguous buffers received + * @param rowsFetched aggregate of number of rows received + */ def update( fetchWaitTimeInMs: Long, remoteBlocksFetched: Long, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index dc8a847a098..d8d602205e9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -141,8 +141,8 @@ class RapidsCachingWriter[K, V]( } /** - * Used to remove shuffle buffers when the writing task detects an error, calling `stop(false)` - */ + * Used to remove shuffle buffers when the writing task detects an error, calling `stop(false)` + */ private def cleanStorage(): Unit = { writtenBufferIds.foreach(catalog.removeBuffer) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala index 2f6db8f7fdd..b1c20fbd66d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/storage/RapidsStorageUtils.scala @@ -23,16 +23,16 @@ import org.apache.spark.storage.StorageUtils object RapidsStorageUtils { // scalastyle:off line.size.limit /** - * Calls into spark's `StorageUtils` to expose the [[dispose]] method. - * - * NOTE: This the spark code as of the writing of this function is: - * https://github.com/apache/spark/blob/e9f3f62b2c0f521f3cc23fef381fc6754853ad4f/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala#L206 - * - * If the implementation in Spark later on breaks our build, we may need to replicate - * the dispose method here. - * - * @param buffer byte buffer to dispose - */ + * Calls into spark's `StorageUtils` to expose the [[dispose]] method. + * + * NOTE: This the spark code as of the writing of this function is: + * https://github.com/apache/spark/blob/e9f3f62b2c0f521f3cc23fef381fc6754853ad4f/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala#L206 + * + * If the implementation in Spark later on breaks our build, we may need to replicate + * the dispose method here. + * + * @param buffer byte buffer to dispose + */ // scalastyle:on line.size.limit def dispose(buffer: ByteBuffer): Unit = StorageUtils.dispose(buffer) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 1fbfe8cb571..ba7e5863226 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -519,16 +519,16 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char) override def dataType: DataType = BooleanType /** - * Validate and convert SQL 'like' pattern to a cuDF regular expression. - * - * Underscores (_) are converted to '.' including newlines and percent signs (%) - * are converted to '.*' including newlines, other characters are quoted literally or escaped. - * An invalid pattern will throw an `IllegalArgumentException`. - * - * @param pattern the SQL pattern to convert - * @param escapeChar the escape string contains one character. - * @return the equivalent cuDF regular expression of the pattern - */ + * Validate and convert SQL 'like' pattern to a cuDF regular expression. + * + * Underscores (_) are converted to '.' including newlines and percent signs (%) + * are converted to '.*' including newlines, other characters are quoted literally or escaped. + * An invalid pattern will throw an `IllegalArgumentException`. + * + * @param pattern the SQL pattern to convert + * @param escapeChar the escape string contains one character. + * @return the equivalent cuDF regular expression of the pattern + */ def escapeLikeRegex(pattern: String, escapeChar: Char): String = { val in = pattern.toIterator val out = new StringBuilder() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala index 5947da4684f..25ed1764cf6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CsvScanSuite.scala @@ -47,8 +47,8 @@ class CsvScanSuite extends SparkQueryCompareTestSuite { } /** - * Running with an inferred schema results in running things that are not columnar optimized. - */ + * Running with an inferred schema results in running things that are not columnar optimized. + */ ALLOW_NON_GPU_testSparkResultsAreEqual("Test CSV inferred schema", intsFromCsvInferredSchema, Seq("FileSourceScanExec", "FilterExec", "CollectLimitExec", "GreaterThan", "Length", "StringTrim", "LocalTableScanExec", "DeserializeToObjectExec", diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index bc921edd3c7..32ddd666067 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -294,13 +294,13 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { } /** - * Runs a test defined by fun, using dataframe df. - * - * @param df the DataFrame to use as input - * @param fun the function to transform the DataFrame (produces another DataFrame) - * @param conf spark conf - * @return tuple of (cpu results, gpu results) as arrays of Row - */ + * Runs a test defined by fun, using dataframe df. + * + * @param df the DataFrame to use as input + * @param fun the function to transform the DataFrame (produces another DataFrame) + * @param conf spark conf + * @return tuple of (cpu results, gpu results) as arrays of Row + */ def runOnCpuAndGpu( df: SparkSession => DataFrame, fun: DataFrame => DataFrame,