Skip to content

Commit

Permalink
Merge branch 'branch-23.08' into bloom-filter-agg
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Aug 2, 2023
2 parents fc5e391 + 9ec437f commit 26f3f5a
Show file tree
Hide file tree
Showing 35 changed files with 1,083 additions and 569 deletions.
6 changes: 3 additions & 3 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -12936,7 +12936,7 @@ are limited.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>STRUCT is not supported as a child type for ARRAY;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td><b>NS</b></td>
Expand All @@ -12957,7 +12957,7 @@ are limited.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>STRUCT is not supported as a child type for ARRAY;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -19341,7 +19341,7 @@ as `a` don't show up in the table. They are controlled by the rules for
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>STRUCT is not supported as a child type for ARRAY;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, UDT</em></td>
<td><b>NS</b></td>
Expand Down
67 changes: 48 additions & 19 deletions integration_tests/DATA_GEN.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,44 @@ dataTable.toDF(spark).groupBy("a").count().orderBy(desc("count")).show()
+----------+-----+
| a|count|
+----------+-----+
| 0~{5)Uek>|34414|
|,J~pA-KqBn|34030|
|#"IlU%=azD|13651|
|9=o`YbIDy{|13444|
|xqW\HOC.;L| 2107|
|F&1rC%ge3P| 2089|
|bK%|@|fs9a| 129|
|(Z=9IR8h83| 128|
|Eb9fEBb-]B| 5|
|T<oMow6W_L| 3|
|,J~pA-KqBn|38286|
| 0~{5)Uek>|24279|
|9=o`YbIDy{|24122|
|F&1rC%ge3P| 6147|
|#"IlU%=azD| 5977|
|xqW\HOC.;L| 591|
|bK%|@|fs9a| 547|
|(Z=9IR8h83| 26|
|T<oMow6W_L| 24|
|Eb9fEBb-]B| 1|
+----------+-----+
```
### ExponentialDistribution
`ExponentialDistribution` takes a target seed value and a standard deviation to provide a way
to insert an exponential skew. The target seed is the seed that is most likely to show up and
the standard deviation is `1/rate`. The median should be one standard deviation below the
target.
```scala
val dataTable = DBGen().addTable("data", "a string", 100000)
dataTable("a").setSeedMapping(ExponentialDistribution(50, 1.0)).setSeedRange(0, 100)
dataTable.toDF(spark).groupBy("a").count().orderBy(desc("count")).show()
+----------+-----+
| a|count|
+----------+-----+
|,J~pA-KqBn|63428|
| 0~{5)Uek>|23026|
|#"IlU%=azD| 8602|
|xqW\HOC.;L| 3141|
|(Z=9IR8h83| 1164|
|Eb9fEBb-]B| 412|
|do)6AwiT_T| 129|
||i2l\J)u8I| 62|
|VZav:oU#g[| 23|
|kFR]RZ9pu|| 8|
| aMZ({x5#1| 5|
+----------+-----+
```

Expand All @@ -258,29 +286,30 @@ val dataTable = DBGen().addTable("data", "a string", 100000)
dataTable("a").setSeedMapping(MultiDistribution(Seq(
(10.0, NormalDistribution(50, 1.0)),
(1.0, FlatDistribution())))).setSeedRange(0, 100)
dataTable.toDF(spark).groupBy("a").count().orderBy(desc("count")).show()
+----------+-----+
| a|count|
+----------+-----+
| 0~{5)Uek>|31727|
|,J~pA-KqBn|29461|
|9=o`YbIDy{|12910|
|#"IlU%=azD|12840|
|F&1rC%ge3P| 2103|
|xqW\HOC.;L| 2080|
|(Z=9IR8h83| 214|
|bK%|@|fs9a| 185|
|,J~pA-KqBn|33532|
| 0~{5)Uek>|23093|
|9=o`YbIDy{|22131|
|F&1rC%ge3P| 5711|
|#"IlU%=azD| 5646|
|xqW\HOC.;L| 659|
|bK%|@|fs9a| 615|
|(Z=9IR8h83| 120|
|n&]AosAQJf| 111|
||H6h"R!7CH| 110|
|-bVd8htg"^| 108|
|u2^.x?oJBb| 107|
| aMZ({x5#1| 107|
|Qb#XoQx[{Z| 107|
|5C&<?S31Kp| 106|
|T<oMow6W_L| 106|
|)Wf2']8yFm| 105|
|_qo)|Ti2}n| 105|
|S1Jdbn_hda| 104|
|\SANbeK.?`| 103|
|yba?^,?zP`| 103|
+----------+-----+
only showing top 20 rows
```
Expand Down
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyspark.sql.types import IntegralType
from spark_session import *
import pyspark.sql.functions as f
import pyspark.sql.utils
from datetime import timedelta

# No overflow gens here because we just focus on verifying the fallback to CPU when
Expand Down Expand Up @@ -638,6 +639,7 @@ def test_decimal_bround(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'bround(a)',
'bround(1.234, 2)',
'bround(a, -1)',
'bround(a, 1)',
'bround(a, 2)',
Expand All @@ -650,11 +652,34 @@ def test_decimal_round(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'round(a)',
'round(1.234, 2)',
'round(a, -1)',
'round(a, 1)',
'round(a, 2)',
'round(a, 10)'))


@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', [int_gen], ids=idfn)
def test_illegal_args_round(data_gen):
def check_analysis_exception(spark, sql_text):
try:
gen_df(spark, [("a", data_gen), ("b", int_gen)], length=10).selectExpr(sql_text)
raise Exception("round/bround should not plan with invalid arguments %s" % sql_text)
except pyspark.sql.utils.AnalysisException as e:
pass

def doit(spark):
check_analysis_exception(spark, "round(1.2345, b)")
check_analysis_exception(spark, "round(a, b)")
check_analysis_exception(spark, "bround(1.2345, b)")
check_analysis_exception(spark, "bround(a, b)")

with_cpu_session(lambda spark: doit(spark))
with_gpu_session(lambda spark: doit(spark))


@incompat
@approximate_float
def test_non_decimal_round_overflow():
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from spark_session import is_before_spark_313, is_before_spark_330, is_databricks113_or_later, is_spark_330_or_later, is_databricks104_or_later, is_spark_33X, is_spark_340_or_later, is_spark_330, is_spark_330cdh
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from pyspark.sql.functions import array_contains, col, element_at, lit
from pyspark.sql.functions import array_contains, col, element_at, lit, array

# max_val is a little larger than the default max size(20) of ArrayGen
# so we can get the out-of-bound indices.
Expand Down Expand Up @@ -218,6 +218,9 @@ def get_input(spark):
return two_col_df(spark, arr_gen, data_gen)

assert_gpu_and_cpu_are_equal_collect(lambda spark: get_input(spark).select(
array_contains(array(lit(None)), col('b')),
array_contains(array(), col('b')),
array_contains(array(lit(literal), lit(literal)), col('b')),
array_contains(col('a'), literal.cast(data_gen.data_type)),
array_contains(col('a'), col('b')),
array_contains(col('a'), col('a')[5])))
Expand Down
21 changes: 20 additions & 1 deletion integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,8 @@
from pyspark.sql.types import *
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
from spark_session import with_cpu_session, with_gpu_session

nested_gens = [ArrayGen(LongGen()), ArrayGen(decimal_gen_128bit),
StructGen([("a", LongGen()), ("b", decimal_gen_128bit)]),
Expand Down Expand Up @@ -155,6 +157,23 @@ def test_sort_array_lit(data_gen):
f.sort_array(f.lit(array_lit), False)))


@pytest.mark.parametrize('data_gen', [ArrayGen(IntegerGen())], ids=idfn)
def test_illegal_args_sort_array(data_gen):
def check_analysis_exception(spark, sql_text):
try:
gen_df(spark, [("a", data_gen), ("b", boolean_gen)], length=10).selectExpr(sql_text)
raise Exception("sort_array should not plan with invalid arguments %s" % sql_text)
except pyspark.sql.utils.AnalysisException as e:
pass

def doit(spark):
check_analysis_exception(spark, "sort_array(a, b)")
check_analysis_exception(spark, "sort_array(array(), b)")

with_cpu_session(lambda spark: doit(spark))
with_gpu_session(lambda spark: doit(spark))


def test_sort_array_normalize_nans():
"""
When the average length of array is > 100,
Expand Down
89 changes: 77 additions & 12 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datetime import date, datetime, timezone
from marks import ignore_order, incompat, allow_non_gpu
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_spark_session, is_before_spark_330
from spark_session import with_cpu_session, is_before_spark_330
import pyspark.sql.functions as f

# We only support literal intervals for TimeSub
Expand Down Expand Up @@ -229,6 +229,16 @@ def test_unix_timestamp(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.unix_timestamp(f.col('a'))))


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
def test_unsupported_fallback_unix_timestamp(data_gen):
assert_gpu_fallback_collect(lambda spark: gen_df(
spark, [("a", data_gen), ("b", string_gen)], length=10).selectExpr(
"unix_timestamp(a, b)"),
"UnixTimestamp")


@pytest.mark.parametrize('ansi_enabled', [True, False], ids=['ANSI_ON', 'ANSI_OFF'])
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
def test_to_unix_timestamp(data_gen, ansi_enabled):
Expand All @@ -237,19 +247,48 @@ def test_to_unix_timestamp(data_gen, ansi_enabled):
{'spark.sql.ansi.enabled': ansi_enabled})


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
def test_unsupported_fallback_to_unix_timestamp(data_gen):
assert_gpu_fallback_collect(lambda spark: gen_df(
spark, [("a", data_gen), ("b", string_gen)], length=10).selectExpr(
"to_unix_timestamp(a, b)"),
"ToUnixTimestamp")


@pytest.mark.parametrize('time_zone', ["UTC", "UTC+0", "UTC-0", "GMT", "GMT+0", "GMT-0"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_from_utc_timestamp(data_gen, time_zone):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)))

@allow_non_gpu('ProjectExec, FromUTCTimestamp')
@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('time_zone', ["PST", "MST", "EST", "VST", "NST", "AST"], ids=idfn)
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_from_utc_timestamp_fallback(data_gen, time_zone):
def test_from_utc_timestamp_unsupported_timezone_fallback(data_gen, time_zone):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
'FromUTCTimestamp')


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [timestamp_gen], ids=idfn)
def test_unsupported_fallback_from_utc_timestamp(data_gen):
time_zone_gen = StringGen(pattern="UTC")
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, [("a", data_gen), ("tzone", time_zone_gen)]).selectExpr(
"from_utc_timestamp(a, tzone)"),
'FromUTCTimestamp')


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', [long_gen], ids=idfn)
def test_unsupported_fallback_from_unixtime(data_gen):
fmt_gen = StringGen(pattern="[M]")
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.from_utc_timestamp(f.col('a'), time_zone)),
'ProjectExec')
lambda spark: gen_df(spark, [("a", data_gen), ("fmt", fmt_gen)]).selectExpr(
"from_unixtime(a, fmt)"),
'FromUnixTime')


@pytest.mark.parametrize('invalid,fmt', [
Expand Down Expand Up @@ -388,28 +427,30 @@ def test_date_format(data_gen, date_format):
unsupported_date_formats = ['F']
@pytest.mark.parametrize('date_format', unsupported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@allow_non_gpu('ProjectExec,Alias,DateFormatClass,Literal,Cast')
@allow_non_gpu('ProjectExec')
def test_date_format_f(data_gen, date_format):
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)), 'ProjectExec')
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'DateFormatClass')

@pytest.mark.parametrize('date_format', unsupported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@allow_non_gpu('ProjectExec,Alias,DateFormatClass,Literal,Cast')
@allow_non_gpu('ProjectExec')
def test_date_format_f_incompat(data_gen, date_format):
# note that we can't support it even with incompatibleDateFormats enabled
conf = {"spark.rapids.sql.incompatibleDateFormats.enabled": "true"}
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)), 'ProjectExec', conf)
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'DateFormatClass', conf)

maybe_supported_date_formats = ['dd-MM-yyyy', 'yyyy-MM-dd HH:mm:ss.SSS', 'yyyy-MM-dd HH:mm:ss.SSSSSS']
@pytest.mark.parametrize('date_format', maybe_supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
@allow_non_gpu('ProjectExec,Alias,DateFormatClass,Literal,Cast')
@allow_non_gpu('ProjectExec')
def test_date_format_maybe(data_gen, date_format):
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)),
'ProjectExec')
'DateFormatClass')

@pytest.mark.parametrize('date_format', maybe_supported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
Expand Down Expand Up @@ -439,6 +480,30 @@ def do_join_cast(spark):
return left.join(right, left.monthly_reporting_period == right.r_monthly_reporting_period, how='inner')
assert_gpu_and_cpu_are_equal_collect(do_join_cast)


@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
def test_unsupported_fallback_date_format(data_gen):
conf = {"spark.rapids.sql.incompatibleDateFormats.enabled": "true"}
assert_gpu_fallback_collect(
lambda spark : gen_df(spark, [("a", data_gen)]).selectExpr(
"date_format(a, a)"),
"DateFormatClass",
conf)


@allow_non_gpu('ProjectExec')
def test_unsupported_fallback_to_date():
date_gen = StringGen(pattern="2023-08-01")
pattern_gen = StringGen(pattern="[M]")
conf = {"spark.rapids.sql.incompatibleDateFormats.enabled": "true"}
assert_gpu_fallback_collect(
lambda spark: gen_df(spark, [("a", date_gen), ("b", pattern_gen)]).selectExpr(
"to_date(a, b)"),
'GetTimestamp',
conf)


# (-62135510400, 253402214400) is the range of seconds that can be represented by timestamp_seconds
# considering the influence of time zone.
ts_float_gen = SetValuesGen(FloatType(), [0.0, -0.0, 1.0, -1.0, 1.234567, -1.234567, 16777215.0, float('inf'), float('-inf'), float('nan')])
Expand Down
Loading

0 comments on commit 26f3f5a

Please sign in to comment.