Skip to content

Commit

Permalink
Fix Integration Test Failures for Databricks 13.3 Support (NVIDIA#9646)
Browse files Browse the repository at this point in the history
* Fixed orc_test, parquet_test and regexp_test

* Added Support for PythonUDAF

* moved the PythonUDAF override to it's correct place

* removed left over imports from the bad commit

* build fix after upmerge

* fixed imports

* fix 341db

* Signing off

Signed-off-by: Raza Jafri <[email protected]>

* enable test_read_hive_fixed_length_char for 341db

---------

Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored Nov 10, 2023
1 parent b414029 commit f4a898c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
6 changes: 3 additions & 3 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from data_gen import *
from marks import *
from pyspark.sql.types import *
from spark_init_internal import spark_version
from spark_session import with_cpu_session, is_before_spark_320, is_before_spark_330, is_spark_cdh, is_spark_340_or_later
from parquet_test import _nested_pruning_schemas
from conftest import is_databricks_runtime
Expand Down Expand Up @@ -805,8 +806,7 @@ def test_simple_partitioned_read_for_multithreaded_combining(spark_tmp_path, kee
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.orc(data_path), conf=all_confs)


@pytest.mark.skipif(is_spark_340_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/8324")
@pytest.mark.skipif(is_spark_340_or_later() and (not (is_databricks_runtime() and spark_version() == "3.4.1")), reason="https://github.com/NVIDIA/spark-rapids/issues/8324")
@pytest.mark.parametrize('data_file', ['fixed-length-char-column-from-hive.orc'])
@pytest.mark.parametrize('reader', [read_orc_df, read_orc_sql])
def test_read_hive_fixed_length_char(std_input_path, data_file, reader):
Expand All @@ -819,7 +819,7 @@ def test_read_hive_fixed_length_char(std_input_path, data_file, reader):


@allow_non_gpu("ProjectExec")
@pytest.mark.skipif(is_before_spark_340(), reason="https://github.com/NVIDIA/spark-rapids/issues/8324")
@pytest.mark.skipif(is_before_spark_340() or (is_databricks_runtime() and spark_version() == "3.4.1"), reason="https://github.com/NVIDIA/spark-rapids/issues/8324")
@pytest.mark.parametrize('data_file', ['fixed-length-char-column-from-hive.orc'])
@pytest.mark.parametrize('reader', [read_orc_df, read_orc_sql])
def test_project_fallback_when_reading_hive_fixed_length_char(std_input_path, data_file, reader):
Expand Down
1 change: 1 addition & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def read_timestamp_nano_parquet(spark):
@pytest.mark.skipif(spark_version() >= '3.2.0' and spark_version() < '3.2.4', reason='New config added in 3.2.4')
@pytest.mark.skipif(spark_version() >= '3.3.0' and spark_version() < '3.3.2', reason='New config added in 3.3.2')
@pytest.mark.skipif(is_databricks_runtime() and spark_version() == '3.3.2', reason='Config not in DB 12.2')
@pytest.mark.skipif(is_databricks_runtime() and spark_version() == '3.4.1', reason='Config not in DB 13.3')
@allow_non_gpu('FileSourceScanExec, ColumnarToRowExec')
def test_parquet_read_nano_as_longs_true(std_input_path):
data_path = "%s/timestamp-nanos.parquet" % (std_input_path)
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from data_gen import *
from marks import *
from pyspark.sql.types import *
from spark_session import is_before_spark_320, is_before_spark_350, is_jvm_charset_utf8
from spark_session import is_before_spark_320, is_before_spark_350, is_jvm_charset_utf8, is_databricks_runtime, spark_version

if not is_jvm_charset_utf8():
pytestmark = [pytest.mark.regexp, pytest.mark.skip(reason=str("Current locale doesn't support UTF-8, regexp support is disabled"))]
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_regexp_extract_no_match():
# Spark take care of the error handling
@allow_non_gpu('ProjectExec', 'RegExpExtract')
def test_regexp_extract_idx_negative():
message = "The specified group index cannot be less than zero" if is_before_spark_350() else \
message = "The specified group index cannot be less than zero" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid"

gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}')
Expand All @@ -503,7 +503,7 @@ def test_regexp_extract_idx_negative():
# Spark take care of the error handling
@allow_non_gpu('ProjectExec', 'RegExpExtract')
def test_regexp_extract_idx_out_of_bounds():
message = "Regex group count is 3, but the specified group index is 4" if is_before_spark_350() else \
message = "Regex group count is 3, but the specified group index is 4" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract` is invalid: Expects group index between 0 and 3, but got 4."
gen = mk_str_gen('[abcd]{1,3}[0-9]{1,3}[abcd]{1,3}')
assert_gpu_and_cpu_error(
Expand Down Expand Up @@ -826,7 +826,7 @@ def test_regexp_extract_all_idx_positive():

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_negative():
message = "The specified group index cannot be less than zero" if is_before_spark_350() else \
message = "The specified group index cannot be less than zero" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid"

gen = mk_str_gen('[abcd]{0,3}')
Expand All @@ -839,7 +839,7 @@ def test_regexp_extract_all_idx_negative():

@allow_non_gpu('ProjectExec', 'RegExpExtractAll')
def test_regexp_extract_all_idx_out_of_bounds():
message = "Regex group count is 2, but the specified group index is 3" if is_before_spark_350() else \
message = "Regex group count is 2, but the specified group index is 3" if is_before_spark_350() and not (is_databricks_runtime() and spark_version() == "3.4.1") else \
"[INVALID_PARAMETER_VALUE.REGEX_GROUP_INDEX] The value of parameter(s) `idx` in `regexp_extract_all` is invalid: Expects group index between 0 and 2, but got 3."
gen = mk_str_gen('[a-d]{1,2}.{0,1}[0-9]{1,2}')
assert_gpu_and_cpu_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

/*** spark-rapids-shim-json-lines
{"spark": "341db"}
{"spark": "350"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims
Expand Down

0 comments on commit f4a898c

Please sign in to comment.