Skip to content

Commit

Permalink
Added in support for RangeExec (NVIDIA#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
revans2 authored Jul 21, 2020
1 parent 0035ac3 commit d96bce1
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ Name | Description | Default Value | Notes
<a name="sql.exec.GlobalLimitExec"></a>spark.rapids.sql.exec.GlobalLimitExec|Limiting of results across partitions|true|None|
<a name="sql.exec.LocalLimitExec"></a>spark.rapids.sql.exec.LocalLimitExec|Per-partition limiting of results|true|None|
<a name="sql.exec.ProjectExec"></a>spark.rapids.sql.exec.ProjectExec|The backend for most select, withColumn and dropColumn statements|true|None|
<a name="sql.exec.RangeExec"></a>spark.rapids.sql.exec.RangeExec|The backend for range operator|true|None|
<a name="sql.exec.SortExec"></a>spark.rapids.sql.exec.SortExec|The backend for the sort operator|true|None|
<a name="sql.exec.UnionExec"></a>spark.rapids.sql.exec.UnionExec|The backend for the union operator|true|None|
<a name="sql.exec.HashAggregateExec"></a>spark.rapids.sql.exec.HashAggregateExec|The backend for hash based aggregations|true|None|
Expand Down
54 changes: 54 additions & 0 deletions integration_tests/src/main/python/range_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from data_gen import LONG_MAX, LONG_MIN
from pyspark.sql.types import *
import pyspark.sql.functions as f

def test_simple_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(100))

def test_start_end_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(-100, 100))

def test_step_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(-100, 100, 7))

def test_neg_step_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(100, -100, -7))

def test_partitioned_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(1000, numPartitions=2))

def test_large_corner_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(LONG_MAX - 100, LONG_MAX, step=3))

def test_small_corner_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(LONG_MIN + 100, LONG_MIN, step=-3))

