From f4a898cd0faafbc972d87a8114078f0bd2fa5cb1 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 9 Nov 2023 16:27:30 -0800 Subject: [PATCH] Fix Integration Test Failures for Databricks 13.3 Support (#9646) * Fixed orc_test, parquet_test and regexp_test * Added Support for PythonUDAF * moved the PythonUDAF override to it's correct place * removed left over imports from the bad commit * build fix after upmerge * fixed imports * fix 341db * Signing off Signed-off-by: Raza Jafri * enable test_read_hive_fixed_length_char for 341db --------- Signed-off-by: Raza Jafri --- integration_tests/src/main/python/orc_test.py | 6 +++--- integration_tests/src/main/python/parquet_test.py | 1 + integration_tests/src/main/python/regexp_test.py | 10 +++++----- .../com/nvidia/spark/rapids/shims/PythonUDFShim.scala | 1 + 4 files changed, 10 insertions(+), 8 deletions(-) rename sql-plugin/src/main/{spark350 => spark341db}/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala (98%) diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index b66903955bd..cbb2ee9e703 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -19,6 +19,7 @@ from data_gen import * from marks import * from pyspark.sql.types import * +from spark_init_internal import spark_version from spark_session import with_cpu_session, is_before_spark_320, is_before_spark_330, is_spark_cdh, is_spark_340_or_later from parquet_test import _nested_pruning_schemas from conftest import is_databricks_runtime @@ -805,8 +806,7 @@ def test_simple_partitioned_read_for_multithreaded_combining(spark_tmp_path, kee assert_gpu_and_cpu_are_equal_collect( lambda spark: spark.read.orc(data_path), conf=all_confs) - -@pytest.mark.skipif(is_spark_340_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/8324") +@pytest.mark.skipif(is_spark_340_or_later() and (not (is_databricks_runtime() and spark_version() == "3.4.1")), reason="https://github.com/NVIDIA/spark-rapids/issues/8324") @pytest.mark.parametrize('data_file', ['fixed-length-char-column-from-hive.orc']) @pytest.mark.parametrize('reader', [read_orc_df, read_orc_sql]) def test_read_hive_fixed_length_char(std_input_path, data_file, reader): @@ -819,7 +819,7 @@ def test_read_hive_fixed_length_char(std_input_path, data_file, reader): @allow_non_gpu("ProjectExec") -@pytest.mark.skipif(is_before_spark_340(), reason="https://github.com/NVIDIA/spark-rapids/issues/8324") +@pytest.mark.skipif(is_before_spark_340() or (is_databricks_runtime() and spark_version() == "3.4.1"), reason="https://github.com/NVIDIA/spark-rapids/issues/8324") @pytest.mark.parametrize('data_file', ['fixed-length-char-column-from-hive.orc']) @pytest.mark.parametrize('reader', [read_orc_df, read_orc_sql]) def test_project_fallback_when_reading_hive_fixed_length_char(std_input_path, data_file, reader): diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index b3e04b91d93..bf312e2dd81 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -819,6 +819,7 @@ def read_timestamp_nano_parquet(spark): @pytest.mark.skipif(spark_version() >= '3.2.0' and spark_version() < '3.2.4', reason='New config added in 3.2.4') @pytest.mark.skipif(spark_version() >= '3.3.0' and spark_version() < '3.3.2', reason='New config added in 3.3.2') @pytest.mark.skipif(is_databricks_runtime() and spark_version() == '3.3.2', reason='Config not in DB 12.2') +@pytest.mark.skipif(is_databricks_runtime() and spark_version() == '3.4.1', reason='Config not in DB 13.3') @allow_non_gpu('FileSourceScanExec, ColumnarToRowExec') def test_parquet_read_nano_as_longs_true(std_input_path): data_path = "%s/timestamp-nanos.parquet" % (std_input_path) diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index fa563d69e88..3c1e2b0df78 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -20,7 +20,7 @@ from data_gen import * from marks import * from pyspark.sql.types import * -from spark_session import is_before_spark_320, is_before_spark_350, is_jvm_charset_utf8 +from spark_session import is_before_spark_320, is_before_spark_350, is_jvm_charset_utf8, is_databricks_runtime, spark_version if not is_jvm_charset_utf8(): pytestmark = [pytest.mark.regexp, pytest.mark.skip(reason=str("Current locale doesn't support UTF-8, regexp support is disabled"))] @@ -489,7 +489,7 @@ def test_regexp_extract_no_match(): # Spark take care of the error handling @allow_non_gpu('ProjectExec', 'RegExpExtract') def test_regexp_extract_idx_negative(): - message = "The specified group index cannot be less than zero" if is_before_spark_350() else \ + message = "The specified group index cannot be less than zero" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \ "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid" gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') @@ -503,7 +503,7 @@ def test_regexp_extract_idx_negative(): # Spark take care of the error handling @allow_non_gpu('ProjectExec', 'RegExpExtract') def test_regexp_extract_idx_out_of_bounds(): - message = "Regex group count is 3, but the specified group index is 4" if is_before_spark_350() else \ + message = "Regex group count is 3, but the specified group index is 4" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \ "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid: Expects group index between 0 and 3, but got 4." gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}') assert_gpu_and_cpu_error( @@ -826,7 +826,7 @@ def test_regexp_extract_all_idx_positive(): @allow_non_gpu('ProjectExec', 'RegExpExtractAll') def test_regexp_extract_all_idx_negative(): - message = "The specified group index cannot be less than zero" if is_before_spark_350() else \ + message = "The specified group index cannot be less than zero" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \ "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid" gen = mk_str_gen('[abcd]{0,3}') @@ -839,7 +839,7 @@ def test_regexp_extract_all_idx_negative(): @allow_non_gpu('ProjectExec', 'RegExpExtractAll') def test_regexp_extract_all_idx_out_of_bounds(): - message = "Regex group count is 2, but the specified group index is 3" if is_before_spark_350() else \ + message = "Regex group count is 2, but the specified group index is 3" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \ "[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid: Expects group index between 0 and 2, but got 3." gen = mk_str_gen('[a-d]{1,2}.{0,1}[0-9]{1,2}') assert_gpu_and_cpu_error( diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala similarity index 98% rename from sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala rename to sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala index 890aa978001..c207313268a 100644 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala +++ b/sql-plugin/src/main/spark341db/scala/com/nvidia/spark/rapids/shims/PythonUDFShim.scala @@ -15,6 +15,7 @@ */ /*** spark-rapids-shim-json-lines +{"spark": "341db"} {"spark": "350"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims