From d691102bff070d387e6ba2f183fad2fc610884e4 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 10 Mar 2022 16:31:28 -0800 Subject: [PATCH] Permuted map key/value types for getMapValue tests. --- integration_tests/src/main/python/map_test.py | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/map_test.py b/integration_tests/src/main/python/map_test.py index ee42a780170e..df896fb8a999 100644 --- a/integration_tests/src/main/python/map_test.py +++ b/integration_tests/src/main/python/map_test.py @@ -61,7 +61,10 @@ def test_map_entries(data_gen): 'map_entries(a)')) -@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn) +map_value_gens = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, StringGen, DateGen, TimestampGen] + + +@pytest.mark.parametrize('data_gen', [MapGen(StringGen(nullable=False), value()) for value in map_value_gens], ids=idfn) def test_get_map_value_string_keys(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( @@ -73,12 +76,15 @@ def test_get_map_value_string_keys(data_gen): 'a["key_5"]')) -integral_map_gens = [MapGen(f(nullable=False), f()) - for f in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen]] +numeric_key_gens = [key(nullable=False) if key in [FloatGen, DoubleGen] + else key(nullable=False, min_val=0, max_val=100) + for key in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen]] +numeric_key_map_gens = [MapGen(key, value()) for key in numeric_key_gens for value in map_value_gens] -@pytest.mark.parametrize('data_gen', integral_map_gens, ids=idfn) -def test_get_map_value_integral_keys(data_gen): + +@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn) +def test_get_map_value_numeric_keys(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( 'a[0]', @@ -88,6 +94,24 @@ def test_get_map_value_integral_keys(data_gen): 'a[999]')) +@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn) +def test_get_map_value_date_keys(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'a[date "1997"]', + 'a[date "2022-01-01"]', + 'a[null]')) + + +@pytest.mark.parametrize('data_gen', [MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens], ids=idfn) +def test_get_map_value_timestamp_keys(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, data_gen).selectExpr( + 'a[timestamp "1997"]', + 'a[timestamp "2022-01-01"]', + 'a[null]')) + + @pytest.mark.parametrize('key_gen', [StringGen(nullable=False), IntegerGen(nullable=False), basic_struct_gen], ids=idfn) @pytest.mark.parametrize('value_gen', [StringGen(nullable=True), IntegerGen(nullable=True), basic_struct_gen], ids=idfn) def test_single_entry_map(key_gen, value_gen):