diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index df896fb8a999..d37edfc667b3 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -64,7 +64,9 @@ def test_map_entries(data_gen): map_value_gens = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, StringGen, DateGen, TimestampGen] -@pytest.mark.parametrize('data_gen', [MapGen(StringGen(nullable=False), value()) for value in map_value_gens], ids=idfn) +@pytest.mark.parametrize('data_gen', + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens], + ids=idfn) def test_get_map_value_string_keys(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( @@ -197,17 +199,55 @@ def test_map_get_map_value_ansi_not_fail(data_gen): 'a["NOT_FOUND"]'), conf=ansi_enabled_conf) -@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) -def test_simple_element_at_map(data_gen): + +@pytest.mark.parametrize('data_gen', + [MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens], + ids=idfn) +def test_element_at_map_string_keys(data_gen): assert_gpu_and_cpu_are_equal_collect( - lambda spark : unary_op_df(spark, data_gen).selectExpr( + lambda spark: unary_op_df(spark, data_gen).selectExpr( 'element_at(a, "key_0")', 'element_at(a, "key_1")', 'element_at(a, "null")', 'element_at(a, "key_9")', 'element_at(a, "NOT_FOUND")', 'element_at(a, "key_5")'), - conf={'spark.sql.ansi.enabled':False}) + conf={'spark.sql.ansi.enabled': False}) + + +@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn) +def test_element_at_map_numeric_keys(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'element_at(a, 0)', + 'element_at(a, 1)', + 'element_at(a, null)', + 'element_at(a, -9)', + 'element_at(a, 999)'), + conf={'spark.sql.ansi.enabled': False}) + + +@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn) +def test_element_at_map_date_keys(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'element_at(a, date "1997")', + 'element_at(a, date "2022-01-01")', + 'element_at(a, null)'), + conf={'spark.sql.ansi.enabled': False}) + + +@pytest.mark.parametrize('data_gen', + [MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens], + ids=idfn) +def test_element_at_map_timestamp_keys(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'element_at(a, timestamp "1997")', + 'element_at(a, timestamp "2022-01-01")', + 'element_at(a, null)'), + conf={'spark.sql.ansi.enabled': False}) + @pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element") @pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) diff --git a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala index 83dd50e636a4..3a746d426bba 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala @@ -279,7 +279,10 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging { }), GpuOverrides.expr[GetMapValue]( "Gets Value from a Map based on a key", - ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all, + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes, TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){ @@ -294,14 +297,11 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging { TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all, ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) + - TypeSig.MAP.nested(TypeSig.STRING) - .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), + TypeSig.MAP.nested(TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), - ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) - .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + - "not as maps keys") - .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + - "not array indexes"), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -309,9 +309,12 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging { val checks = in.left.dataType match { case _: MapType => // Match exactly with the checks for GetMapValue - ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, - ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), - ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.commonCudfTypes, TypeSig.all)) case _: ArrayType => // Match exactly with the checks for GetArrayItem ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala index 4f84259bf720..3700d324b398 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala @@ -278,7 +278,10 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging { }), GpuOverrides.expr[GetMapValue]( "Gets Value from a Map based on a key", - ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all, + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes, TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){ @@ -293,14 +296,11 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging { TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all, ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) + - TypeSig.MAP.nested(TypeSig.STRING) - .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), + TypeSig.MAP.nested(TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), - ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) - .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + - "not as maps keys") - .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + - "not array indexes"), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -308,9 +308,12 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging { val checks = in.left.dataType match { case _: MapType => // Match exactly with the checks for GetMapValue - ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, - ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), - ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.commonCudfTypes, TypeSig.all)) case _: ArrayType => // Match exactly with the checks for GetArrayItem ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala index 32f3f36644c9..095be06b63e7 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala @@ -381,7 +381,10 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { }), GpuOverrides.expr[GetMapValue]( "Gets Value from a Map based on a key", - ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all, + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes, TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r) { @@ -396,14 +399,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all, ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) + - TypeSig.MAP.nested(TypeSig.STRING) - .withPsNote(TypeEnum.MAP, "If it's map, only string is supported."), + TypeSig.MAP.nested(TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.MAP, "If it's map, only primitive key types are supported."), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), - ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) - .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + - "not as maps keys") - .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + - "not array indexes"), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -411,9 +411,12 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { val checks = in.left.dataType match { case _: MapType => // Match exactly with the checks for GetMapValue - ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, - ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), - ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.commonCudfTypes, TypeSig.all)) case _: ArrayType => // Match exactly with the checks for GetArrayItem ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 37ca7070442b..5a1e73b9af77 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2564,7 +2564,10 @@ object GpuOverrides extends Logging { (in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)), expr[GetMapValue]( "Gets Value from a Map based on a key", - ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all, + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)), ("key", TypeSig.commonCudfTypes, TypeSig.all)), (in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)), @@ -2576,14 +2579,11 @@ object GpuOverrides extends Logging { TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all, ("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) + - TypeSig.MAP.nested(TypeSig.STRING) - .withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."), + TypeSig.MAP.nested(TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."), TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)), - ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING)) - .withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " + - "not as maps keys") - .withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " + - "not array indexes"), + ("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes) + .withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"), TypeSig.all)), (in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) { override def tagExprForGpu(): Unit = { @@ -2591,9 +2591,13 @@ object GpuOverrides extends Logging { val checks = in.left.dataType match { case _: MapType => // Match exactly with the checks for GetMapValue - ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all, - ("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)), - ("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)) + ExprChecks.binaryProject( + (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL + + TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), + TypeSig.all, + ("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), + TypeSig.MAP.nested(TypeSig.all)), + ("key", TypeSig.commonCudfTypes, TypeSig.all)) case _: ArrayType => // Match exactly with the checks for GetArrayItem ExprChecks.binaryProject(