Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Parquet test_round_trip to avoid CPU write exception #3203

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ def read_parquet_sql(data_path):
decimal_gens = [DecimalGen(), DecimalGen(precision=7, scale=3), DecimalGen(precision=10, scale=10),
DecimalGen(precision=9, scale=0), DecimalGen(precision=18, scale=15)]

rebase_write_corrected_conf = {
'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED',
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED'
}

rebase_write_legacy_conf = {
'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY',
'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'LEGACY'
}

# Like the standard map_gens_sample but with timestamps limited
parquet_map_gens = [MapGen(f(nullable=False), f()) for f in [
BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DateGen,
lambda nullable=True: TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc), nullable=nullable)]] +\
[simple_string_to_string_map_gen,
MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]

parquet_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen,
TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)), ArrayGen(byte_gen),
Expand All @@ -40,7 +59,7 @@ def read_parquet_sql(data_path):
ArrayGen(ArrayGen(byte_gen)),
StructGen([['child0', ArrayGen(byte_gen)], ['child1', byte_gen], ['child2', float_gen], ['child3', DecimalGen()]]),
ArrayGen(StructGen([['child0', string_gen], ['child1', double_gen], ['child2', int_gen]]))] +
map_gens_sample + decimal_gens,
parquet_map_gens + decimal_gens,
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/132'))]

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
Expand All @@ -60,7 +79,7 @@ def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader_confs,
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'})
# once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround
Expand Down Expand Up @@ -122,7 +141,7 @@ def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1_enabled
s0 = gen_scalar(parquet_gen, force_no_nulls=True)
with_cpu_session(
lambda spark : gen_df(spark, gen_list).orderBy('a').write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
rf = read_func(data_path)
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand All @@ -147,6 +166,7 @@ def test_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_rebase, v1_
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand All @@ -166,6 +186,7 @@ def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand Down Expand Up @@ -230,7 +251,7 @@ def test_read_round_trip_legacy(spark_tmp_path, parquet_gens, v1_enabled_list, r
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'})
conf=rebase_write_legacy_conf)
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
Expand All @@ -249,15 +270,15 @@ def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list, reader_confs):
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0/key2=20'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(first_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'})
conf=rebase_write_legacy_conf)
second_data_path = spark_tmp_path + '/PARQUET_DATA/key=1/key2=21'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(second_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
third_data_path = spark_tmp_path + '/PARQUET_DATA/key=2/key2=22'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(third_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand All @@ -274,11 +295,11 @@ def test_partitioned_read_just_partitions(spark_tmp_path, v1_enabled_list, reade
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(first_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'})
conf=rebase_write_legacy_conf)
second_data_path = spark_tmp_path + '/PARQUET_DATA/key=1'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).write.parquet(second_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand Down Expand Up @@ -323,12 +344,12 @@ def test_read_merge_schema(spark_tmp_path, v1_enabled_list, reader_confs):
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, first_gen_list).write.parquet(first_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'})
conf=rebase_write_legacy_conf)
second_gen_list = [(('_c' if i % 2 == 0 else '_b') + str(i), gen) for i, gen in enumerate(parquet_gens)]
second_data_path = spark_tmp_path + '/PARQUET_DATA/key=1'
with_cpu_session(
lambda spark : gen_df(spark, second_gen_list).write.parquet(second_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list})
Expand All @@ -348,12 +369,12 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, reader_con
first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0'
with_cpu_session(
lambda spark : gen_df(spark, first_gen_list).write.parquet(first_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'})
conf=rebase_write_legacy_conf)
second_gen_list = [(('_c' if i % 2 == 0 else '_b') + str(i), gen) for i, gen in enumerate(parquet_gens)]
second_data_path = spark_tmp_path + '/PARQUET_DATA/key=1'
with_cpu_session(
lambda spark : gen_df(spark, second_gen_list).write.parquet(second_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.parquet.mergeSchema': "true"})
Expand Down Expand Up @@ -416,7 +437,7 @@ def test_small_file_memory(spark_tmp_path, v1_enabled_list):
first_data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, gen_list).repartition(2000).write.parquet(first_data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path),
Expand Down Expand Up @@ -448,7 +469,7 @@ def test_nested_pruning(spark_tmp_path, data_gen, read_schema, reader_confs, v1_
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : gen_df(spark, data_gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'})
conf=rebase_write_corrected_conf)
all_confs = reader_confs.copy()
all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.optimizer.nestedSchemaPruning.enabled': nested_enabled,
Expand Down