diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index 7100afaa5e2..840931e20ab 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -18,7 +18,7 @@ from data_gen import * from marks import incompat, approximate_float from pyspark.sql.types import * -from spark_session import with_spark_session +from spark_session import with_spark_session, is_before_spark_310 import pyspark.sql.functions as f @pytest.mark.parametrize('data_gen', numeric_gens, ids=idfn) @@ -360,7 +360,7 @@ def test_expm1(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('expm1(a)')) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://issues.apache.org/jira/browse/SPARK-32640') @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) @@ -369,7 +369,7 @@ def test_log(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('log(a)')) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://issues.apache.org/jira/browse/SPARK-32640') @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) @@ -378,7 +378,7 @@ def test_log1p(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('log1p(a)')) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://issues.apache.org/jira/browse/SPARK-32640') @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) @@ -387,7 +387,7 @@ def test_log2(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr('log2(a)')) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://issues.apache.org/jira/browse/SPARK-32640') @approximate_float @pytest.mark.parametrize('data_gen', double_gens, ids=idfn) diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index 96d17d21d54..573f4e8ccb3 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -18,7 +18,7 @@ from data_gen import * from datetime import date import pyspark.sql.functions as f -from spark_session import with_cpu_session, with_gpu_session +from spark_session import with_cpu_session, with_gpu_session, is_spark_300 from join_test import create_df from generate_expr_test import four_op_df from marks import incompat, allow_non_gpu, ignore_order @@ -61,8 +61,8 @@ def test_passing_gpuExpr_as_Expr(): @pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) @ignore_order def test_cache_join(data_gen, join_type): - if data_gen.data_type == BooleanType(): - pytest.xfail("https://github.com/NVIDIA/spark-rapids/issues/350") + if is_spark_300() and data_gen.data_type == BooleanType(): + pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") def do_join(spark): left, right = create_df(spark, data_gen, 500, 500) @@ -81,8 +81,8 @@ def do_join(spark): @ignore_order def test_cached_join_filter(data_gen, join_type): data, filter = data_gen - if data.data_type == BooleanType(): - pytest.xfail("https://github.com/NVIDIA/spark-rapids/issues/350") + if is_spark_300() and data.data_type == BooleanType(): + pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") def do_join(spark): left, right = create_df(spark, data, 500, 500) @@ -96,8 +96,8 @@ def do_join(spark): @pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) @ignore_order def test_cache_broadcast_hash_join(data_gen, join_type): - if data_gen.data_type == BooleanType(): - pytest.xfail("https://github.com/NVIDIA/spark-rapids/issues/350") + if is_spark_300() and data_gen.data_type == BooleanType(): + pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") def do_join(spark): left, right = create_df(spark, data_gen, 500, 500) @@ -116,8 +116,8 @@ def do_join(spark): @pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) @ignore_order def test_cache_shuffled_hash_join(data_gen, join_type): - if data_gen.data_type == BooleanType(): - pytest.xfail("https://github.com/NVIDIA/spark-rapids/issues/350") + if is_spark_300() and data_gen.data_type == BooleanType(): + pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") def do_join(spark): left, right = create_df(spark, data_gen, 50, 500) @@ -151,8 +151,8 @@ def do_join(spark): @pytest.mark.parametrize('data_gen', all_gen_restricting_dates, ids=idfn) @allow_non_gpu('InMemoryTableScanExec', 'DataWritingCommandExec') def test_cache_posexplode_makearray(spark_tmp_path, data_gen): - if data_gen.data_type == BooleanType(): - pytest.xfail("https://github.com/NVIDIA/spark-rapids/issues/350") + if is_spark_300() and data_gen.data_type == BooleanType(): + pytest.xfail("https://issues.apache.org/jira/browse/SPARK-32672") data_path_cpu = spark_tmp_path + '/PARQUET_DATA_CPU' data_path_gpu = spark_tmp_path + '/PARQUET_DATA_GPU' def write_posExplode(data_path): diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index c5c587021fd..06caf94aba2 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -18,14 +18,14 @@ from datetime import date, datetime, timezone from marks import incompat from pyspark.sql.types import * -from spark_session import with_spark_session +from spark_session import with_spark_session, is_before_spark_310 import pyspark.sql.functions as f # We only support literal intervals for TimeSub vals = [(-584, 1563), (1943, 1101), (2693, 2167), (2729, 0), (44, 1534), (2635, 3319), (1885, -2828), (0, 2463), (932, 2286), (0, 0)] @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://issues.apache.org/jira/browse/SPARK-32640') @pytest.mark.parametrize('data_gen', vals, ids=idfn) def test_timesub(data_gen): diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index b50d12c85af..3f42132eeb4 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -19,7 +19,7 @@ from pyspark.sql.types import * from marks import * import pyspark.sql.functions as f -from spark_session import with_spark_session +from spark_session import with_spark_session, is_spark_300 _no_nans_float_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true', 'spark.rapids.sql.hasNans': 'false', @@ -316,7 +316,7 @@ def test_hash_agg_with_nan_keys(data_gen): @pytest.mark.xfail( - condition=with_spark_session(lambda spark : spark.sparkContext.version == "3.0.0"), + condition=with_spark_session(lambda spark : is_spark_300()), reason="[SPARK-32038][SQL] NormalizeFloatingNumbers should also work on distinct aggregate " "(https://github.com/apache/spark/pull/28876) " "Fixed in later Apache Spark releases.") diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 83a058b401c..01fa612e2f5 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -18,7 +18,7 @@ from conftest import is_databricks_runtime from data_gen import * from marks import ignore_order, allow_non_gpu, incompat -from spark_session import with_spark_session +from spark_session import with_spark_session, is_before_spark_310 all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), BooleanGen(), DateGen(), TimestampGen(), @@ -152,7 +152,7 @@ def do_join(spark): @ignore_order @pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', pytest.param('FullOuter', marks=pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/575')), 'Cross'], ids=idfn) def test_broadcast_join_mixed(join_type): diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index e9e121f6354..4dce6149966 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -19,7 +19,7 @@ from data_gen import * from marks import * from pyspark.sql.types import * -from spark_session import with_cpu_session, with_spark_session +from spark_session import with_cpu_session, with_spark_session, is_before_spark_310 def read_orc_df(data_path): return lambda spark : spark.read.orc(data_path) @@ -200,7 +200,7 @@ def test_compress_write_round_trip(spark_tmp_path, compress): conf={'spark.sql.orc.compression.codec': compress}) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/576') def test_input_meta(spark_tmp_path): first_data_path = spark_tmp_path + '/ORC_DATA/key=0' diff --git a/integration_tests/src/main/python/qa_nightly_sql.py b/integration_tests/src/main/python/qa_nightly_sql.py index 9024e025d5a..4f811ff0116 100644 --- a/integration_tests/src/main/python/qa_nightly_sql.py +++ b/integration_tests/src/main/python/qa_nightly_sql.py @@ -13,7 +13,7 @@ # limitations under the License. from conftest import is_databricks_runtime -from spark_session import with_spark_session +from spark_session import with_spark_session, is_before_spark_310 import pytest SELECT_SQL = [ @@ -745,16 +745,16 @@ ("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF RIGHT JOIN test_table1 ON test_table.strF=test_table1.strF"), ("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF RIGHT JOIN test_table1 ON test_table.dateF=test_table1.dateF"), ("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF RIGHT JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), -pytest.param(("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), -pytest.param(("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), marks=pytest.mark.xfail(condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), reason='https://github.com/NVIDIA/spark-rapids/issues/578')) +pytest.param(("SELECT test_table.byteF as byteF, test_table1.byteF as byteF1 from test_table FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF", "test_table.byteF, test_table1.byteF FULL JOIN test_table1 ON test_table.byteF=test_table1.byteF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.shortF as shortF, test_table1.shortF as shortF1 from test_table FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF", "test_table.shortF, test_table1.shortF FULL JOIN test_table1 ON test_table.shortF=test_table1.shortF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.intF as intF, test_table1.intF as intF1 from test_table FULL JOIN test_table1 ON test_table.intF=test_table1.intF", "test_table.intF, test_table1.intF FULL JOIN test_table1 ON test_table.intF=test_table1.intF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.longF as longF, test_table1.longF as longF1 from test_table FULL JOIN test_table1 ON test_table.longF=test_table1.longF", "test_table.longF, test_table1.longF FULL JOIN test_table1 ON test_table.longF=test_table1.longF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.floatF as floatF, test_table1.floatF as floatF1 from test_table FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF", "test_table.floatF, test_table1.floatF FULL JOIN test_table1 ON test_table.floatF=test_table1.floatF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.doubleF as doubleF, test_table1.doubleF as doubleF1 from test_table FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF", "test_table.doubleF, test_table1.doubleF FULL JOIN test_table1 ON test_table.doubleF=test_table1.doubleF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.booleanF as booleanF, test_table1.booleanF as booleanF1 from test_table FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF", "test_table.booleanF, test_table1.booleanF FULL JOIN test_table1 ON test_table.booleanF=test_table1.booleanF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.strF as strF, test_table1.strF as strF1 from test_table FULL JOIN test_table1 ON test_table.strF=test_table1.strF", "test_table.strF, test_table1.strF FULL JOIN test_table1 ON test_table.strF=test_table1.strF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.dateF as dateF, test_table1.dateF as dateF1 from test_table FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF", "test_table.dateF, test_table1.dateF FULL JOIN test_table1 ON test_table.dateF=test_table1.dateF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')), +pytest.param(("SELECT test_table.timestampF as timestampF, test_table1.timestampF as timestampF1 from test_table FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF", "test_table.timestampF, test_table1.timestampF FULL JOIN test_table1 ON test_table.timestampF=test_table1.timestampF"), marks=pytest.mark.xfail(condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/578')) ] SELECT_PRE_ORDER_SQL=[ diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index 9aaabd1cde3..65e1e0335eb 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -41,3 +41,5 @@ def get_spark_i_know_what_i_am_doing(): """ return _spark +def spark_version(): + return _spark.version diff --git a/integration_tests/src/main/python/spark_session.py b/integration_tests/src/main/python/spark_session.py index 1a4b8a7ba29..7394f445a2c 100644 --- a/integration_tests/src/main/python/spark_session.py +++ b/integration_tests/src/main/python/spark_session.py @@ -14,7 +14,7 @@ from conftest import is_allowing_any_non_gpu, get_non_gpu_allowed from pyspark.sql import SparkSession, DataFrame -from spark_init_internal import get_spark_i_know_what_i_am_doing +from spark_init_internal import get_spark_i_know_what_i_am_doing, spark_version def _from_scala_map(scala_map): ret = {} @@ -90,3 +90,8 @@ def with_gpu_session(func, conf={}): copy['spark.rapids.sql.test.allowedNonGpu'] = ','.join(get_non_gpu_allowed()) return with_spark_session(func, conf=copy) +def is_spark_300(): + return spark_version() == "3.0.0" + +def is_before_spark_310(): + return spark_version() < "3.1.0" diff --git a/integration_tests/src/main/python/tpch_test.py b/integration_tests/src/main/python/tpch_test.py index fe39073bc1f..72fa68295d2 100644 --- a/integration_tests/src/main/python/tpch_test.py +++ b/integration_tests/src/main/python/tpch_test.py @@ -16,7 +16,7 @@ from asserts import assert_gpu_and_cpu_are_equal_collect from marks import approximate_float, incompat, ignore_order, allow_non_gpu -from spark_session import with_spark_session +from spark_session import with_spark_session, is_before_spark_310 _base_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true', 'spark.rapids.sql.hasNans': 'false'} @@ -119,7 +119,7 @@ def test_tpch_q15(tpch, conf): lambda spark : tpch.do_test_query("q15")) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/586') @allow_non_gpu('BroadcastNestedLoopJoinExec', 'Or', 'IsNull', 'EqualTo', 'AttributeReference', 'BroadcastExchangeExec') @pytest.mark.parametrize('conf', [_base_conf, _adaptive_conf]) @@ -128,7 +128,7 @@ def test_tpch_q16(tpch, conf): lambda spark : tpch.do_test_query("q16"), conf=conf) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/586') @approximate_float @pytest.mark.parametrize('conf', [_base_conf, _adaptive_conf]) @@ -137,7 +137,7 @@ def test_tpch_q17(tpch, conf): lambda spark : tpch.do_test_query("q17"), conf=conf) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/586') @incompat @approximate_float @@ -154,7 +154,7 @@ def test_tpch_q19(tpch, conf): lambda spark : tpch.do_test_query("q19"), conf=conf) @pytest.mark.xfail( - condition=with_spark_session(lambda spark : not(spark.sparkContext.version < "3.1.0")), + condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/586') @pytest.mark.parametrize('conf', [_base_conf, _adaptive_conf]) def test_tpch_q20(tpch, conf):