Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent a degenerative join in test_dpp_reuse_broadcast_exchange [databricks] #10168

Merged
merged 8 commits into from
Jan 10, 2024
38 changes: 31 additions & 7 deletions integration_tests/src/main/python/data_gen.py
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -329,19 +329,40 @@ def start(self, rand):
self._start(rand, lambda: self.next_val())

class RepeatSeqGen(DataGen):
"""Generate Repeated seq of `length` random items"""
def __init__(self, child, length):
super().__init__(child.data_type, nullable=False)
self.nullable = child.nullable
"""Generate Repeated seq of `length` random items if child is a DataGen,
otherwise repeat the provided seq when child is a list.

When child is a list:
data_type must be specified
length must be <= length of child
nullable is honored
When child is a DataGen:
length must be specified
data_type must be None or match child's
nullable is set child's nullable attribute
"""
def __init__(self, child, length=None, data_type=None, nullable=False):
if isinstance(child, list):
super().__init__(data_type, nullable=False)
self.nullable = nullable
assert (length is None or length < len(child))
self._length = length if length is not None else len(child)
else:
super().__init__(child.data_type, nullable=False)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
self.nullable = child.nullable
assert(data_type is None or data_type != child.data_type)
assert(length is not None)
self._length = length
self._child = child
self._vals = []
self._length = length
self._index = 0

def __repr__(self):
return super().__repr__() + '(' + str(self._child) + ')'

def _cache_repr(self):
if isinstance(self._child, list):
return super()._cache_repr() + '(' + str(self._child) + ',' + str(self._length) + ')'
return super()._cache_repr() + '(' + self._child._cache_repr() + ',' + str(self._length) + ')'

def _loop_values(self):
Expand All @@ -351,9 +372,12 @@ def _loop_values(self):

def start(self, rand):
self._index = 0
self._child.start(rand)
self._start(rand, self._loop_values)
self._vals = [self._child.gen() for _ in range(0, self._length)]
if isinstance(self._child, list):
self._vals = self._child
jlowe marked this conversation as resolved.
Show resolved Hide resolved
else:
self._child.start(rand)
self._vals = [self._child.gen() for _ in range(0, self._length)]

class SetValuesGen(DataGen):
"""A set of values that are randomly selected"""
Expand Down
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/dpp_test.py
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,20 +14,25 @@

import pytest

from pyspark.sql.types import IntegerType

from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture, assert_gpu_and_cpu_are_equal_collect
from conftest import spark_tmp_table_factory
from data_gen import *
from marks import ignore_order, allow_non_gpu
from spark_session import is_before_spark_320, with_cpu_session, is_before_spark_312, is_databricks_runtime, is_databricks113_or_later

# non-positive values here can produce a degenerative join, so here we ensure that most values are
# positive to ensure the join will produce rows. See https://github.com/NVIDIA/spark-rapids/issues/10147
value_gen = RepeatSeqGen([None, INT_MIN, -1, 0, 1, INT_MAX], data_type=IntegerType(), nullable=False)

def create_dim_table(table_name, table_format, length=500):
def fn(spark):
df = gen_df(spark, [
('key', IntegerGen(nullable=False, min_val=0, max_val=9, special_cases=[])),
('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])),
('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])),
('value', int_gen),
('value', value_gen),
# specify nullable=False for `filter` to avoid generating invalid SQL with
# expression `filter = None` (https://github.com/NVIDIA/spark-rapids/issues/9817)
('filter', RepeatSeqGen(
Expand All @@ -49,7 +54,7 @@ def fn(spark):
('skey', IntegerGen(nullable=False, min_val=0, max_val=4, special_cases=[])),
# ex_key is not a partition column
('ex_key', IntegerGen(nullable=False, min_val=0, max_val=3, special_cases=[])),
('value', int_gen)], length)
('value', value_gen)], length)
df.write.format(table_format) \
.mode("overwrite") \
.partitionBy('key', 'skey') \
Expand Down