def test_wrong_step_corner_range():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(100, -100, 7))


Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,14 @@ object GpuOverrides {
GpuProjectExec(childExprs.map(_.convertToGpu()), childPlans(0).convertIfNeeded())
}
}),
exec[RangeExec](
"The backend for range operator",
(range, conf, p, r) => {
new SparkPlanMeta[RangeExec](range, conf, p, r) {
override def convertToGpu(): GpuExec =
GpuRangeExec(range.range, conf.gpuTargetBatchSizeBytes)
}
}),
exec[BatchScanExec](
"The backend for most file input",
(p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf
import ai.rapids.cudf.NvtxColor
import ai.rapids.cudf.{NvtxColor, Scalar, Table}
import com.nvidia.spark.rapids.GpuMetricNames._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, IsNotNull, NamedExpression, NullIntolerant, PredicateHelper, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, RangePartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
Expand Down Expand Up @@ -63,7 +63,7 @@ object GpuProjectExec {
}

case class GpuProjectExec(projectList: Seq[Expression], child: SparkPlan)
extends UnaryExecNode with GpuExec {
extends UnaryExecNode with GpuExec {

private val sparkProjectList = projectList.asInstanceOf[Seq[NamedExpression]]

Expand Down Expand Up @@ -124,7 +124,7 @@ object GpuFilter {
}

case class GpuFilterExec(condition: Expression, child: SparkPlan)
extends UnaryExecNode with PredicateHelper with GpuExec {
extends UnaryExecNode with PredicateHelper with GpuExec {

// Split out all the IsNotNulls from condition.
private val (notNullPreds, _) = splitConjunctivePredicates(condition).partition {
Expand Down Expand Up @@ -170,11 +170,133 @@ case class GpuFilterExec(condition: Expression, child: SparkPlan)
val boundCondition = GpuBindReferences.bindReference(condition, child.output)
val rdd = child.executeColumnar()
rdd.map { batch =>
GpuFilter(batch, boundCondition, numOutputRows, numOutputBatches, totalTime)
GpuFilter(batch, boundCondition, numOutputRows, numOutputBatches, totalTime)
}
}
}

/**
* Physical plan for range (generating a range of 64 bit numbers).
*/
case class GpuRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range,
targetSizeBytes: Long)
extends LeafExecNode with GpuExec {

val start: Long = range.start
val end: Long = range.end
val step: Long = range.step
val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
val numElements: BigInt = range.numElements

override val output: Seq[Attribute] = range.output

override def outputOrdering: Seq[SortOrder] = range.outputOrdering

override def outputPartitioning: Partitioning = {
if (numElements > 0) {
if (numSlices == 1) {
SinglePartition
} else {
RangePartitioning(outputOrdering, numSlices)
}
} else {
UnknownPartitioning(0)
}
}

override def doCanonicalize(): SparkPlan = {
GpuRangeExec(
range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range],
targetSizeBytes)
}

protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES)
val totalTime = longMetric(TOTAL_TIME)
val maxRowCountPerBatch = Math.min(targetSizeBytes/8, Int.MaxValue)

sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
.mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
def getSafeMargin(bi: BigInt): Long =
if (bi.isValidLong) {
bi.toLong
} else if (bi > 0) {
Long.MaxValue
} else {
Long.MinValue
}
val safePartitionStart = getSafeMargin(partitionStart) // inclusive
val safePartitionEnd = getSafeMargin(partitionEnd) // exclusive, unless start == this
val taskContext = TaskContext.get()

val iter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
private[this] var number: Long = safePartitionStart
private[this] var done: Boolean = false
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics

override def hasNext: Boolean =
if (!done) {
if (step > 0) {
number < safePartitionEnd
} else {
number > safePartitionEnd
}
} else false

override def next(): ColumnarBatch =
withResource(new NvtxWithMetrics("GpuRange", NvtxColor.DARK_GREEN, totalTime)){
_ =>
GpuSemaphore.acquireIfNecessary(taskContext)
val start = number
val remainingSteps = (safePartitionEnd - start) / step
// Start is inclusive so we need to produce at least one row
val rowsThisBatch = Math.max(1, Math.min(remainingSteps, maxRowCountPerBatch))
val endInclusive = start + ((rowsThisBatch - 1) * step)
number = endInclusive + step
if (number < endInclusive ^ step < 0) {
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a
// step back, we are pretty sure that we have an overflow.
done = true
}

val ret = withResource(Scalar.fromLong(start)) { startScalar =>
withResource(Scalar.fromLong(step)) { stepScalar =>
withResource(
ai.rapids.cudf.ColumnVector.sequence(
startScalar, stepScalar, rowsThisBatch.toInt)) { vec =>
withResource(new Table(vec)) { tab =>
GpuColumnVector.from(tab)
}
}
}
}

assert (rowsThisBatch == ret.numRows())
numOutputRows += rowsThisBatch
TrampolineUtil.incInputRecordsRows(inputMetrics, rowsThisBatch)
numOutputBatches += 1
ret
}
}
new InterruptibleIterator(taskContext, iter)
}
}

override def simpleString(maxFields: Int): String = {
s"GpuRange ($start, $end, step=$step, splits=$numSlices)"
}

override protected def doExecute(): RDD[InternalRow] =
throw new IllegalStateException(s"Row-based execution should not occur for $this")
}


case class GpuUnionExec(children: Seq[SparkPlan]) extends SparkPlan with GpuExec {
// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.json4s.JsonAST

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.InputMetrics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
Expand Down Expand Up @@ -58,4 +59,7 @@ object TrampolineUtil {
def dataTypeExistsRecursively(dt: DataType, f: DataType => Boolean): Boolean = {
dt.existsRecursively(f)
}

def incInputRecordsRows(inputMetrics: InputMetrics, rows: Long): Unit =
inputMetrics.incRecordsRead(rows)
}

0 comments on commit d96bce1

Please sign in to comment.