Skip to content

Commit

Permalink
Initial support for multi-type scalar map keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
mythrocks committed Mar 8, 2022
1 parent 9767748 commit becd662
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
21 changes: 19 additions & 2 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,34 @@ def test_map_entries(data_gen):
# in here yet, and would need some special case code for checking equality
'map_entries(a)'))


@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value(data_gen):
def test_get_map_value_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["key_0"]',
'a["key_1"]',
'a[null]',
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))


integral_map_gens = [MapGen(f(nullable=False), f())
for f in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen]]


@pytest.mark.parametrize('data_gen', integral_map_gens, ids=idfn)
def test_get_map_value_integral_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[0]',
'a[1]',
'a[null]',
'a[-9]',
'a[999]'))


@pytest.mark.parametrize('key_gen', [StringGen(nullable=False), IntegerGen(nullable=False), basic_struct_gen], ids=idfn)
@pytest.mark.parametrize('value_gen', [StringGen(nullable=True), IntegerGen(nullable=True), basic_struct_gen], ids=idfn)
def test_single_entry_map(key_gen, value_gen):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(TypeSig.commonCudfTypes, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypes, TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, SQLConf.get.ansiEnabled)
Expand Down

0 comments on commit becd662

Please sign in to comment.