Skip to content

Commit

Permalink
Support for map element_at() with non-string keys:
Browse files Browse the repository at this point in the history
Also, fixed return types for GetMapValue.
Also, tests for permutations of key types.
  • Loading branch information
mythrocks committed Mar 12, 2022
1 parent 1986ca7 commit af4306e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 49 deletions.
50 changes: 45 additions & 5 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def test_map_entries(data_gen):
map_value_gens = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, StringGen, DateGen, TimestampGen]


@pytest.mark.parametrize('data_gen', [MapGen(StringGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_get_map_value_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down Expand Up @@ -197,17 +199,55 @@ def test_map_get_map_value_ansi_not_fail(data_gen):
'a["NOT_FOUND"]'),
conf=ansi_enabled_conf)

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_element_at_map(data_gen):

@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, "key_0")',
'element_at(a, "key_1")',
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_element_at_map_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, 0)',
'element_at(a, 1)',
'element_at(a, null)',
'element_at(a, -9)',
'element_at(a, 999)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_element_at_map_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, date "1997")',
'element_at(a, date "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen',
[MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, timestamp "1997")',
'element_at(a, timestamp "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all,
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
Expand All @@ -294,24 +297,24 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all,
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
Expand All @@ -293,24 +296,24 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,10 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all,
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r) {
Expand All @@ -396,24 +399,24 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP, "If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP, "If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2564,7 +2564,10 @@ object GpuOverrides extends Logging {
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all,
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
Expand All @@ -2576,24 +2579,25 @@ object GpuOverrides extends Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.lit(TypeEnum.INT) + TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down

0 comments on commit af4306e

Please sign in to comment.