Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow structs and arrays to pass through for Shuffle and Sort #1477

Merged
merged 5 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -536,9 +536,9 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, MAP, UDT)</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import incompat
from pyspark.sql.types import *
import pyspark.sql.functions as f

# Once we support arrays as literals then we can support a[null] and
# negative indexes for all array gens. When that happens
Expand Down Expand Up @@ -46,3 +44,10 @@ def test_nested_array_index(data_gen):
'a[1]',
'a[3]',
'a[50]'))

@pytest.mark.parametrize('data_gen', single_level_array_gens_non_decimal, ids=idfn)
def test_orderby_array(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : unary_op_df(spark, data_gen),
'array_table',
'select array_table.a, array_table.a[0] as first_val from array_table order by first_val')
1 change: 1 addition & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
boolean_gens = [boolean_gen]

single_level_array_gens = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + decimal_gens + [null_gen]]
single_level_array_gens_non_decimal = [ArrayGen(sub_gen) for sub_gen in all_basic_gens + [null_gen]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not include decimal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw some analysis exception with legacy setting for negative scale when trying to sort arrays with decimal type which I am not fully aware on where we stand in terms of support so I excluded those in that particular test.

I will check it out further


# Be careful to not make these too large of data generation takes for ever
# This is only a few nested array gens, because nesting can be very deep
Expand Down
15 changes: 12 additions & 3 deletions integration_tests/src/main/python/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import incompat
from pyspark.sql.types import *
import pyspark.sql.functions as f

@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]),
StructGen([["first", short_gen], ["second", int_gen], ["third", long_gen]]),
Expand All @@ -32,10 +30,21 @@ def test_struct_get_item(data_gen):
'a.second',
'a.third'))


@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_default, decimal_gen_scale_precision], ids=idfn)
def test_make_struct(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : binary_op_df(spark, data_gen).selectExpr(
'struct(a, b)',
'named_struct("foo", b, "bar", 5, "end", a)'))


@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", float_gen]]),
StructGen([["first", short_gen], ["second", int_gen], ["third", long_gen]]),
StructGen([["first", long_gen], ["second", long_gen], ["third", long_gen]]),
StructGen([["first", string_gen], ["second", ArrayGen(string_gen)], ["third", ArrayGen(string_gen)]])], ids=idfn)
def test_orderby_struct(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark : unary_op_df(spark, data_gen),
'struct_table',
'select struct_table.a, struct_table.a.first as first_val from struct_table order by first_val')
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,8 @@ object GpuOverrides {
}),
exec[ShuffleExchangeExec](
"The backend for most data being exchanged between processes",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
revans2 marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.STRUCT).nested(), TypeSig.all),
(shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)),
exec[UnionExec](
"The backend for the union operator",
Expand Down Expand Up @@ -2372,7 +2373,8 @@ object GpuOverrides {
(agg, conf, p, r) => new GpuSortAggregateMeta(agg, conf, p, r)),
exec[SortExec](
"The backend for the sort operator",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
TypeSig.STRUCT).nested(), TypeSig.all),
revans2 marked this conversation as resolved.
Show resolved Hide resolved
(sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)),
exec[ExpandExec](
"The backend for the expand operator",
Expand Down