From 8ede4ceffb940864f4a3b4af43d34528061ad05c Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Tue, 25 May 2021 10:23:00 -0500 Subject: [PATCH] Add in basic support for scalar maps and allow nesting in named_struct (#2498) Signed-off-by: Robert (Bobby) Evans --- docs/supported_ops.md | 10 ++--- integration_tests/src/main/python/map_test.py | 10 ++++- .../src/main/python/struct_test.py | 7 +++- .../nvidia/spark/rapids/GpuOverrides.scala | 6 +-- .../com/nvidia/spark/rapids/TypeChecks.scala | 3 +- .../com/nvidia/spark/rapids/literals.scala | 41 +++++++++++++++++-- .../spark/sql/rapids/complexTypeCreator.scala | 5 +-- 7 files changed, 62 insertions(+), 20 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 23e53f9e7a1..ece8395ef9d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -4110,8 +4110,8 @@ Accelerator support is described below. S NS NS -NS -NS +PS* (missing nested BINARY, CALENDAR, STRUCT, UDT) +PS* (missing nested BINARY, CALENDAR, STRUCT, UDT) NS NS @@ -4133,7 +4133,7 @@ Accelerator support is described below. -PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (missing nested BINARY, CALENDAR, STRUCT, UDT) @@ -9588,8 +9588,8 @@ Accelerator support is described below. S NS S -PS* (missing nested BINARY, CALENDAR, MAP, UDT) -NS +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) NS NS diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index c4939e0609c..e119c7b3ca3 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,14 @@ def test_simple_get_map_value(data_gen): 'a["NOT_FOUND"]', 'a["key_5"]')) +def test_map_scalar_project(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.range(2).selectExpr( + "map(1, 2, 3, 4) as i", + "map('a', 'b', 'c', 'd') as s", + "map('a', named_struct('foo', 10, 'bar', 'bar')) as st" + "id")) + @pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element") @pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) def test_simple_get_map_value_ansi_fail(data_gen): diff --git a/integration_tests/src/main/python/struct_test.py b/integration_tests/src/main/python/struct_test.py index 7369d4df2b2..9d12727f602 100644 --- a/integration_tests/src/main/python/struct_test.py +++ b/integration_tests/src/main/python/struct_test.py @@ -32,12 +32,15 @@ def test_struct_get_item(data_gen): 'a.third')) -@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn) +@pytest.mark.parametrize('data_gen', all_basic_gens + [null_gen, decimal_gen_default, decimal_gen_scale_precision, simple_string_to_string_map_gen] + single_level_array_gens, ids=idfn) def test_make_struct(data_gen): + # Spark has no good way to create a map literal without the map function + # so we are inserting one. assert_gpu_and_cpu_are_equal_collect( lambda spark : binary_op_df(spark, data_gen).selectExpr( 'struct(a, b)', - 'named_struct("foo", b, "bar", 5, "end", a)')) + 'named_struct("foo", b, "m", map("a", "b"), "n", null, "bar", 5, "end", a)'), + conf = allow_negative_scale_of_decimal_conf) @pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 33973a459cf..8f5ea0bf148 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -773,9 +773,9 @@ object GpuOverrides { expr[Literal]( "Holds a static value from the query", ExprChecks.projectNotLambda( - TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR - + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL - + TypeSig.ARRAY + TypeSig.STRUCT), + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.CALENDAR + + TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), TypeSig.all), (lit, conf, p, r) => new LiteralExprMeta(lit, conf, p, r)), expr[Signum]( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 23edb99ec02..a742bb4f30e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -827,7 +827,8 @@ object WindowSpecCheck extends ExprChecks { object CreateNamedStructCheck extends ExprChecks { val nameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) val sparkNameSig: TypeSig = TypeSig.lit(TypeEnum.STRING) - val valueSig: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + val valueSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + + TypeSig.ARRAY + TypeSig.MAP).nested() val sparkValueSig: TypeSig = TypeSig.all val resultSig: TypeSig = TypeSig.STRUCT.nested(valueSig) val sparkResultSig: TypeSig = TypeSig.STRUCT.nested(sparkValueSig) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 36eb4fc15c4..aabac047a33 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -16,8 +16,7 @@ package com.nvidia.spark.rapids -import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, - Long => JLong, Short => JShort} +import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} import java.util import java.util.{List => JList, Objects} import javax.xml.bind.DatatypeConverter @@ -31,7 +30,7 @@ import org.json4s.JsonAST.{JField, JNull, JString} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, DateTimeUtils, MapData, TimestampFormatter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ @@ -76,6 +75,11 @@ object GpuScalar extends Arm with Logging { new HostColumnVector.ListType(true, resolveElementType(elementType)) case StructType(fields) => new HostColumnVector.StructType(true, fields.map(f => resolveElementType(f.dataType)): _*) + case MapType(keyType, valueType, _) => + new HostColumnVector.ListType(true, + new HostColumnVector.StructType(true, + resolveElementType(keyType), + resolveElementType(valueType))) case other => new HostColumnVector.BasicType(true, GpuColumnVector.getNonNestedRapidsType(other)) } @@ -122,6 +126,15 @@ object GpuScalar extends Arm with Logging { convertElementTo(data.get(id, f.dataType), f.dataType) } new HostColumnVector.StructData(row.asInstanceOf[Array[Object]]: _*) + case MapType(keyType, valueType, _) => + val data = element.asInstanceOf[MapData] + val keys = data.keyArray.array.map(convertElementTo(_, keyType)).toList + val values = data.valueArray.array.map(convertElementTo(_, valueType)).toList + keys.zip(values).map { + case (k, v) => new HostColumnVector.StructData( + k.asInstanceOf[Object], + v.asInstanceOf[Object]) + }.asJava case _ => element } @@ -176,6 +189,10 @@ object GpuScalar extends Arm with Logging { val colType = resolveElementType(elementType) val rows = seq.map(convertElementTo(_, elementType)) ColumnVector.fromStructs(colType, rows.asInstanceOf[Seq[HostColumnVector.StructData]]: _*) + case MapType(_, _, _) => + val colType = resolveElementType(elementType) + val rows = seq.map(convertElementTo(_, elementType)) + ColumnVector.fromLists(colType, rows.asInstanceOf[Seq[JList[_]]]: _*) case NullType => GpuColumnVector.columnVectorFromNull(seq.size, NullType) case u => @@ -198,6 +215,9 @@ object GpuScalar extends Arm with Logging { def from(v: Any, t: DataType): Scalar = t match { case nullType if v == null => nullType match { case ArrayType(elementType, _) => Scalar.listFromNull(resolveElementType(elementType)) + case MapType(keyType, valueType, _) => Scalar.listFromNull( + resolveElementType(StructType( + Seq(StructField("key", keyType), StructField("value", valueType))))) case _ => Scalar.fromNull(GpuColumnVector.getNonNestedRapidsType(nullType)) } case decType: DecimalType => @@ -283,6 +303,19 @@ object GpuScalar extends Arm with Logging { case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + s" for ArrayType, expecting ArrayData") } + case MapType(keyType, valueType, _) => v match { + case map: MapData => + val struct = withResource(columnVectorFromLiterals(map.keyArray().array, keyType)) { keys => + withResource(columnVectorFromLiterals(map.valueArray().array, valueType)) { values => + ColumnVector.makeStruct(map.numElements(), keys, values) + } + } + withResource(struct) { struct => + Scalar.listFromColumnView(struct) + } + case _ => throw new IllegalArgumentException(s"'$v: ${v.getClass}' is not supported" + + s" for MapType, expecting MapData") + } case _ => throw new UnsupportedOperationException(s"${v.getClass} '$v' is not supported" + s" as a Scalar yet") } @@ -376,7 +409,7 @@ class GpuScalar private( private var refCount: Int = 0 - if(scalar.isEmpty && value.isEmpty) { + if (scalar.isEmpty && value.isEmpty) { throw new IllegalArgumentException("GpuScalar requires at least a value or a Scalar") } if (value.isDefined && value.get.isInstanceOf[Scalar]) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala index fd216d2868d..6619ef7c5c0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala @@ -148,10 +148,7 @@ case class GpuCreateNamedStruct(children: Seq[Expression]) extends GpuExpression withResource(new Array[ColumnVector](valExprs.size)) { columns => val numRows = batch.numRows() valExprs.indices.foreach { index => - val dt = dataType.fields(index).dataType - val ret = valExprs(index).columnarEval(batch) - columns(index) = - GpuExpressionsUtils.resolveColumnVector(ret, numRows, dt).getBase + columns(index) = GpuExpressionsUtils.columnarEvalToColumn(valExprs(index), batch).getBase } GpuColumnVector.from(ColumnVector.makeStruct(numRows, columns: _*), dataType) }