Skip to content

Commit

Permalink
Add support for Structs for UnionExec (NVIDIA#1919)
Browse files Browse the repository at this point in the history
* union support for structs

Signed-off-by: Raza Jafri <[email protected]>

* added more tests

Signed-off-by: Raza Jafri <[email protected]>

* format change

Signed-off-by: Raza Jafri <[email protected]>

* added missing all_gen

Signed-off-by: Raza Jafri <[email protected]>

* addressed review comments

Signed-off-by: Raza Jafri <[email protected]>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala

Co-authored-by: Gera Shegalov <[email protected]>
Signed-off-by: Raza Jafri <[email protected]>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala

Co-authored-by: Gera Shegalov <[email protected]>
Signed-off-by: Raza Jafri <[email protected]>

* line exceeds limit

Signed-off-by: Raza Jafri <[email protected]>

* added comments for tests

Signed-off-by: Raza Jafri <[email protected]>

Co-authored-by: Raza Jafri <[email protected]>
Co-authored-by: Gera Shegalov <[email protected]>
  • Loading branch information
3 people authored Mar 23, 2021
1 parent 8746585 commit 80311bf
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ Accelerator supports are described below.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (unionByName will not optionally impute nulls for missing struct fields when the column is a struct and there are non-overlapping fields; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_310
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_311
import pyspark.sql.functions as f

decimal_gens_not_max_prec = [decimal_gen_neg_scale, decimal_gen_scale_precision,
Expand Down Expand Up @@ -524,7 +524,7 @@ def _test_div_by_zero(ansi_mode, expr):


@pytest.mark.parametrize('expr', ['1/0', 'a/0', 'a/b'])
@pytest.mark.xfail(condition=is_before_spark_310(), reason='https://github.com/apache/spark/pull/29882')
@pytest.mark.xfail(condition=is_before_spark_311(), reason='https://github.com/apache/spark/pull/29882')
def test_div_by_zero_ansi(expr):
_test_div_by_zero(ansi_mode='ansi', expr=expr)

Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyspark.sql.types import *
from marks import *
import pyspark.sql.functions as f
from spark_session import with_spark_session, is_spark_300, is_before_spark_310
from spark_session import with_spark_session, is_spark_300, is_before_spark_311

_no_nans_float_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true',
'spark.rapids.sql.hasNans': 'false',
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_distinct_count_reductions(data_gen):
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'count(DISTINCT a)'))

@pytest.mark.xfail(condition=is_before_spark_310(),
@pytest.mark.xfail(condition=is_before_spark_311(),
reason='Spark fixed distinct count of NaNs in 3.1')
@pytest.mark.parametrize('data_gen', [float_gen, double_gen], ids=idfn)
def test_distinct_float_count_reductions(data_gen):
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from conftest import is_databricks_runtime, is_emr_runtime
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat
from spark_session import with_cpu_session, with_spark_session, is_before_spark_310
from spark_session import with_cpu_session, with_spark_session

all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
BooleanGen(), DateGen(), TimestampGen(), null_gen,
Expand Down
30 changes: 28 additions & 2 deletions integration_tests/src/main/python/repart_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,42 @@
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from spark_session import is_before_spark_311
from data_gen import *
from marks import ignore_order
import pyspark.sql.functions as f

@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
nested_scalar_mark=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1459")
@pytest.mark.parametrize('data_gen', [pytest.param((StructGen([['child0', DecimalGen(7, 2)]]),
StructGen([['child1', IntegerGen()]])), marks=nested_scalar_mark),
(StructGen([['child0', DecimalGen(7, 2)]], nullable=False),
StructGen([['child1', IntegerGen()]], nullable=False))], ids=idfn)
@pytest.mark.skipif(is_before_spark_311(), reason="This is supported only in Spark 3.1.1+")
# This tests the union of DF of structs with different types of cols as long as the struct itself
# isn't null. This is a limitation in cudf because we don't support nested types as literals
def test_union_struct_missing_children(data_gen):
left_gen, right_gen = data_gen
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, left_gen).unionByName(binary_op_df(
spark, right_gen), True))

@pytest.mark.parametrize('data_gen', all_gen + [all_basic_struct_gen, StructGen([['child0', DecimalGen(7, 2)]])], ids=idfn)
# This tests union of two DFs of two cols each. The types of the left col and right col is the same
def test_union(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).union(binary_op_df(spark, data_gen)))

@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gen + [pytest.param(all_basic_struct_gen, marks=nested_scalar_mark),
pytest.param(StructGen([[ 'child0', DecimalGen(7, 2)]], nullable=False), marks=nested_scalar_mark)])
@pytest.mark.skipif(is_before_spark_311(), reason="This is supported only in Spark 3.1.1+")
# This tests the union of two DFs of structs with missing child column names. The missing child
# column will be replaced by nulls in the output DF. This is a feature added in 3.1+
def test_union_by_missing_col_name(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).withColumnRenamed("a", "x")
.unionByName(binary_op_df(spark, data_gen).withColumnRenamed("a", "y"), True))

@pytest.mark.parametrize('data_gen', all_gen + [all_basic_struct_gen, StructGen([['child0', DecimalGen(7, 2)]])], ids=idfn)
def test_union_by_name(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).unionByName(binary_op_df(spark, data_gen)))
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/src/main/python/sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from marks import *
from pyspark.sql.types import *
import pyspark.sql.functions as f
from spark_session import is_before_spark_310
from spark_session import is_before_spark_311

orderable_not_null_gen = [ByteGen(nullable=False), ShortGen(nullable=False), IntegerGen(nullable=False),
LongGen(nullable=False), FloatGen(nullable=False), DoubleGen(nullable=False), BooleanGen(nullable=False),
Expand Down Expand Up @@ -50,9 +50,9 @@ def test_single_sort_in_part(data_gen, order):
conf = allow_negative_scale_of_decimal_conf)

orderable_gens_sort = [byte_gen, short_gen, int_gen, long_gen,
pytest.param(float_gen, marks=pytest.mark.xfail(condition=is_before_spark_310(),
pytest.param(float_gen, marks=pytest.mark.xfail(condition=is_before_spark_311(),
reason='Spark has -0.0 < 0.0 before Spark 3.1')),
pytest.param(double_gen, marks=pytest.mark.xfail(condition=is_before_spark_310(),
pytest.param(double_gen, marks=pytest.mark.xfail(condition=is_before_spark_311(),
reason='Spark has -0.0 < 0.0 before Spark 3.1')),
boolean_gen, timestamp_gen, date_gen, string_gen, null_gen] + decimal_gens
@pytest.mark.parametrize('data_gen', orderable_gens_sort, ids=idfn)
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,5 @@ def with_gpu_session(func, conf={}):
def is_spark_300():
return (spark_version() == "3.0.0" or spark_version().startswith('3.0.0-amzn'))

def is_before_spark_310():
return spark_version() < "3.1.0"
def is_before_spark_311():
return spark_version() < "3.1.1"
1 change: 0 additions & 1 deletion integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import pytest

from spark_session import is_before_spark_310
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2688,7 +2688,11 @@ object GpuOverrides {
(shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)),
exec[UnionExec](
"The backend for the union operator",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL)
.withPsNote(TypeEnum.STRUCT,
"unionByName will not optionally impute nulls for missing struct fields " +
"when the column is a struct and there are non-overlapping fields"), TypeSig.all),
(union, conf, p, r) => new SparkPlanMeta[UnionExec](union, conf, p, r) {
override def convertToGpu(): GpuExec =
GpuUnionExec(childPlans.map(_.convertIfNeeded()))
Expand Down

0 comments on commit 80311bf

Please sign in to comment.