Skip to content

Commit

Permalink
Force datagen_seed for conv_dec_to_from_hex, cast_string_date_valid_f…
Browse files Browse the repository at this point in the history
…ormat (#9785)

Relates to #9781 and 9784

Signed-off-by: Gera Shegalov <[email protected]>
  • Loading branch information
gerashegalov authored Nov 20, 2023
1 parent 02e40d9 commit 15e58aa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
15 changes: 8 additions & 7 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_cast_nested(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.col('a').cast(to_type)))

@datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/9781")
def test_cast_string_date_valid_format():
# In Spark 3.2.0+ the valid format changed, and we cannot support all of the format.
# This provides values that are valid in all of those formats.
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_cast_long_to_decimal_overflow():
f.col('a').cast(DecimalType(18, -1))))

# casting these types to string should be passed
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
basic_array_struct_gens_for_cast_to_string = [f() for f in basic_gens_for_cast_to_string] + [null_gen] + decimal_gens

# We currently do not generate the exact string as Spark for some decimal values of zero
Expand Down Expand Up @@ -300,7 +301,7 @@ def _assert_cast_to_string_equal (data_gen, conf):
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy})


Expand All @@ -319,7 +320,7 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_map_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy})


Expand All @@ -338,7 +339,7 @@ def test_cast_map_with_unmatched_element_to_string(data_gen, legacy):
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_struct_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy}
)

Expand All @@ -355,7 +356,7 @@ def was_broken_for_nested_null(spark):
return df.select(df._1.cast(StringType()))

assert_gpu_and_cpu_are_equal_collect(
was_broken_for_nested_null,
was_broken_for_nested_null,
{"spark.sql.legacy.castComplexTypesToString.enabled": 'true' if cast_conf == 'LEGACY' else 'false'}
)

Expand All @@ -372,7 +373,7 @@ def broken_df(spark):
return df.select(df.a.cast(StringType())).filter(df.b > 1)

assert_gpu_and_cpu_are_equal_collect(
broken_df,
broken_df,
{"spark.sql.legacy.castComplexTypesToString.enabled": 'true' if cast_conf == 'LEGACY' else 'false'}
)

Expand All @@ -381,7 +382,7 @@ def broken_df(spark):
@pytest.mark.xfail(reason='casting this type to string is not an exact match')
def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
data_gen,
{"spark.rapids.sql.castFloatToString.enabled" : "true",
"spark.sql.legacy.castComplexTypesToString.enabled": legacy}
)
Expand Down
9 changes: 5 additions & 4 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ def test_like_complex_escape():
pytest.param(10, r'-?[0-9]{1,18}', id='from_10'),
pytest.param(16, r'-?[0-9a-fA-F]{1,15}', id='from_16')
])
@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9784')
# to_base can be positive and negative
@pytest.mark.parametrize('to_base', [10, 16], ids=['to_plus10', 'to_plus16'])
def test_conv_dec_to_from_hex(from_base, to_base, pattern):
Expand All @@ -798,10 +799,10 @@ def test_conv_dec_to_from_hex(from_base, to_base, pattern):
conf={'spark.rapids.sql.expression.Conv': True}
)

format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0),
DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5),
DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10),
DecimalGen(precision=38, scale=-10),
format_number_gens = integral_gens + [DecimalGen(precision=7, scale=7), DecimalGen(precision=18, scale=0),
DecimalGen(precision=18, scale=3), DecimalGen(precision=36, scale=5),
DecimalGen(precision=36, scale=-5), DecimalGen(precision=38, scale=10),
DecimalGen(precision=38, scale=-10),
DecimalGen(precision=38, scale=30, special_cases=[Decimal('0.000125')]),
DecimalGen(precision=38, scale=32, special_cases=[Decimal('0.000125')]),
DecimalGen(precision=38, scale=37, special_cases=[Decimal('0.000125')])]
Expand Down

0 comments on commit 15e58aa

Please sign in to comment.