diff --git a/tests/test_transformations.py b/tests/test_transformations.py index ee24ac6f..25c51967 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -16,6 +16,19 @@ ) from quinn.transformations import flatten_struct, flatten_map, flatten_dataframe from .spark import spark +from functools import wraps + + +def skip_if_spark_connect_mode(func): + @wraps(func) + def wrapper(*args, **kwargs): + spark_version = args[0].version # Assuming the first argument is the Spark session + if spark_version < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"): + pytest.skip( + "Skipping test because sort_columns is not supported in Spark-Connect mode for Spark versions < 3.5.2" + ) + return func(*args, **kwargs) + return wrapper def describe_with_columns_renamed(): @@ -295,10 +308,12 @@ def _get_simple_test_dataframes(sort_order) -> tuple[(DataFrame, DataFrame)]: ) +@skip_if_spark_connect_mode def test_sort_struct_flat(): _test_sort_struct_flat(spark, "asc") +@skip_if_spark_connect_mode def test_sort_struct_flat_desc(): _test_sort_struct_flat(spark, "desc") @@ -537,15 +552,11 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: return unsorted_df, expected_df unsorted_df, expected_df = _get_test_dataframes() - if spark.version < "3.5.2" and os.getenv("SPARK_CONNECT_MODE_ENABLED"): - with pytest.raises(Exception) as excinfo: - quinn.sort_columns(unsorted_df, "asc", sort_nested=True) - assert str(excinfo.value) == "sort_columns is not supported on Spark-Connect mode for Spark versions < 3.5.2" - else: - sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable - ) + + sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) + chispa.schema_comparer.assert_schema_equality( + sorted_df.schema, expected_df.schema, ignore_nullable + ) def _test_sort_struct_nested_with_arraytypes_desc(spark, ignore_nullable: bool): @@ -704,34 +715,42 @@ def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: ) +@skip_if_spark_connect_mode def test_sort_struct_nested(): _test_sort_struct_nested(spark, True) +@skip_if_spark_connect_mode def test_sort_struct_nested_desc(): _test_sort_struct_nested_desc(spark, True) +@skip_if_spark_connect_mode def test_sort_struct_nested_with_arraytypes(): _test_sort_struct_nested_with_arraytypes(spark, True) +@skip_if_spark_connect_mode def test_sort_struct_nested_with_arraytypes_desc(): _test_sort_struct_nested_with_arraytypes_desc(spark, True) +@skip_if_spark_connect_mode def test_sort_struct_nested_nullable(): _test_sort_struct_nested(spark, True) +@skip_if_spark_connect_mode def test_sort_struct_nested_nullable_desc(): _test_sort_struct_nested_desc(spark, False) +@skip_if_spark_connect_mode def test_sort_struct_nested_with_arraytypes_nullable(): _test_sort_struct_nested_with_arraytypes(spark, False) +@skip_if_spark_connect_mode def test_sort_struct_nested_with_arraytypes_nullable_desc(): _test_sort_struct_nested_with_arraytypes_desc(spark, True)