Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in basic support for scalar maps and allow nesting in named_struct #2498

Merged
merged 1 commit into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4110,8 +4110,8 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
Expand All @@ -4133,7 +4133,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, STRUCT, UDT)</em></td>
<td> </td>
</tr>
<tr>
Expand Down Expand Up @@ -9588,8 +9588,8 @@ Accelerator support is described below.
<td>S</td>
<td><b>NS</b></td>
<td>S</td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><em>PS* (missing nested BINARY, CALENDAR, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
Expand Down
10 changes: 9 additions & 1 deletion integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,6 +32,14 @@ def test_simple_get_map_value(data_gen):
'a["NOT_FOUND"]',
'a["key_5"]'))

def test_map_scalar_project():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.range(2).selectExpr(
"map(1, 2, 3, 4) as i",
"map('a', 'b', 'c', 'd') as s",
"map('a', named_struct('foo', 10, 'bar', 'bar')) as st"
"id"))

@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value_ansi_fail(data_gen):
Expand Down
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def test_struct_get_item(data_gen):
'a.third'))


@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
@pytest.mark.parametrize('data_gen', all_basic_gens + [null_gen, decimal_gen_default, decimal_gen_scale_precision, simple_string_to_string_map_gen] + single_level_array_gens, ids=idfn)
def test_make_struct(data_gen):
# Spark has no good way to create a map literal without the map function
# so we are inserting one.
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'struct(a, b)',
'named_struct("foo", b, "bar", 5, "end", a)'))
'named_struct("foo", b, "m", map("a", "b"), "n", null, "bar", 5, "end", a)'),
conf = allow_negative_scale_of_decimal_conf)


@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@ object GpuOverrides {
expr[Literal](
"Holds a static value from the query",
ExprChecks.projectNotLambda(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
+ TypeSig.ARRAY + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR
+ TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL
+ TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.all),
(lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)),
expr[Signum](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,8 @@ object WindowSpecCheck extends ExprChecks {
object CreateNamedStructCheck extends ExprChecks {
val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING)
val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY + TypeSig.MAP).nested()
val sparkValueSig: TypeSig = TypeSig.all
val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig)
val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig)
Expand Down
41 changes: 37 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

package com.nvidia.spark.rapids

import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat,
Long => JLong, Short => JShort}
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort}
import java.util
import java.util.{List => JList, Objects}
import javax.xml.bind.DatatypeConverter
Expand All @@ -31,7 +30,7 @@ import org.json4s.JsonAST.{JField, JNull, JString}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, MapData, TimestampFormatter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,6 +75,11 @@ object GpuScalar extends Arm with Logging {
new HostColumnVector.ListType(true, resolveElementType(elementType))
case StructType(fields) =>
new HostColumnVector.StructType(true, fields.map(f => resolveElementType(f.dataType)): _*)
case MapType(keyType, valueType, _) =>
new HostColumnVector.ListType(true,
new HostColumnVector.StructType(true,
resolveElementType(keyType),
resolveElementType(valueType)))
case other =>
new HostColumnVector.BasicType(true, GpuColumnVector.getNonNestedRapidsType(other))
}
Expand Down Expand Up @@ -122,6 +126,15 @@ object GpuScalar extends Arm with Logging {
convertElementTo(data.get(id, f.dataType), f.dataType)
}
new HostColumnVector.StructData(row.asInstanceOf[Array[Object]]: _*)
case MapType(keyType, valueType, _) =>
val data = element.asInstanceOf[MapData]
val keys = data.keyArray.array.map(convertElementTo(_, keyType)).toList
val values = data.valueArray.array.map(convertElementTo(_, valueType)).toList
keys.zip(values).map {
case (k, v) => new HostColumnVector.StructData(
k.asInstanceOf[Object],
v.asInstanceOf[Object])
}.asJava
case _ => element
}

Expand Down Expand Up @@ -176,6 +189,10 @@ object GpuScalar extends Arm with Logging {
val colType = resolveElementType(elementType)
val rows = seq.map(convertElementTo(_, elementType))
ColumnVector.fromStructs(colType, rows.asInstanceOf[Seq[HostColumnVector.StructData]]: _*)
case MapType(_, _, _) =>
val colType = resolveElementType(elementType)
val rows = seq.map(convertElementTo(_, elementType))
ColumnVector.fromLists(colType, rows.asInstanceOf[Seq[JList[_]]]: _*)
case NullType =>
GpuColumnVector.columnVectorFromNull(seq.size, NullType)
case u =>
Expand All @@ -198,6 +215,9 @@ object GpuScalar extends Arm with Logging {
def from(v: Any, t: DataType): Scalar = t match {
case nullType if v == null => nullType match {
case ArrayType(elementType, _) => Scalar.listFromNull(resolveElementType(elementType))
case MapType(keyType, valueType, _) => Scalar.listFromNull(
resolveElementType(StructType(
Seq(StructField("key", keyType), StructField("value", valueType)))))
case _ => Scalar.fromNull(GpuColumnVector.getNonNestedRapidsType(nullType))
}
case decType: DecimalType =>
Expand Down Expand Up @@ -283,6 +303,19 @@ object GpuScalar extends Arm with Logging {
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for ArrayType, expecting ArrayData")
}
case MapType(keyType, valueType, _) => v match {
case map: MapData =>
val struct = withResource(columnVectorFromLiterals(map.keyArray().array, keyType)) { keys =>
withResource(columnVectorFromLiterals(map.valueArray().array, valueType)) { values =>
ColumnVector.makeStruct(map.numElements(), keys, values)
}
}
withResource(struct) { struct =>
Scalar.listFromColumnView(struct)
}
case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" +
s" for MapType, expecting MapData")
}
case _ => throw new UnsupportedOperationException(s"${v.getClass} '$v' is not supported" +
s" as a Scalar yet")
}
Expand Down Expand Up @@ -376,7 +409,7 @@ class GpuScalar private(

private var refCount: Int = 0

if(scalar.isEmpty && value.isEmpty) {
if (scalar.isEmpty && value.isEmpty) {
throw new IllegalArgumentException("GpuScalar requires at least a value or a Scalar")
}
if (value.isDefined && value.get.isInstanceOf[Scalar]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression
withResource(new Array[ColumnVector](valExprs.size)) { columns =>
val numRows = batch.numRows()
valExprs.indices.foreach { index =>
val dt = dataType.fields(index).dataType
val ret = valExprs(index).columnarEval(batch)
columns(index) =
GpuExpressionsUtils.resolveColumnVector(ret, numRows, dt).getBase
columns(index) = GpuExpressionsUtils.columnarEvalToColumn(valExprs(index), batch).getBase
}
GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType)
}
Expand Down