Skip to content

Commit

Permalink
Window tests with smaller batches (#2482)
Browse files Browse the repository at this point in the history
* Window tests with smaller batches

Signed-off-by: Robert (Bobby) Evans <[email protected]>

* Addressed review comments

* Update integration_tests/src/main/python/data_gen.py

Co-authored-by: Gera Shegalov <[email protected]>

Co-authored-by: Gera Shegalov <[email protected]>
  • Loading branch information
revans2 and gerashegalov authored May 24, 2021
1 parent 1cf741b commit 3f64354
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 24 deletions.
13 changes: 8 additions & 5 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,14 @@ class DateGen(DataGen):
def __init__(self, start=None, end=None, nullable=True):
super().__init__(DateType(), nullable=nullable)
if start is None:
# spark supports times starting at
# Spark supports times starting at
# "0001-01-01 00:00:00.000000"
start = date(1, 1, 1)
elif not isinstance(start, date):
raise RuntimeError('Unsupported type passed in for start {}'.format(start))

if end is None:
# spark supports time through
# Spark supports time through
# "9999-12-31 23:59:59.999999"
end = date(9999, 12, 31)
elif isinstance(end, timedelta):
Expand Down Expand Up @@ -513,14 +513,17 @@ class TimestampGen(DataGen):
def __init__(self, start=None, end=None, nullable=True):
super().__init__(TimestampType(), nullable=nullable)
if start is None:
# spark supports times starting at
# Spark supports times starting at
# "0001-01-01 00:00:00.000000"
start = datetime(1, 1, 1, tzinfo=timezone.utc)
# but it has issues if you get really close to that because it tries to do things
# in a different format which causes roundoff, so we have to add a few days,
# just to be sure
start = datetime(1, 1, 3, tzinfo=timezone.utc)
elif not isinstance(start, datetime):
raise RuntimeError('Unsupported type passed in for start {}'.format(start))

if end is None:
# spark supports time through
# Spark supports time through
# "9999-12-31 23:59:59.999999"
end = datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
elif isinstance(end, timedelta):
Expand Down
41 changes: 22 additions & 19 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,17 @@ def test_window_aggs_for_ranges_numeric_long_overflow(data_gen):


@ignore_order
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('data_gen', [
_grpkey_byte_with_nulls,
_grpkey_short_with_nulls,
_grpkey_int_with_nulls,
_grpkey_long_with_nulls
], ids=idfn)
def test_window_aggs_for_range_numeric(data_gen):
def test_window_aggs_for_range_numeric(data_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.window.range.byte.enabled': True,
'spark.rapids.sql.window.range.short.enabled': True}
assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
Expand Down Expand Up @@ -192,19 +196,20 @@ def test_window_aggs_for_range_numeric(data_gen):
' (partition by a order by b asc '
' range between UNBOUNDED preceding and UNBOUNDED following) as max_b_unbounded '
'from window_agg_table ',
conf={'spark.rapids.sql.window.range.byte.enabled': True,
'spark.rapids.sql.window.range.short.enabled': True})

conf = conf)

@ignore_order
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
_grpkey_longs_with_nulls,
_grpkey_longs_with_timestamps,
_grpkey_longs_with_nullable_timestamps,
_grpkey_longs_with_decimals,
_grpkey_longs_with_nullable_decimals,
_grpkey_decimals_with_nulls], ids=idfn)
def test_window_aggs_for_rows(data_gen):
def test_window_aggs_for_rows(data_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.castFloatToDecimal.enabled': True}
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, data_gen, length=2048),
"window_agg_table",
Expand All @@ -224,7 +229,7 @@ def test_window_aggs_for_rows(data_gen):
' row_number() over '
' (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num '
'from window_agg_table ',
conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})
conf = conf)


part_and_order_gens = [long_gen, DoubleGen(no_nans=True, special_cases=[]),
Expand All @@ -240,10 +245,13 @@ def tmp(something):

@ignore_order
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('c_gen', lead_lag_data_gens, ids=idfn)
@pytest.mark.parametrize('b_gen', part_and_order_gens, ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', part_and_order_gens, ids=meta_idfn('partBy:'))
def test_multi_types_window_aggs_for_rows_lead_lag(a_gen, b_gen, c_gen):
def test_multi_types_window_aggs_for_rows_lead_lag(a_gen, b_gen, c_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
'spark.rapids.sql.hasNans': False}
data_gen = [
('a', RepeatSeqGen(a_gen, length=20)),
('b', b_gen),
Expand All @@ -270,7 +278,7 @@ def do_it(spark):
.withColumn('lag_1_c', f.lag('c', 1).over(baseWindowSpec)) \
.withColumn('lag_def_c', f.lag('c', 4, defaultVal).over(baseWindowSpec)) \
.withColumn('row_num', f.row_number().over(baseWindowSpec))
assert_gpu_and_cpu_are_equal_collect(do_it, conf={'spark.rapids.sql.hasNans': 'false'})
assert_gpu_and_cpu_are_equal_collect(do_it, conf = conf)


lead_lag_array_data_gens =\
Expand All @@ -279,20 +287,14 @@ def do_it(spark):
[ArrayGen(ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10), max_length=10) \
for sub_gen in lead_lag_data_gens]

# lead and lag are supported for arrays, but the other window operations like min and max are not right now
# once they are all supported the tests should be combined.
@pytest.mark.skip(reason="If some rows of order-by columns (here is a,b,c) are equal, then it may fail because"
"CPU and GPU can't guarantee the order for the same rows, while lead/lag is typically"
"depending on row's order. The solution is we should add the d and d_default columns"
"into the order-by to guarantee the order. But for now, sorting on array has not been"
"supported yet, see https://github.com/NVIDIA/spark-rapids/issues/2470."
"Once the issue is resolved, we should remove skip mark")
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('d_gen', lead_lag_array_data_gens, ids=meta_idfn('agg:'))
@pytest.mark.parametrize('c_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('c_gen', [LongRangeGen()], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('b_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', [long_gen], ids=meta_idfn('partBy:'))
def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen):
def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen, batch_size):
conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
data_gen = [
('a', RepeatSeqGen(a_gen, length=20)),
('b', b_gen),
Expand All @@ -310,7 +312,8 @@ def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen):
LAG(d, 5) OVER (PARTITION by a ORDER BY b,c) lag_d_5,
LAG(d, 2, d_default) OVER (PARTITION by a ORDER BY b,c) lag_d_2_default
FROM window_agg_table
''')
''',
conf = conf)


# lead and lag don't currently work for string columns, so redo the tests, but just for strings
Expand Down

0 comments on commit 3f64354

Please sign in to comment.