Skip to content

Commit

Permalink
Add in basic support for scalar maps and allow nesting in named_struct (
Browse files Browse the repository at this point in the history
#2498)

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored May 25, 2021
1 parent e603a1a commit 8ede4ce
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 20 deletions.
10 changes: 5 additions & 5 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4110,8 +4110,8 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
Expand All @@ -4133,7 +4133,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td> </td>
</tr>
<tr>
Expand Down Expand Up @@ -9588,8 +9588,8 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td>S</td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
Expand Down
10 changes: 9 additions & 1 deletion integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,6 +32,14 @@ def test_simple_get_map_value(data_gen):
'a["NOT_FOUND"]',
'a["key_5"]'))

def test_map_scalar_project():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(2).selectExpr(
"map(1, 2, 3, 4) as i",
"map('a', 'b', 'c', 'd') as s",
"map('a', named_struct('foo', 10, 'bar', 'bar')) as st"
"id"))

@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value_ansi_fail(data_gen):
Expand Down
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def test_struct_get_item(data_gen):
'a.third'))


@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
@pytest.mark.parametrize('data_gen', all_basic_gens + [null_gen, decimal_gen_default, decimal_gen_scale_precision, simple_string_to_string_map_gen] + single_level_array_gens, ids=idfn)
def test_make_struct(data_gen):
# Spark has no good way to create a map literal without the map function
# so we are inserting one.
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'struct(a, b)',
'named_struct("foo", b, "bar", 5, "end", a)'))
'named_struct("foo", b, "m", map("a", "b"), "n", null, "bar", 5, "end", a)'),
conf = allow_negative_scale_of_decimal_conf)


@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@ object GpuOverrides {
expr[Literal](
"Holds a static value from the query",
ExprChecks.projectNotLambda(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
+ TypeSig.ARRAY + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR
+ TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL
+ TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.all),
(lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)),
expr[Signum](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,8 @@ object WindowSpecCheck extends ExprChecks {
object CreateNamedStructCheck extends ExprChecks {
val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY + TypeSig.MAP).nested()
val sparkValueSig: TypeSig = TypeSig.all
val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig)
val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig)
Expand Down
41 changes: 37 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

package com.nvidia.spark.rapids

import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat,
Long => JLong, Short => JShort}
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort}
import java.util
import java.util.{List => JList, Objects}
import javax.xml.bind.DatatypeConverter
Expand All @@ -31,7 +30,7 @@ import org.json4s.JsonAST.{JField, JNull, JString}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, MapData, TimestampFormatter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,6 +75,11 @@ object GpuScalar extends Arm with Logging {
new HostColumnVector.ListType(true, resolveElementType(elementType))
case StructType(fields) =>
new HostColumnVector.StructType(true, fields.map(f => resolveElementType(f.dataType)): _*)
case MapType(keyType, valueType, _) =>
new HostColumnVector.ListType(true,
new HostColumnVector.StructType(true,
resolveElementType(keyType),
resolveElementType(valueType)))
case other =>
new HostColumnVector.BasicType(true, GpuColumnVector.getNonNestedRapidsType(other))
}
Expand Down Expand Up @@ -122,6 +126,15 @@ object GpuScalar extends Arm with Logging {
convertElementTo(data.get(id, f.dataType), f.dataType)
}
new HostColumnVector.StructData(row.asInstanceOf[Array[Object]]: _*)
case MapType(keyType, valueType, _) =>
val data = element.asInstanceOf[MapData]
val keys = data.keyArray.array.map(convertElementTo(_, keyType)).toList
val values = data.valueArray.array.map(convertElementTo(_, valueType)).toList
keys.zip(values).map {
case (k, v) => new HostColumnVector.StructData(
k.asInstanceOf[Object],
v.asInstanceOf[Object])
}.asJava
case _ => element
}

Expand Down Expand Up @@ -176,6 +189,10 @@ object GpuScalar extends Arm with Logging {
val colType = resolveElementType(elementType)
val rows = seq.map(convertElementTo(_, elementType))
ColumnVector.fromStructs(colType, rows.asInstanceOf[Seq[HostColumnVector.StructData]]: _*)
case MapType(_, _, _) =>
val colType = resolveElementType(elementType)
val rows = seq.map(convertElementTo(_, elementType))
ColumnVector.fromLists(colType, rows.asInstanceOf[Seq[JList[_]]]: _*)
case NullType =>
GpuColumnVector.columnVectorFromNull(seq.size, NullType)
case u =>
Expand All @@ -198,6 +215,9 @@ object GpuScalar extends Arm with Logging {
def from(v: Any, t: DataType): Scalar = t match {
case nullType if v == null => nullType match {
case ArrayType(elementType, _) => Scalar.listFromNull(resolveElementType(elementType))
case MapType(keyType, valueType, _) => Scalar.listFromNull(
resolveElementType(StructType(
Seq(StructField("key", keyType), StructField("value", valueType)))))
case _ => Scalar.fromNull(GpuColumnVector.getNonNestedRapidsType(nullType))
}
case decType: DecimalType =>
Expand Down Expand Up @@ -283,6 +303,19 @@ object GpuScalar extends Arm with Logging {
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for ArrayType, expecting ArrayData")
}
case MapType(keyType, valueType, _) => v match {
case map: MapData =>
val struct = withResource(columnVectorFromLiterals(map.keyArray().array, keyType)) { keys =>
withResource(columnVectorFromLiterals(map.valueArray().array, valueType)) { values =>
ColumnVector.makeStruct(map.numElements(), keys, values)
}
}
withResource(struct) { struct =>
Scalar.listFromColumnView(struct)
}
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for MapType, expecting MapData")
}
case _ => throw new UnsupportedOperationException(s"${v.getClass} '$v' is not supported" +
s" as a Scalar yet")
}
Expand Down Expand Up @@ -376,7 +409,7 @@ class GpuScalar private(

private var refCount: Int = 0

if(scalar.isEmpty && value.isEmpty) {
if (scalar.isEmpty && value.isEmpty) {
throw new IllegalArgumentException("GpuScalar requires at least a value or a Scalar")
}
if (value.isDefined && value.get.isInstanceOf[Scalar]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression
withResource(new Array[ColumnVector](valExprs.size)) { columns =>
val numRows = batch.numRows()
valExprs.indices.foreach { index =>
val dt = dataType.fields(index).dataType
val ret = valExprs(index).columnarEval(batch)
columns(index) =
GpuExpressionsUtils.resolveColumnVector(ret, numRows, dt).getBase
columns(index) = GpuExpressionsUtils.columnarEvalToColumn(valExprs(index), batch).getBase
}
GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType)
}
Expand Down

0 comments on commit 8ede4ce

Please sign in to comment.