diff --git a/docs/compatibility.md b/docs/compatibility.md
index e14bc66cd222..560002f4d5ed 100644
--- a/docs/compatibility.md
+++ b/docs/compatibility.md
@@ -309,15 +309,6 @@ Also, the GPU does not support casting from strings containing hex values.
To enable this operation on the GPU, set
[`spark.rapids.sql.castStringToFloat.enabled`](configs.md#sql.castStringToFloat.enabled) to `true`.
-
-### String to Integral Types
-
-The GPU will return incorrect results for strings representing values greater than Long.MaxValue or
-less than Long.MinValue. The correct behavior would be to return null for these values, but the GPU
-currently overflows and returns an incorrect integer value.
-
-To enable this operation on the GPU, set
-[`spark.rapids.sql.castStringToInteger.enabled`](configs.md#sql.castStringToInteger.enabled) to `true`.
### String to Date
diff --git a/docs/configs.md b/docs/configs.md
index 8abcf8f00a2d..a65ca2db7588 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -51,6 +51,7 @@ Name | Description | Default Value
spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null
spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true
spark.rapids.sql.batchSizeBytes|Set the target number of bytes for a GPU batch. Splits sizes for input data is covered by separate configs. The maximum setting is 2 GB to avoid exceeding the cudf row count limit of a column.|2147483647
+spark.rapids.sql.castDecimalToString.enabled|When set to true, casting from decimal to string is supported on the GPU. The GPU does NOT produce exact same string as spark produces, but producing strings which are semantically equal. For instance, given input BigDecimal(123, -2), the GPU produces "12300", which spark produces "1.23E+4".|false
spark.rapids.sql.castFloatToDecimal.enabled|Casting from floating point types to decimal on the GPU returns results that have tiny difference compared to results returned from CPU.|false
spark.rapids.sql.castFloatToIntegralTypes.enabled|Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.|false
spark.rapids.sql.castFloatToString.enabled|Casting from floating point types to string on the GPU returns results that have a different precision than the default results of Spark.|false
@@ -158,6 +159,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Floor|`floor`|Floor of a number|true|None|
spark.rapids.sql.expression.FromUnixTime|`from_unixtime`|Get the string from a unix timestamp|true|None|
spark.rapids.sql.expression.GetArrayItem| |Gets the field at `ordinal` in the Array|true|None|
+spark.rapids.sql.expression.GetJsonObject|`get_json_object`|Extracts a json object from path|true|None|
spark.rapids.sql.expression.GetMapValue| |Gets Value from a Map based on a key|true|None|
spark.rapids.sql.expression.GetStructField| |Gets the named field of the struct|true|None|
spark.rapids.sql.expression.GetTimestamp| |Gets timestamps from strings using given pattern.|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 650eb3a091b4..770cc1732c8d 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -144,12 +144,12 @@ Accelerator supports are described below.
S* |
S |
S* |
+S |
NS |
NS |
NS |
NS |
-NS |
-NS |
+PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
NS |
@@ -379,7 +379,7 @@ Accelerator supports are described below.
NS |
NS |
NS |
-NS |
+PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
NS |
@@ -5817,6 +5817,74 @@ Accelerator support is described below.
NS |
+GetJsonObject |
+`get_json_object` |
+Extracts a json object from path |
+None |
+project |
+json |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+path |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS (Literal value only) |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
GetMapValue |
|
Gets Value from a Map based on a key |
@@ -12421,7 +12489,7 @@ Accelerator support is described below.
NS |
NS |
|
-NS |
+PS* (missing nested BINARY, CALENDAR, ARRAY, STRUCT, UDT) |
NS |
@@ -12442,7 +12510,7 @@ Accelerator support is described below.
NS |
NS |
|
-NS |
+PS* (missing nested BINARY, CALENDAR, ARRAY, STRUCT, UDT) |
NS |
@@ -18049,7 +18117,7 @@ and the accelerator produces the same result.
NS |
|
NS |
-NS |
+S |
S* |
|
|
@@ -18175,7 +18243,7 @@ and the accelerator produces the same result.
|
|
|
-NS |
+PS (the struct's children must also support being cast to string) |
|
|
|
@@ -18453,7 +18521,7 @@ and the accelerator produces the same result.
NS |
|
NS |
-NS |
+S |
S* |
|
|
@@ -18579,7 +18647,7 @@ and the accelerator produces the same result.
|
|
|
-NS |
+PS (the struct's children must also support being cast to string) |
|
|
|
diff --git a/docs/tuning-guide.md b/docs/tuning-guide.md
index 835e57987fc0..ffd23e78107e 100644
--- a/docs/tuning-guide.md
+++ b/docs/tuning-guide.md
@@ -209,5 +209,4 @@ performance.
- [`spark.rapids.sql.variableFloatAgg.enabled`](configs.md#sql.variableFloatAgg.enabled)
- [`spark.rapids.sql.hasNans`](configs.md#sql.hasNans)
- [`spark.rapids.sql.castFloatToString.enabled`](configs.md#sql.castFloatToString.enabled)
-- [`spark.rapids.sql.castStringToInteger.enabled`](configs.md#sql.castStringToInteger.enabled)
- [`spark.rapids.sql.castStringToFloat.enabled`](configs.md#sql.castStringToFloat.enabled)
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh
index ad7d8c7b180f..8d920e6eb487 100755
--- a/integration_tests/run_pyspark_from_build.sh
+++ b/integration_tests/run_pyspark_from_build.sh
@@ -84,11 +84,11 @@ else
then
# With xdist 0 and 1 are the same parallelsm but
# 0 is more effecient
- TEST_PARALLEL_OPTS=""
+ TEST_PARALLEL_OPTS=()
MEMORY_FRACTION='1'
else
MEMORY_FRACTION=`python -c "print(1/($TEST_PARALLEL + 1))"`
- TEST_PARALLEL_OPTS="-n $TEST_PARALLEL"
+ TEST_PARALLEL_OPTS=("-n" "$TEST_PARALLEL")
fi
RUN_DIR="$SCRIPTPATH"/target/run_dir
mkdir -p "$RUN_DIR"
@@ -99,41 +99,35 @@ else
## Under cloud environment, overwrite the '--std_input_path' param to point to the distributed file path
INPUT_PATH=${INPUT_PATH:-"$SCRIPTPATH"}
- if [[ "${TEST_PARALLEL_OPTS}" != "" ]];
+ RUN_TESTS_COMMAND=("$SCRIPTPATH"/runtests.py
+ --rootdir
+ "$LOCAL_ROOTDIR"
+ "$LOCAL_ROOTDIR"/src/main/python)
+
+ TEST_COMMON_OPTS=(-v
+ -rfExXs
+ "$TEST_TAGS"
+ --std_input_path="$INPUT_PATH"/src/test/resources
+ --color=yes
+ $TEST_TYPE_PARAM
+ "$TEST_ARGS"
+ $RUN_TEST_PARAMS
+ "$@")
+
+ export PYSP_TEST_spark_driver_extraClassPath="${ALL_JARS// /:}"
+ export PYSP_TEST_spark_driver_extraJavaOptions="-ea -Duser.timezone=UTC $COVERAGE_SUBMIT_FLAGS"
+ export PYSP_TEST_spark_executor_extraJavaOptions='-ea -Duser.timezone=UTC'
+ export PYSP_TEST_spark_ui_showConsoleProgress='false'
+ export PYSP_TEST_spark_sql_session_timeZone='UTC'
+ export PYSP_TEST_spark_sql_shuffle_partitions='12'
+ if ((${#TEST_PARALLEL_OPTS[@]} > 0));
then
- export PYSP_TEST_spark_driver_extraClassPath="${ALL_JARS// /:}"
- export PYSP_TEST_spark_driver_extraJavaOptions="-ea -Duser.timezone=UTC $COVERAGE_SUBMIT_FLAGS"
- export PYSP_TEST_spark_executor_extraJavaOptions='-ea -Duser.timezone=UTC'
- export PYSP_TEST_spark_ui_showConsoleProgress='false'
- export PYSP_TEST_spark_sql_session_timeZone='UTC'
- export PYSP_TEST_spark_sql_shuffle_partitions='12'
export PYSP_TEST_spark_rapids_memory_gpu_allocFraction=$MEMORY_FRACTION
export PYSP_TEST_spark_rapids_memory_gpu_maxAllocFraction=$MEMORY_FRACTION
-
- python \
- "$SCRIPTPATH"/runtests.py --rootdir "$LOCAL_ROOTDIR" "$LOCAL_ROOTDIR"/src/main/python \
- $TEST_PARALLEL_OPTS \
- -v -rfExXs "$TEST_TAGS" \
- --std_input_path="$INPUT_PATH"/src/test/resources/ \
- --color=yes \
- $TEST_TYPE_PARAM \
- "$TEST_ARGS" \
- $RUN_TEST_PARAMS \
- "$@"
+ python "${RUN_TESTS_COMMAND[@]}" "${TEST_PARALLEL_OPTS[@]}" "${TEST_COMMON_OPTS[@]}"
else
"$SPARK_HOME"/bin/spark-submit --jars "${ALL_JARS// /,}" \
- --conf "spark.driver.extraJavaOptions=-ea -Duser.timezone=UTC $COVERAGE_SUBMIT_FLAGS" \
- --conf 'spark.executor.extraJavaOptions=-ea -Duser.timezone=UTC' \
- --conf 'spark.sql.session.timeZone=UTC' \
- --conf 'spark.sql.shuffle.partitions=12' \
- $SPARK_SUBMIT_FLAGS \
- "$SCRIPTPATH"/runtests.py --rootdir "$LOCAL_ROOTDIR" "$LOCAL_ROOTDIR"/src/main/python \
- -v -rfExXs "$TEST_TAGS" \
- --std_input_path="$INPUT_PATH"/src/test/resources/ \
- --color=yes \
- $TEST_TYPE_PARAM \
- "$TEST_ARGS" \
- $RUN_TEST_PARAMS \
- "$@"
+ --driver-java-options "$PYSP_TEST_spark_driver_extraJavaOptions" \
+ $SPARK_SUBMIT_FLAGS "${RUN_TESTS_COMMAND[@]}" "${TEST_COMMON_OPTS[@]}"
fi
fi
diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py
index 80bd2e3634ab..e38d46aa0f2e 100644
--- a/integration_tests/src/main/python/asserts.py
+++ b/integration_tests/src/main/python/asserts.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
from spark_session import with_cpu_session, with_gpu_session
import time
import types as pytypes
+import data_gen
def _assert_equal(cpu, gpu, float_check, path):
t = type(cpu)
@@ -99,7 +100,7 @@ class _RowCmp(object):
"""Allows for sorting Rows in a consistent way"""
def __init__(self, wrapped):
#TODO will need others for maps, etc
- if isinstance(wrapped, Row):
+ if isinstance(wrapped, Row) or isinstance(wrapped, list):
self.wrapped = [_RowCmp(c) for c in wrapped]
else:
self.wrapped = wrapped
@@ -356,7 +357,7 @@ def assert_gpu_and_cpu_row_counts_equal(func, conf={}):
"""
_assert_gpu_and_cpu_are_equal(func, 'COUNT', conf=conf)
-def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None):
+def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None, debug=False):
"""
Assert that the specified SQL query produces equal results on CPU and GPU.
:param df_fun: a function that will create the dataframe
@@ -370,7 +371,10 @@ def assert_gpu_and_cpu_are_equal_sql(df_fun, table_name, sql, conf=None):
def do_it_all(spark):
df = df_fun(spark)
df.createOrReplaceTempView(table_name)
- return spark.sql(sql)
+ if debug:
+ return data_gen.debug_df(spark.sql(sql))
+ else:
+ return spark.sql(sql)
assert_gpu_and_cpu_are_equal_collect(do_it_all, conf)
def assert_py4j_exception(func, error_message):
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 22d9d94ddc73..1ebcfb5c923b 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -299,7 +299,7 @@ def start(self, rand):
POS_FLOAT_NAN_MAX_VALUE = struct.unpack('f', struct.pack('I', 0x7fffffff))[0]
class FloatGen(DataGen):
"""Generate floats, which some built in corner cases."""
- def __init__(self, nullable=True,
+ def __init__(self, nullable=True,
no_nans=False, special_cases=None):
self._no_nans = no_nans
if special_cases is None:
@@ -334,7 +334,7 @@ def gen_float():
POS_DOUBLE_NAN_MAX_VALUE = struct.unpack('d', struct.pack('L', 0x7fffffffffffffff))[0]
class DoubleGen(DataGen):
"""Generate doubles, which some built in corner cases."""
- def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False,
+ def __init__(self, min_exp=DOUBLE_MIN_EXP, max_exp=DOUBLE_MAX_EXP, no_nans=False,
nullable=True, special_cases = None):
self._min_exp = min_exp
self._max_exp = max_exp
@@ -447,7 +447,7 @@ def __init__(self, start=None, end=None, nullable=True):
self._start_day = self._to_days_since_epoch(start)
self._end_day = self._to_days_since_epoch(end)
-
+
self.with_special_case(start)
self.with_special_case(end)
@@ -652,9 +652,27 @@ def gen_scalar_value(data_gen, seed=0, force_no_nulls=False):
v = list(gen_scalar_values(data_gen, 1, seed=seed, force_no_nulls=force_no_nulls))
return v[0]
-def debug_df(df):
- """print out the contents of a dataframe for debugging."""
- print('COLLECTED\n{}'.format(df.collect()))
+def debug_df(df, path = None, file_format = 'json', num_parts = 1):
+ """Print out or save the contents and the schema of a dataframe for debugging."""
+
+ if path is not None:
+ # Save the dataframe and its schema
+ # The schema can be re-created by using DataType.fromJson and used
+ # for loading the dataframe
+ file_name = f"{path}.{file_format}"
+ schema_file_name = f"{path}.schema.json"
+
+ df.coalesce(num_parts).write.format(file_format).save(file_name)
+ print(f"SAVED df output for debugging at {file_name}")
+
+ schema_json = df.schema.json()
+ schema_file = open(schema_file_name , 'w')
+ schema_file.write(schema_json)
+ schema_file.close()
+ print(f"SAVED df schema for debugging along in the output dir")
+ else:
+ print('COLLECTED\n{}'.format(df.collect()))
+
df.explain()
df.printSchema()
return df
diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py
new file mode 100644
index 000000000000..ad7e1f861b00
--- /dev/null
+++ b/integration_tests/src/main/python/json_test.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2021, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from asserts import assert_gpu_and_cpu_are_equal_collect
+from data_gen import *
+from pyspark.sql.types import *
+
+def mk_json_str_gen(pattern):
+ return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
+
+@pytest.mark.parametrize('json_str_pattern', [r'\{"store": \{"fruit": \[\{"weight":\d,"type":"[a-z]{1,9}"\}\], ' \
+ r'"bicycle":\{"price":\d\d\.\d\d,"color":"[a-z]{0,4}"\}\},' \
+ r'"email":"[a-z]{1,5}\@[a-z]{3,10}\.com","owner":"[a-z]{3,8}"\}',
+ r'\{"a": "[a-z]{1,3}"\}'], ids=idfn)
+def test_get_json_object(json_str_pattern):
+ gen = mk_json_str_gen(json_str_pattern)
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, gen, length=10).selectExpr(
+ 'get_json_object(a,"$.a")',
+ 'get_json_object(a, "$.owner")',
+ 'get_json_object(a, "$.store.fruit[0]")'),
+ conf={'spark.sql.parser.escapedStringLiterals': 'true'})
diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py
index e77ce88fc808..2a4723c3a39f 100644
--- a/integration_tests/src/main/python/repart_test.py
+++ b/integration_tests/src/main/python/repart_test.py
@@ -93,6 +93,8 @@ def test_repartion_df(num_parts, length):
([('a', string_gen)], ['a']),
([('a', null_gen)], ['a']),
([('a', StructGen([('c0', boolean_gen), ('c1', StructGen([('cc0', boolean_gen), ('cc1', string_gen)]))]))], ['a']),
+ ([('a', long_gen), ('b', StructGen([('b1', long_gen)]))], ['a']),
+ ([('a', long_gen), ('b', ArrayGen(long_gen, max_length=2))], ['a']),
([('a', byte_gen)], [f.col('a') - 5]),
([('a', long_gen)], [f.col('a') + 15]),
([('a', byte_gen), ('b', boolean_gen)], ['a', 'b']),
@@ -109,7 +111,9 @@ def test_hash_repartition_exact(gen, num_parts):
data_gen = gen[0]
part_on = gen[1]
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : gen_df(spark, data_gen)\
+ lambda spark : gen_df(spark, data_gen, length=1024)\
.repartition(num_parts, *part_on)\
- .selectExpr('spark_partition_id() as id', '*', 'hash(*)', 'pmod(hash(*),{})'.format(num_parts)),
+ .withColumn('id', f.spark_partition_id())\
+ .withColumn('hashed', f.hash(*part_on))\
+ .selectExpr('*', 'pmod(hashed, {})'.format(num_parts)),
conf = allow_negative_scale_of_decimal_conf)
diff --git a/integration_tests/src/main/python/sort_test.py b/integration_tests/src/main/python/sort_test.py
index 5a706a2da0a4..9d90a0dd66d0 100644
--- a/integration_tests/src/main/python/sort_test.py
+++ b/integration_tests/src/main/python/sort_test.py
@@ -34,6 +34,46 @@ def test_single_orderby(data_gen, order):
lambda spark : unary_op_df(spark, data_gen).orderBy(order),
conf = allow_negative_scale_of_decimal_conf)
+@pytest.mark.parametrize('shuffle_parts', [
+ pytest.param(1),
+ pytest.param(200, marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1607"))
+])
+@pytest.mark.parametrize('stable_sort', [
+ pytest.param(True),
+ pytest.param(False, marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1607"))
+])
+@pytest.mark.parametrize('data_gen', [
+ pytest.param(all_basic_struct_gen),
+ pytest.param(StructGen([['child0', all_basic_struct_gen]]),
+ marks=pytest.mark.xfail(reason='second-level structs are not supported')),
+ pytest.param(ArrayGen(string_gen),
+ marks=pytest.mark.xfail(reason="arrays are not supported")),
+ pytest.param(MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
+ marks=pytest.mark.xfail(reason="maps are not supported")),
+], ids=idfn)
+@pytest.mark.parametrize('order', [
+ pytest.param(f.col('a').asc()),
+ pytest.param(f.col('a').asc_nulls_first()),
+ pytest.param(f.col('a').asc_nulls_last(),
+ marks=pytest.mark.xfail(reason='opposite null order not supported')),
+ pytest.param(f.col('a').desc()),
+ pytest.param(f.col('a').desc_nulls_first(),
+ marks=pytest.mark.xfail(reason='opposite null order not supported')),
+ pytest.param(f.col('a').desc_nulls_last()),
+], ids=idfn)
+def test_single_nested_orderby_plain(data_gen, order, shuffle_parts, stable_sort):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).orderBy(order),
+ # TODO no interference with range partition once implemented
+ conf = {
+ **allow_negative_scale_of_decimal_conf,
+ **{
+ 'spark.sql.shuffle.partitions': shuffle_parts,
+ 'spark.rapids.sql.stableSort.enabled': stable_sort,
+ 'spark.rapids.allowCpuRangePartitioning': False
+ }
+ })
+
# SPARK CPU itself has issue with negative scale for take ordered and project
orderable_without_neg_decimal = [n for n in (orderable_gens + orderable_not_null_gen) if not (isinstance(n, DecimalGen) and n.scale < 0)]
@pytest.mark.parametrize('data_gen', orderable_without_neg_decimal, ids=idfn)
@@ -42,6 +82,32 @@ def test_single_orderby_with_limit(data_gen, order):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).orderBy(order).limit(100))
+@pytest.mark.parametrize('data_gen', [
+ pytest.param(all_basic_struct_gen),
+ pytest.param(StructGen([['child0', all_basic_struct_gen]]),
+ marks=pytest.mark.xfail(reason='second-level structs are not supported')),
+ pytest.param(ArrayGen(string_gen),
+ marks=pytest.mark.xfail(reason="arrays are not supported")),
+ pytest.param(MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
+ marks=pytest.mark.xfail(reason="maps are not supported")),
+], ids=idfn)
+@pytest.mark.parametrize('order', [
+ pytest.param(f.col('a').asc()),
+ pytest.param(f.col('a').asc_nulls_first()),
+ pytest.param(f.col('a').asc_nulls_last(),
+ marks=pytest.mark.xfail(reason='opposite null order not supported')),
+ pytest.param(f.col('a').desc()),
+ pytest.param(f.col('a').desc_nulls_first(),
+ marks=pytest.mark.xfail(reason='opposite null order not supported')),
+ pytest.param(f.col('a').desc_nulls_last()),
+], ids=idfn)
+def test_single_nested_orderby_with_limit(data_gen, order):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).orderBy(order).limit(100),
+ conf = {
+ 'spark.rapids.allowCpuRangePartitioning': False
+ })
+
@pytest.mark.parametrize('data_gen', orderable_gens + orderable_not_null_gen, ids=idfn)
@pytest.mark.parametrize('order', [f.col('a').asc(), f.col('a').asc_nulls_last(), f.col('a').desc(), f.col('a').desc_nulls_first()], ids=idfn)
def test_single_sort_in_part(data_gen, order):
diff --git a/integration_tests/src/main/python/struct_test.py b/integration_tests/src/main/python/struct_test.py
index f00e492cfa92..ca8da5f751cb 100644
--- a/integration_tests/src/main/python/struct_test.py
+++ b/integration_tests/src/main/python/struct_test.py
@@ -57,3 +57,63 @@ def test_orderby_struct_2(data_gen):
lambda spark : append_unique_int_col_to_df(spark, unary_op_df(spark, data_gen)),
'struct_table',
'select struct_table.a, struct_table.uniq_int from struct_table order by uniq_int')
+
+# conf with legacy cast to string on
+legacy_complex_types_to_string = {'spark.sql.legacy.castComplexTypesToString.enabled': 'true'}
+@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", short_gen], ["fourth", int_gen], ["fifth", long_gen], ["sixth", string_gen], ["seventh", date_gen]])], ids=idfn)
+def test_legacy_cast_struct_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")),
+ conf = legacy_complex_types_to_string)
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", float_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='casting float to string is not an exact match')
+def test_legacy_cast_struct_with_float_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")),
+ conf = legacy_complex_types_to_string)
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", double_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='casting double to string is not an exact match')
+def test_legacy_cast_struct_with_double_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")),
+ conf = legacy_complex_types_to_string)
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", timestamp_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/219')
+def test_legacy_cast_struct_with_timestamp_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")),
+ conf = legacy_complex_types_to_string)
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", boolean_gen], ["second", byte_gen], ["third", short_gen], ["fourth", int_gen], ["fifth", long_gen], ["sixth", string_gen], ["seventh", date_gen]])], ids=idfn)
+def test_cast_struct_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")))
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", float_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='casting float to string is not an exact match')
+def test_cast_struct_with_float_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")))
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", double_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='casting double to string is not an exact match')
+def test_cast_struct_with_double_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")))
+
+@pytest.mark.parametrize('data_gen', [StructGen([["first", timestamp_gen]])], ids=idfn)
+@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/219')
+def test_cast_struct_with_timestamp_to_string(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).select(
+ f.col('a').cast("STRING")))
diff --git a/jenkins/Dockerfile-blossom.integration.centos7 b/jenkins/Dockerfile-blossom.integration.centos7
index 3da1bf9d4954..eff9cbd0a9cd 100644
--- a/jenkins/Dockerfile-blossom.integration.centos7
+++ b/jenkins/Dockerfile-blossom.integration.centos7
@@ -17,8 +17,8 @@
###
#
# Arguments:
-# CUDA_VER=10.1 or 10.2
-# CUDF_VER=0.18 or 0.19-SNAPSHOT
+# CUDA_VER=10.1, 10.2 or 11.0
+# CUDF_VER=0.18 or 0.19
# URM_URL=
###
ARG CUDA_VER=10.1
diff --git a/jenkins/Dockerfile-blossom.ubuntu16 b/jenkins/Dockerfile-blossom.ubuntu
similarity index 78%
rename from jenkins/Dockerfile-blossom.ubuntu16
rename to jenkins/Dockerfile-blossom.ubuntu
index 1bda5c05313d..f8c4e4a51854 100644
--- a/jenkins/Dockerfile-blossom.ubuntu16
+++ b/jenkins/Dockerfile-blossom.ubuntu
@@ -18,22 +18,24 @@
#
# Build the image for rapids-plugin development environment
#
-# Arguments: CUDA_VER=10.1 or 10.2
-#
+# Arguments:
+# CUDA_VER=10.1, 10.2 or 11.0
+# UBUNTU_VER=18.04 or 20.04
###
-ARG CUDA_VER=10.1
-
-FROM nvidia/cuda:${CUDA_VER}-runtime-ubuntu16.04
+ARG CUDA_VER=11.0
+ARG UBUNTU_VER=18.04
+FROM nvidia/cuda:${CUDA_VER}-runtime-ubuntu${UBUNTU_VER}
#Install java-8, maven, docker image
RUN apt-get update -y && \
apt-get install -y software-properties-common
RUN add-apt-repository ppa:deadsnakes/ppa && \
apt-get update -y && \
- apt-get install -y maven \
+ DEBIAN_FRONTEND="noninteractive" apt-get install -y maven \
openjdk-8-jdk python3.8 python3.8-distutils python3-setuptools tzdata git
RUN python3.8 -m easy_install pip
+RUN update-java-alternatives --set /usr/lib/jvm/java-1.8.0-openjdk-amd64
RUN ln -s /usr/bin/python3.8 /usr/bin/python
RUN python -m pip install pytest sre_yield requests pandas pyarrow findspark pytest-xdist
diff --git a/jenkins/Jenkinsfile-blossom.premerge b/jenkins/Jenkinsfile-blossom.premerge
index 5dfbae28f449..9285188c4299 100644
--- a/jenkins/Jenkinsfile-blossom.premerge
+++ b/jenkins/Jenkinsfile-blossom.premerge
@@ -29,7 +29,8 @@ def pluginPremerge
def githubHelper // blossom github helper
def TEMP_IMAGE_BUILD = true
-def PREMERGE_DOCKERFILE = 'jenkins/Dockerfile-blossom.ubuntu16'
+def CUDA_NAME = 'cuda11.0' // hardcode cuda version for docker build part
+def PREMERGE_DOCKERFILE = 'jenkins/Dockerfile-blossom.ubuntu'
def IMAGE_PREMERGE // temp image for premerge test
def PREMERGE_TAG
def skipped = false
@@ -60,6 +61,7 @@ pipeline {
URM_URL = "https://${ArtifactoryConstants.ARTIFACTORY_NAME}/artifactory/sw-spark-maven"
PVC = credentials("pvc")
CUSTOM_WORKSPACE = "/home/jenkins/agent/workspace/${BUILD_TAG}"
+ CUDA_CLASSIFIER = 'cuda11'
}
stages {
@@ -116,9 +118,6 @@ pipeline {
)
container('docker-build') {
- def CUDA_NAME = sh(returnStdout: true,
- script: '. jenkins/version-def.sh>&2 && echo -n $CUDA_CLASSIFIER | sed "s/-/./g"')
-
// check if pre-merge dockerfile modified
def dockerfileModified = sh(returnStdout: true,
script: 'BASE=$(git --no-pager log --oneline -1 | awk \'{ print $NF }\'); ' +
@@ -129,7 +128,7 @@ pipeline {
}
if (TEMP_IMAGE_BUILD) {
- IMAGE_TAG = "dev-ubuntu16-${CUDA_NAME}"
+ IMAGE_TAG = "dev-ubuntu18-${CUDA_NAME}"
PREMERGE_TAG = "${IMAGE_TAG}-${BUILD_TAG}"
IMAGE_PREMERGE = "${ARTIFACTORY_NAME}/sw-spark-docker-local/plugin:${PREMERGE_TAG}"
def CUDA_VER = "$CUDA_NAME" - "cuda"
@@ -137,7 +136,7 @@ pipeline {
uploadDocker(IMAGE_PREMERGE)
} else {
// if no pre-merge dockerfile change, use nightly image
- IMAGE_PREMERGE = "$ARTIFACTORY_NAME/sw-spark-docker-local/plugin:dev-ubuntu16-$CUDA_NAME-blossom-dev"
+ IMAGE_PREMERGE = "$ARTIFACTORY_NAME/sw-spark-docker-local/plugin:dev-ubuntu18-$CUDA_NAME-blossom-dev"
}
diff --git a/jenkins/databricks/clusterutils.py b/jenkins/databricks/clusterutils.py
index fdb937014efd..a48e7cc3f23c 100644
--- a/jenkins/databricks/clusterutils.py
+++ b/jenkins/databricks/clusterutils.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +46,13 @@ def generate_create_templ(sshKey, cluster_name, runtime, idle_timeout,
templ['driver_node_type_id'] = driver_node_type
templ['ssh_public_keys'] = [ sshKey ]
templ['num_workers'] = num_workers
+ templ['init_scripts'] = [
+ {
+ "dbfs": {
+ "destination": "dbfs:/databricks/init_scripts/init_cudf_udf.sh"
+ }
+ }
+ ]
return templ
diff --git a/jenkins/databricks/init_cudf_udf.sh b/jenkins/databricks/init_cudf_udf.sh
new file mode 100644
index 000000000000..70758a7de917
--- /dev/null
+++ b/jenkins/databricks/init_cudf_udf.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+#
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# The initscript to set up environment for the cudf_udf tests on Databrcks
+# Will be automatically pushed into the dbfs:/databricks/init_scripts once it is updated.
+
+CUDF_VER=${CUDF_VER:-0.19}
+
+# Use mamba to install cudf-udf packages to speed up conda resolve time
+base=$(conda info --base)
+conda create -y -n mamba -c conda-forge mamba
+pip uninstall -y pyarrow
+${base}/envs/mamba/bin/mamba remove -y c-ares zstd libprotobuf pandas
+${base}/envs/mamba/bin/mamba install -y pyarrow=1.0.1 -c conda-forge
+${base}/envs/mamba/bin/mamba install -y -c rapidsai -c rapidsai-nightly -c nvidia -c conda-forge -c defaults cudf=$CUDF_VER cudatoolkit=10.1
+conda env remove -n mamba
diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh
index baefe8b6da7e..95e196def395 100755
--- a/jenkins/databricks/test.sh
+++ b/jenkins/databricks/test.sh
@@ -33,10 +33,25 @@ sudo chmod 777 /databricks/data/logs/
sudo chmod 777 /databricks/data/logs/*
echo { \"port\":\"15002\" } > ~/.databricks-connect
+CUDF_UDF_TEST_ARGS="--conf spark.python.daemon.module=rapids.daemon_databricks \
+ --conf spark.rapids.memory.gpu.allocFraction=0.1 \
+ --conf spark.rapids.python.memory.gpu.allocFraction=0.1 \
+ --conf spark.rapids.python.concurrentPythonWorkers=2"
+
if [ -d "$LOCAL_JAR_PATH" ]; then
## Run tests with jars in the LOCAL_JAR_PATH dir downloading from the denpedency repo
- LOCAL_JAR_PATH=$LOCAL_JAR_PATH bash $LOCAL_JAR_PATH/integration_tests/run_pyspark_from_build.sh --runtime_env="databricks"
+ LOCAL_JAR_PATH=$LOCAL_JAR_PATH bash $LOCAL_JAR_PATH/integration_tests/run_pyspark_from_build.sh --runtime_env="databricks"
+
+ ## Run cudf-udf tests
+ CUDF_UDF_TEST_ARGS="$CUDF_UDF_TEST_ARGS --conf spark.executorEnv.PYTHONPATH=`ls $LOCAL_JAR_PATH/rapids-4-spark_*.jar | grep -v 'tests.jar'`"
+ LOCAL_JAR_PATH=$LOCAL_JAR_PATH SPARK_SUBMIT_FLAGS=$CUDF_UDF_TEST_ARGS TEST_PARALLEL=1 \
+ bash $LOCAL_JAR_PATH/integration_tests/run_pyspark_from_build.sh --runtime_env="databricks" -m "cudf_udf" --cudf_udf
else
## Run tests with jars building from the spark-rapids source code
bash /home/ubuntu/spark-rapids/integration_tests/run_pyspark_from_build.sh --runtime_env="databricks"
+
+ ## Run cudf-udf tests
+ CUDF_UDF_TEST_ARGS="$CUDF_UDF_TEST_ARGS --conf spark.executorEnv.PYTHONPATH=`ls /home/ubuntu/spark-rapids/dist/target/rapids-4-spark_*.jar | grep -v 'tests.jar'`"
+ SPARK_SUBMIT_FLAGS=$CUDF_UDF_TEST_ARGS TEST_PARALLEL=1 \
+ bash /home/ubuntu/spark-rapids/integration_tests/run_pyspark_from_build.sh --runtime_env="databricks" -m "cudf_udf" --cudf_udf
fi
diff --git a/jenkins/spark-nightly-build.sh b/jenkins/spark-nightly-build.sh
index 6e75e28d9ae9..1e0a18948d4c 100755
--- a/jenkins/spark-nightly-build.sh
+++ b/jenkins/spark-nightly-build.sh
@@ -21,14 +21,16 @@ set -ex
## export 'M2DIR' so that shims can get the correct cudf/spark dependency info
export M2DIR="$WORKSPACE/.m2"
-mvn -U -B -Pinclude-databricks,snapshot-shims clean deploy $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dpytest.TEST_TAGS='' -Dpytest.TEST_TYPE="nightly"
+mvn -U -B -Pinclude-databricks,snapshot-shims clean deploy $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR \
+ -Dpytest.TEST_TAGS='' -Dpytest.TEST_TYPE="nightly" -Dcuda.version=$CUDA_CLASSIFIER
# Run unit tests against other spark versions
-mvn -U -B -Pspark301tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
-mvn -U -B -Pspark302tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
-mvn -U -B -Pspark303tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
-mvn -U -B -Pspark311tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
-mvn -U -B -Pspark312tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
-mvn -U -B -Pspark320tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR
+mvn -U -B -Pspark301tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
+mvn -U -B -Pspark302tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
+mvn -U -B -Pspark303tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
+mvn -U -B -Pspark311tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
+mvn -U -B -Pspark312tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
+# Disabled until Spark 3.2 source incompatibility fixed, see https://github.com/NVIDIA/spark-rapids/issues/2052
+#mvn -U -B -Pspark320tests,snapshot-shims test $MVN_URM_MIRROR -Dmaven.repo.local=$M2DIR -Dcuda.version=$CUDA_CLASSIFIER
# Parse cudf and spark files from local mvn repo
jenkins/printJarVersion.sh "CUDFVersion" "$M2DIR/ai/rapids/cudf/${CUDF_VER}" "cudf-${CUDF_VER}" "-${CUDA_CLASSIFIER}.jar" $SERVER_ID
diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh
index 6ea3b972ae38..7eb2c1eed47c 100755
--- a/jenkins/spark-premerge-build.sh
+++ b/jenkins/spark-premerge-build.sh
@@ -37,15 +37,17 @@ export PATH="$SPARK_HOME/bin:$SPARK_HOME/sbin:$PATH"
tar zxf $SPARK_HOME.tgz -C $ARTF_ROOT && \
rm -f $SPARK_HOME.tgz
-mvn -U -B $MVN_URM_MIRROR '-P!snapshot-shims,pre-merge' clean verify -Dpytest.TEST_TAGS='' -Dpytest.TEST_TYPE="pre-commit" -Dpytest.TEST_PARALLEL=4
+mvn -U -B $MVN_URM_MIRROR '-P!snapshot-shims,pre-merge' clean verify -Dpytest.TEST_TAGS='' \
+ -Dpytest.TEST_TYPE="pre-commit" -Dpytest.TEST_PARALLEL=4 -Dcuda.version=$CUDA_CLASSIFIER
# Run the unit tests for other Spark versions but dont run full python integration tests
# NOT ALL TESTS NEEDED FOR PREMERGE
-#env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS=''
-env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark302tests,snapshot-shims test -Dpytest.TEST_TAGS=''
-env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark303tests,snapshot-shims test -Dpytest.TEST_TAGS=''
-env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark311tests,snapshot-shims test -Dpytest.TEST_TAGS=''
-env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark312tests,snapshot-shims test -Dpytest.TEST_TAGS=''
-env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark320tests,snapshot-shims test -Dpytest.TEST_TAGS=''
+#env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark301tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
+env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark302tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
+env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark303tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
+env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark311tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
+env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark312tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
+# Disabled until Spark 3.2 source incompatibility fixed, see https://github.com/NVIDIA/spark-rapids/issues/2052
+#env -u SPARK_HOME mvn -U -B $MVN_URM_MIRROR -Pspark320tests,snapshot-shims test -Dpytest.TEST_TAGS='' -Dcuda.version=$CUDA_CLASSIFIER
# The jacoco coverage should have been collected, but because of how the shade plugin
# works and jacoco we need to clean some things up so jacoco will only report for the
diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
index 39df892a1f29..593fe341fe88 100644
--- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
+++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala
@@ -478,6 +478,8 @@ class Spark300Shims extends SparkShims {
InMemoryFileIndex.shouldFilterOut(path)
}
+ override def getLegacyComplexTypeToString(): Boolean = true
+
// Arrow version changed between Spark versions
override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = {
val arrowBuf = vec.getDataBuffer()
diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
index 557c3fb9fd22..380e302d05b4 100644
--- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
+++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
@@ -139,7 +139,7 @@ class Spark311Shims extends Spark301Shims {
// stringChecks are the same
// binaryChecks are the same
- override val decimalChecks: TypeSig = none
+ override val decimalChecks: TypeSig = DECIMAL + STRING
override val sparkDecimalSig: TypeSig = numeric + BOOLEAN + STRING
// calendarChecks are the same
@@ -424,6 +424,10 @@ class Spark311Shims extends Spark301Shims {
HadoopFSUtilsShim.shouldIgnorePath(path)
}
+ override def getLegacyComplexTypeToString(): Boolean = {
+ SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING)
+ }
+
// Arrow version changed between Spark versions
override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = {
val arrowBuf = vec.getDataBuffer()
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
index e24cc05cd85f..55b45d40d743 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java
@@ -20,6 +20,7 @@
import ai.rapids.cudf.DType;
import ai.rapids.cudf.ArrowColumnBuilder;
import ai.rapids.cudf.HostColumnVector;
+import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.Scalar;
import ai.rapids.cudf.Schema;
import ai.rapids.cudf.Table;
@@ -102,9 +103,9 @@ private static String hexString(byte[] bytes) {
* @param name the name of the column to print out.
* @param hostCol the column to print out.
*/
- public static synchronized void debug(String name, HostColumnVector hostCol) {
+ public static synchronized void debug(String name, HostColumnVectorCore hostCol) {
DType type = hostCol.getType();
- System.err.println("COLUMN " + name + " " + type);
+ System.err.println("COLUMN " + name + " - " + type);
if (type.getTypeId() == DType.DTypeEnum.DECIMAL64) {
for (int i = 0; i < hostCol.getRowCount(); i++) {
if (hostCol.isNull(i)) {
@@ -156,12 +157,32 @@ public static synchronized void debug(String name, HostColumnVector hostCol) {
System.err.println(i + " " + hostCol.getFloat(i));
}
}
+ } else if (DType.STRUCT.equals(type)) {
+ for (int i = 0; i < hostCol.getRowCount(); i++) {
+ if (hostCol.isNull(i)) {
+ System.err.println(i + " NULL");
+ } // The struct child columns are printed out later on.
+ }
+ for (int i = 0; i < hostCol.getNumChildren(); i++) {
+ debug(name + ":CHILD_" + i, hostCol.getChildColumnView(i));
+ }
+ } else if (DType.LIST.equals(type)) {
+ System.err.println("OFFSETS");
+ for (int i = 0; i < hostCol.getRowCount(); i++) {
+ if (hostCol.isNull(i)) {
+ System.err.println(i + " NULL");
+ } else {
+ System.err.println(i + " [" + hostCol.getStartListOffset(i) + " - " +
+ hostCol.getEndListOffset(i) + ")");
+ }
+ }
+ debug(name + ":DATA", hostCol.getChildColumnView(0));
} else {
System.err.println("TYPE " + type + " NOT SUPPORTED FOR DEBUG PRINT");
}
}
- private static void debugInteger(HostColumnVector hostCol, DType intType) {
+ private static void debugInteger(HostColumnVectorCore hostCol, DType intType) {
for (int i = 0; i < hostCol.getRowCount(); i++) {
if (hostCol.isNull(i)) {
System.err.println(i + " NULL");
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala
index ff3a3fc6b34d..09eb76415d27 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala
@@ -21,7 +21,9 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.internal.SQLConf
class CostBasedOptimizer(conf: RapidsConf) extends Logging {
@@ -39,27 +41,37 @@ class CostBasedOptimizer(conf: RapidsConf) extends Logging {
*/
def optimize(plan: SparkPlanMeta[SparkPlan]): Seq[Optimization] = {
val optimizations = new ListBuffer[Optimization]()
- recursivelyOptimize(plan, optimizations, finalOperator = true, "")
+ recursivelyOptimize(plan, optimizations, finalOperator = true)
optimizations
}
+
+ /**
+ * Walk the plan and determine CPU and GPU costs for each operator and then make decisions
+ * about whether operators should run on CPU or GPU.
+ *
+ * @param plan The plan to optimize
+ * @param optimizations Accumulator to store the optimizations that are applied
+ * @param finalOperator Is this the final (root) operator? We have special behavior for this
+ * case because we need the final output to be on the CPU in row format
+ * @return Tuple containing (cpuCost, gpuCost) for the specified plan and the subset of the
+ * tree beneath it that is a candidate for optimization.
+ */
private def recursivelyOptimize(
plan: SparkPlanMeta[SparkPlan],
optimizations: ListBuffer[Optimization],
- finalOperator: Boolean,
- indent: String = ""): (Double, Double) = {
+ finalOperator: Boolean): (Double, Double) = {
// get the CPU and GPU cost of the child plan(s)
val childCosts = plan.childPlans
.map(child => recursivelyOptimize(
child.asInstanceOf[SparkPlanMeta[SparkPlan]],
optimizations,
- finalOperator = false,
- indent + " "))
+ finalOperator = false))
val (childCpuCosts, childGpuCosts) = childCosts.unzip
- // get the CPU and GPU cost of this operator
+ // get the CPU and GPU cost of this operator (excluding cost of children)
val (operatorCpuCost, operatorGpuCost) = costModel.applyCost(plan)
// calculate total (this operator + children)
@@ -72,27 +84,41 @@ class CostBasedOptimizer(conf: RapidsConf) extends Logging {
.count(_.canThisBeReplaced != plan.canThisBeReplaced)
if (numTransitions > 0) {
+ // there are transitions between CPU and GPU so we need to calculate the transition costs
+ // and also make decisions based on those costs to see whether any parts of the plan would
+ // have been better off just staying on the CPU
+
+ // is this operator on the GPU?
if (plan.canThisBeReplaced) {
- // at least one child is transitioning from CPU to GPU
+ // at least one child is transitioning from CPU to GPU so we calculate the
+ // transition costs
val transitionCost = plan.childPlans.filter(!_.canThisBeReplaced)
.map(costModel.transitionToGpuCost).sum
- val gpuCost = operatorGpuCost + transitionCost
- if (gpuCost > operatorCpuCost) {
+
+ // if the GPU cost including transition is more than the CPU cost then avoid this
+ // transition and reset the GPU cost
+ if (operatorGpuCost + transitionCost > operatorCpuCost && !consumesQueryStage(plan)) {
+ // avoid transition and keep this operator on CPU
optimizations.append(AvoidTransition(plan))
plan.costPreventsRunningOnGpu()
- // stay on CPU, so costs are same
+ // reset GPU cost
totalGpuCost = totalCpuCost;
} else {
+ // add transition cost to total GPU cost
totalGpuCost += transitionCost
}
} else {
- // at least one child is transitioning from GPU to CPU
+ // at least one child is transitioning from GPU to CPU so we evaulate each of this
+ // child plans to see if it was worth running on GPU now that we have the cost of
+ // transitioning back to CPU
plan.childPlans.zip(childCosts).foreach {
case (child, childCosts) =>
val (childCpuCost, childGpuCost) = childCosts
val transitionCost = costModel.transitionToCpuCost(child)
val childGpuTotal = childGpuCost + transitionCost
- if (child.canThisBeReplaced && childGpuTotal > childCpuCost) {
+ if (child.canThisBeReplaced && !consumesQueryStage(child)
+ && childGpuTotal > childCpuCost) {
+ // force this child plan back onto CPU
optimizations.append(ReplaceSection(
child.asInstanceOf[SparkPlanMeta[SparkPlan]], totalCpuCost, totalGpuCost))
child.recursiveCostPreventsRunningOnGpu()
@@ -107,7 +133,8 @@ class CostBasedOptimizer(conf: RapidsConf) extends Logging {
}
}
- // special behavior if this is the final operator in the plan
+ // special behavior if this is the final operator in the plan because we always have the
+ // cost of going back to CPU at the end
if (finalOperator && plan.canThisBeReplaced) {
totalGpuCost += costModel.transitionToCpuCost(plan)
}
@@ -115,18 +142,17 @@ class CostBasedOptimizer(conf: RapidsConf) extends Logging {
if (totalGpuCost > totalCpuCost) {
// we have reached a point where we have transitioned onto GPU for part of this
// plan but with no benefit from doing so, so we want to undo this and go back to CPU
- if (plan.canThisBeReplaced) {
+ if (plan.canThisBeReplaced && !consumesQueryStage(plan)) {
// this plan would have been on GPU so we move it and onto CPU and recurse down
// until we reach a part of the plan that is already on CPU and then stop
optimizations.append(ReplaceSection(plan, totalCpuCost, totalGpuCost))
plan.recursiveCostPreventsRunningOnGpu()
+ // reset the costs because this section of the plan was not moved to GPU
+ totalGpuCost = totalCpuCost
}
-
- // reset the costs because this section of the plan was not moved to GPU
- totalGpuCost = totalCpuCost
}
- if (!plan.canThisBeReplaced) {
+ if (!plan.canThisBeReplaced || consumesQueryStage(plan)) {
// reset the costs because this section of the plan was not moved to GPU
totalGpuCost = totalCpuCost
}
@@ -134,6 +160,20 @@ class CostBasedOptimizer(conf: RapidsConf) extends Logging {
(totalCpuCost, totalGpuCost)
}
+ /**
+ * Determines whether the specified operator will read from a query stage.
+ */
+ private def consumesQueryStage(plan: SparkPlanMeta[_]): Boolean = {
+ // if the child query stage already executed on GPU then we need to keep the
+ // next operator on GPU in these cases
+ SQLConf.get.adaptiveExecutionEnabled && (plan.wrapped match {
+ case _: CustomShuffleReaderExec
+ | _: ShuffledHashJoinExec
+ | _: BroadcastHashJoinExec
+ | _: BroadcastNestedLoopJoinExec => true
+ case _ => false
+ })
+ }
}
/**
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/FloatUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/FloatUtils.scala
index de9258177d8a..4be623cc9298 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/FloatUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/FloatUtils.scala
@@ -16,11 +16,11 @@
package com.nvidia.spark.rapids
-import ai.rapids.cudf.{ColumnVector, DType, Scalar}
+import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
object FloatUtils extends Arm {
- def nanToZero(cv: ColumnVector): ColumnVector = {
+ def nanToZero(cv: ColumnView): ColumnVector = {
if (cv.getType() != DType.FLOAT32 && cv.getType() != DType.FLOAT64) {
throw new IllegalArgumentException("Only Floats and Doubles allowed")
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
index f996df383075..672c1ac20685 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
@@ -19,7 +19,9 @@ package com.nvidia.spark.rapids
import java.text.SimpleDateFormat
import java.time.DateTimeException
-import ai.rapids.cudf.{ColumnVector, DType, Scalar}
+import scala.collection.mutable.ArrayBuffer
+
+import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression}
@@ -34,20 +36,24 @@ class CastExprMeta[INPUT <: CastBase](
rule: DataFromReplacementRule)
extends UnaryExprMeta[INPUT](cast, conf, parent, rule) {
- private val castExpr = if (ansiEnabled) "ansi_cast" else "cast"
val fromType = cast.child.dataType
val toType = cast.dataType
+ var legacyCastToString = ShimLoader.getSparkShims.getLegacyComplexTypeToString()
override def tagExprForGpu(): Unit = {
+ recursiveTagExprForGpuCheck(fromType)
+ }
+
+ def recursiveTagExprForGpuCheck(fromDataType: DataType) {
if (!conf.isCastFloatToDecimalEnabled && toType.isInstanceOf[DecimalType] &&
- (fromType == DataTypes.FloatType || fromType == DataTypes.DoubleType)) {
+ (fromDataType == DataTypes.FloatType || fromDataType == DataTypes.DoubleType)) {
willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal to convert " +
"floating point data types to decimals and this can produce results that slightly " +
"differ from the default behavior in Spark. To enable this operation on the GPU, set " +
s"${RapidsConf.ENABLE_CAST_FLOAT_TO_DECIMAL} to true.")
}
if (!conf.isCastFloatToStringEnabled && toType == DataTypes.StringType &&
- (fromType == DataTypes.FloatType || fromType == DataTypes.DoubleType)) {
+ (fromDataType == DataTypes.FloatType || fromDataType == DataTypes.DoubleType)) {
willNotWorkOnGpu("the GPU will use different precision than Java's toString method when " +
"converting floating point data types to strings and this can produce results that " +
"differ from the default behavior in Spark. To enable this operation on the GPU, set" +
@@ -63,14 +69,6 @@ class CastExprMeta[INPUT <: CastBase](
"CPU returns \"+Infinity\" and \"-Infinity\" respectively. To enable this operation on " +
"the GPU, set" + s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.")
}
- if (!conf.isCastStringToIntegerEnabled && cast.child.dataType == DataTypes.StringType &&
- Seq(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType, DataTypes.LongType)
- .contains(cast.dataType)) {
- willNotWorkOnGpu("the GPU will return incorrect results for strings representing" +
- "values greater than Long.MaxValue or less than Long.MinValue. To enable this " +
- "operation on the GPU, set" +
- s" ${RapidsConf.ENABLE_CAST_STRING_TO_INTEGER} to true.")
- }
if (!conf.isCastStringToTimestampEnabled && fromType == DataTypes.StringType
&& toType == DataTypes.TimestampType) {
willNotWorkOnGpu("the GPU only supports a subset of formats " +
@@ -90,6 +88,19 @@ class CastExprMeta[INPUT <: CastBase](
"The first step may lead to precision loss. To enable this operation on the GPU, set " +
s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.")
}
+ if (fromDataType.isInstanceOf[StructType]) {
+ val checks = rule.getChecks.get.asInstanceOf[CastChecks]
+ fromDataType.asInstanceOf[StructType].foreach{field =>
+ recursiveTagExprForGpuCheck(field.dataType)
+ if (toType == StringType) {
+ if (!checks.gpuCanCast(field.dataType, toType)) {
+ willNotWorkOnGpu(s"Unsupported type ${field.dataType} found in Struct column. " +
+ s"Casting ${field.dataType} to ${toType} not currently supported. Refer to " +
+ "CAST documentation for more details.")
+ }
+ }
+ }
+ }
}
def buildTagMessage(entry: ConfEntry[_]): String = {
@@ -97,7 +108,7 @@ class CastExprMeta[INPUT <: CastBase](
}
override def convertToGpu(child: Expression): GpuExpression =
- GpuCast(child, toType, ansiEnabled, cast.timeZoneId)
+ GpuCast(child, toType, ansiEnabled, cast.timeZoneId, legacyCastToString)
}
object GpuCast {
@@ -114,18 +125,6 @@ object GpuCast {
*/
private val FULL_TIMESTAMP_LENGTH = 27
- /**
- * Regex for identifying strings that contain numeric values that can be casted to integral
- * types. This includes floating point numbers but not numbers containing exponents.
- */
- private val CASTABLE_TO_INT_REGEX = "\\s*[+\\-]?[0-9]*(\\.)?[0-9]+\\s*$"
-
- /**
- * Regex for identifying strings that contain numeric values that can be casted to integral
- * types when ansi is enabled.
- */
- private val ANSI_CASTABLE_TO_INT_REGEX = "\\s*[+\\-]?[0-9]+\\s*$"
-
/**
* Regex to match timestamps with or without trailing zeros.
*/
@@ -146,7 +145,8 @@ case class GpuCast(
child: Expression,
dataType: DataType,
ansiMode: Boolean = false,
- timeZoneId: Option[String] = None)
+ timeZoneId: Option[String] = None,
+ legacyCastToString: Boolean = false)
extends GpuUnaryExpression with TimeZoneAwareExpression with NullIntolerant {
import GpuCast._
@@ -203,6 +203,23 @@ case class GpuCast(
override def doColumnar(input: GpuColumnVector): ColumnVector = {
(input.dataType(), dataType) match {
+ // Filter out casts to Decimal that utilize the ColumnVector to avoid a copy
+ case (ShortType | IntegerType | LongType, dt: DecimalType) =>
+ castIntegralsToDecimal(input.getBase, dt)
+
+ case (FloatType | DoubleType, dt: DecimalType) =>
+ castFloatsToDecimal(input.getBase, dt)
+
+ case (from: DecimalType, to: DecimalType) =>
+ castDecimalToDecimal(input.getBase, from, to)
+
+ case _ =>
+ doColumnar(input.getBase, input.dataType())
+ }
+ }
+
+ def doColumnar(input: ColumnView, sparkType: DataType): ColumnVector = {
+ (sparkType, dataType) match {
case (NullType, to) =>
withResource(GpuScalar.from(null, to)) { scalar =>
ColumnVector.fromScalar(scalar, input.getRowCount.toInt)
@@ -210,12 +227,12 @@ case class GpuCast(
case (DateType, BooleanType | _: NumericType) =>
// casts from date type to numerics are always null
withResource(GpuScalar.from(null, dataType)) { scalar =>
- ColumnVector.fromScalar(scalar, input.getBase.getRowCount.toInt)
+ ColumnVector.fromScalar(scalar, input.getRowCount.toInt)
}
case (DateType, StringType) =>
- input.getBase.asStrings("%Y-%m-%d")
+ input.asStrings("%Y-%m-%d")
case (TimestampType, FloatType | DoubleType) =>
- withResource(input.getBase.castTo(DType.INT64)) { asLongs =>
+ withResource(input.castTo(DType.INT64)) { asLongs =>
withResource(Scalar.fromDouble(1000000)) { microsPerSec =>
// Use trueDiv to ensure cast to double before division for full precision
asLongs.trueDiv(microsPerSec, GpuColumnVector.getNonNestedRapidsType(dataType))
@@ -224,7 +241,7 @@ case class GpuCast(
case (TimestampType, ByteType | ShortType | IntegerType) =>
// normally we would just do a floordiv here, but cudf downcasts the operands to
// the output type before the divide. https://github.com/rapidsai/cudf/issues/2574
- withResource(input.getBase.castTo(DType.INT64)) { asLongs =>
+ withResource(input.castTo(DType.INT64)) { asLongs =>
withResource(Scalar.fromInt(1000000)) { microsPerSec =>
withResource(asLongs.floorDiv(microsPerSec, DType.INT64)) { cv =>
if (ansiMode) {
@@ -245,63 +262,65 @@ case class GpuCast(
}
}
case (TimestampType, _: LongType) =>
- withResource(input.getBase.castTo(DType.INT64)) { asLongs =>
+ withResource(input.castTo(DType.INT64)) { asLongs =>
withResource(Scalar.fromInt(1000000)) { microsPerSec =>
asLongs.floorDiv(microsPerSec, GpuColumnVector.getNonNestedRapidsType(dataType))
}
}
case (TimestampType, StringType) =>
castTimestampToString(input)
+ case (StructType(fields), StringType) =>
+ castStructToString(input, legacyCastToString, fields)
// ansi cast from larger-than-integer integral types, to integer
case (LongType, IntegerType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromInt(Int.MinValue),
+ assertValuesInRange(input, Scalar.fromInt(Int.MinValue),
Scalar.fromInt(Int.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from larger-than-short integral types, to short
case (LongType|IntegerType, ShortType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromShort(Short.MinValue),
+ assertValuesInRange(input, Scalar.fromShort(Short.MinValue),
Scalar.fromShort(Short.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from larger-than-byte integral types, to byte
case (LongType|IntegerType|ShortType, ByteType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromByte(Byte.MinValue),
+ assertValuesInRange(input, Scalar.fromByte(Byte.MinValue),
Scalar.fromByte(Byte.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from floating-point types, to byte
case (FloatType|DoubleType, ByteType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromByte(Byte.MinValue),
+ assertValuesInRange(input, Scalar.fromByte(Byte.MinValue),
Scalar.fromByte(Byte.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from floating-point types, to short
case (FloatType|DoubleType, ShortType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromShort(Short.MinValue),
+ assertValuesInRange(input, Scalar.fromShort(Short.MinValue),
Scalar.fromShort(Short.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from floating-point types, to integer
case (FloatType|DoubleType, IntegerType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromInt(Int.MinValue),
+ assertValuesInRange(input, Scalar.fromInt(Int.MinValue),
Scalar.fromInt(Int.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
// ansi cast from floating-point types, to long
case (FloatType|DoubleType, LongType) if ansiMode =>
- assertValuesInRange(input.getBase, Scalar.fromLong(Long.MinValue),
+ assertValuesInRange(input, Scalar.fromLong(Long.MinValue),
Scalar.fromLong(Long.MaxValue))
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
case (FloatType | DoubleType, TimestampType) =>
// Spark casting to timestamp from double assumes value is in microseconds
withResource(Scalar.fromInt(1000000)) { microsPerSec =>
- withResource(input.getBase.nansToNulls()) { inputWithNansToNull =>
+ withResource(input.nansToNulls()) { inputWithNansToNull =>
withResource(FloatUtils.infinityToNulls(inputWithNansToNull)) {
inputWithoutNanAndInfinity =>
- if (input.dataType() == FloatType &&
+ if (sparkType == FloatType &&
ShimLoader.getSparkShims.hasCastFloatTimestampUpcast) {
withResource(inputWithoutNanAndInfinity.castTo(DType.FLOAT64)) { doubles =>
withResource(doubles.mul(microsPerSec, DType.INT64)) {
@@ -320,12 +339,12 @@ case class GpuCast(
}
case (BooleanType, TimestampType) =>
// cudf requires casting to a long first.
- withResource(input.getBase.castTo(DType.INT64)) { longs =>
+ withResource(input.castTo(DType.INT64)) { longs =>
longs.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
case (BooleanType | ByteType | ShortType | IntegerType, TimestampType) =>
// cudf requires casting to a long first
- withResource(input.getBase.castTo(DType.INT64)) { longs =>
+ withResource(input.castTo(DType.INT64)) { longs =>
withResource(longs.castTo(DType.TIMESTAMP_SECONDS)) { timestampSecs =>
timestampSecs.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
@@ -333,21 +352,21 @@ case class GpuCast(
case (_: NumericType, TimestampType) =>
// Spark casting to timestamp assumes value is in seconds, but timestamps
// are tracked in microseconds.
- withResource(input.getBase.castTo(DType.TIMESTAMP_SECONDS)) { timestampSecs =>
+ withResource(input.castTo(DType.TIMESTAMP_SECONDS)) { timestampSecs =>
timestampSecs.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
case (FloatType, LongType) | (DoubleType, IntegerType | LongType) =>
// Float.NaN => Int is casted to a zero but float.NaN => Long returns a small negative
// number Double.NaN => Int | Long, returns a small negative number so Nans have to be
// converted to zero first
- withResource(FloatUtils.nanToZero(input.getBase)) { inputWithNansToZero =>
+ withResource(FloatUtils.nanToZero(input)) { inputWithNansToZero =>
inputWithNansToZero.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
case (FloatType|DoubleType, StringType) =>
castFloatingTypeToString(input)
case (StringType, BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType
| DoubleType | DateType | TimestampType) =>
- withResource(input.getBase.strip()) { trimmed =>
+ withResource(input.strip()) { trimmed =>
dataType match {
case BooleanType =>
castStringToBool(trimmed, ansiMode)
@@ -359,69 +378,42 @@ case class GpuCast(
castStringToFloats(trimmed, ansiMode,
GpuColumnVector.getNonNestedRapidsType(dataType))
case ByteType | ShortType | IntegerType | LongType =>
- // filter out values that are not valid longs or nulls
- val regex = if (ansiMode) {
- GpuCast.ANSI_CASTABLE_TO_INT_REGEX
- } else {
- GpuCast.CASTABLE_TO_INT_REGEX
- }
- val longStrings = withResource(trimmed.matchesRe(regex)) { regexMatches =>
- if (ansiMode) {
- withResource(regexMatches.all(DType.BOOL8)) { allRegexMatches =>
- if (!allRegexMatches.getBoolean) {
- throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
- }
- }
- }
- withResource(Scalar.fromNull(DType.STRING)) { nullString =>
- regexMatches.ifElse(trimmed, nullString)
- }
- }
- // cast to specific integral type after filtering out values that are not in range
- // for that type. Note that the scalar values here are named parameters so are not
- // created until they are needed
- withResource(longStrings) { longStrings =>
- GpuColumnVector.getNonNestedRapidsType(dataType) match {
- case DType.INT8 =>
- castStringToIntegralType(longStrings, DType.INT8,
- Scalar.fromInt(Byte.MinValue), Scalar.fromInt(Byte.MaxValue))
- case DType.INT16 =>
- castStringToIntegralType(longStrings, DType.INT16,
- Scalar.fromInt(Short.MinValue), Scalar.fromInt(Short.MaxValue))
- case DType.INT32 =>
- castStringToIntegralType(longStrings, DType.INT32,
- Scalar.fromInt(Int.MinValue), Scalar.fromInt(Int.MaxValue))
- case DType.INT64 =>
- longStrings.castTo(DType.INT64)
- case _ =>
- throw new IllegalStateException("Invalid integral type")
- }
- }
+ castStringToInts(trimmed, ansiMode,
+ GpuColumnVector.getNonNestedRapidsType(dataType))
}
}
case (StringType, dt: DecimalType) =>
// To apply HALF_UP rounding strategy during casting to decimal, we firstly cast
// string to fp64. Then, cast fp64 to target decimal type to enforce HALF_UP rounding.
- withResource(input.getBase.strip()) { trimmed =>
+ withResource(input.strip()) { trimmed =>
withResource(castStringToFloats(trimmed, ansiMode, DType.FLOAT64)) { fp =>
castFloatsToDecimal(fp, dt)
}
}
case (ShortType | IntegerType | LongType | ByteType | StringType, BinaryType) =>
- input.getBase.asByteList(true)
+ input.asByteList(true)
case (ShortType | IntegerType | LongType, dt: DecimalType) =>
- castIntegralsToDecimal(input.getBase, dt)
+ withResource(input.copyToColumnVector()) { inputVector =>
+ castIntegralsToDecimal(inputVector, dt)
+ }
case (FloatType | DoubleType, dt: DecimalType) =>
- castFloatsToDecimal(input.getBase, dt)
+ withResource(input.copyToColumnVector()) { inputVector =>
+ castFloatsToDecimal(inputVector, dt)
+ }
case (from: DecimalType, to: DecimalType) =>
- castDecimalToDecimal(input.getBase, from, to)
+ withResource(input.copyToColumnVector()) { inputVector =>
+ castDecimalToDecimal(inputVector, from, to)
+ }
+
+ case (_: DecimalType, StringType) =>
+ input.castTo(DType.STRING)
case _ =>
- input.getBase.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
+ input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
}
@@ -435,13 +427,13 @@ case class GpuCast(
* @param inclusiveMax Whether the max value is included in the valid range or not
* @throws IllegalStateException if any values in the column are not within the specified range
*/
- private def assertValuesInRange(values: ColumnVector,
+ private def assertValuesInRange(values: ColumnView,
minValue: => Scalar,
maxValue: => Scalar,
inclusiveMin: Boolean = true,
inclusiveMax: Boolean = true): Unit = {
- def throwIfAny(cv: ColumnVector): Unit = {
+ def throwIfAny(cv: ColumnView): Unit = {
withResource(cv) { cv =>
withResource(cv.any()) { isAny =>
if (isAny.getBoolean) {
@@ -507,16 +499,139 @@ case class GpuCast(
}
}
- private def castTimestampToString(input: GpuColumnVector): ColumnVector = {
- withResource(input.getBase.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros =>
+ private def castTimestampToString(input: ColumnView): ColumnVector = {
+ withResource(input.castTo(DType.TIMESTAMP_MICROSECONDS)) { micros =>
withResource(micros.asStrings("%Y-%m-%d %H:%M:%S.%6f")) { cv =>
cv.stringReplaceWithBackrefs(GpuCast.TIMESTAMP_TRUNCATE_REGEX, "\\1\\2\\3")
}
}
}
- private def castFloatingTypeToString(input: GpuColumnVector): ColumnVector = {
- withResource(input.getBase.castTo(DType.STRING)) { cudfCast =>
+ private def legacyStructToString(input: ColumnView,
+ inputSchema: Array[StructField]): ColumnVector = {
+ var separatorColumn: ColumnVector = null
+ var spaceColumn: ColumnVector = null
+ val columns: ArrayBuffer[ColumnVector] = new ArrayBuffer[ColumnVector]()
+ // coreColumns tracks the casted child columns
+ val coreColumns: ArrayBuffer[ColumnVector] = new ArrayBuffer[ColumnVector]()
+
+ try {
+ withResource(GpuScalar.from(",", StringType)) { separatorScalar =>
+ separatorColumn = ColumnVector.fromScalar(separatorScalar, input.getRowCount.toInt)
+ }
+ withResource(GpuScalar.from(" ", StringType)) { separatorScalar =>
+ spaceColumn = ColumnVector.fromScalar(separatorScalar, input.getRowCount.toInt)
+ }
+ withResource(GpuScalar.from("[", StringType)) { bracketScalar =>
+ columns += ColumnVector.fromScalar(bracketScalar, input.getRowCount.toInt)
+ }
+
+ withResource(input.getChildColumnView(0)) { childView =>
+ columns += doColumnar(childView, inputSchema(0).dataType)
+ coreColumns += columns.last
+ }
+ for(childIndex <- 1 until input.getNumChildren()) {
+ withResource(input.getChildColumnView(childIndex)) { childView =>
+ columns += separatorColumn
+ // Copies the whitespace column's validity with the current column's validity.
+ // Mimics the Spark null behavior of consecutive commas with no space between them
+ columns += spaceColumn.mergeAndSetValidity(BinaryOp.BITWISE_AND, childView)
+ columns += doColumnar(childView, inputSchema(childIndex).dataType)
+ coreColumns += columns.last
+ }
+ }
+ withResource(GpuScalar.from("]", StringType)) { bracketScalar =>
+ columns += ColumnVector.fromScalar(bracketScalar, input.getRowCount.toInt)
+ }
+
+ // Merge casted child columns
+ withResource(GpuScalar.from("", StringType)) { emptyStrScalar =>
+ withResource(ColumnVector.stringConcatenate(emptyStrScalar, emptyStrScalar,
+ columns.toArray[ColumnView])) { fullResult =>
+ // Merge the validity of all child columns, fully null rows are null in the result
+ withResource(fullResult.mergeAndSetValidity(BinaryOp.BITWISE_OR,
+ coreColumns: _*)) { nulledResult =>
+ // Reflect the struct column's validity vector in the result
+ nulledResult.mergeAndSetValidity(BinaryOp.BITWISE_AND, input, nulledResult)
+ }
+ }
+ }
+ } finally {
+ if (separatorColumn != null) {
+ columns.foreach(col =>
+ if(col.getNativeView() != separatorColumn.getNativeView()) {
+ col.close()
+ })
+ separatorColumn.close()
+ }
+ if (spaceColumn != null) {
+ spaceColumn.close()
+ }
+ }
+ }
+
+ private def modernStructToString(input: ColumnView,
+ inputSchema: Array[StructField]): ColumnVector = {
+ var separatorColumn: ColumnVector = null
+ var spaceColumn: ColumnVector = null
+ val columns: ArrayBuffer[ColumnVector] = new ArrayBuffer[ColumnVector]()
+
+ try {
+ withResource(GpuScalar.from(", ", StringType)) { separatorScalar =>
+ separatorColumn = ColumnVector.fromScalar(separatorScalar, input.getRowCount.toInt)
+ }
+ withResource(GpuScalar.from("{", StringType)) { bracketScalar =>
+ columns += ColumnVector.fromScalar(bracketScalar, input.getRowCount.toInt)
+ }
+
+ withResource(input.getChildColumnView(0)) { childView =>
+ columns += doColumnar(childView, inputSchema(0).dataType)
+ }
+ for(childIndex <- 1 until input.getNumChildren()) {
+ withResource(input.getChildColumnView(childIndex)) { childView =>
+ columns += separatorColumn
+ columns += doColumnar(childView, inputSchema(childIndex).dataType)
+ }
+ }
+ withResource(GpuScalar.from("}", StringType)) { bracketScalar =>
+ columns += ColumnVector.fromScalar(bracketScalar, input.getRowCount.toInt)
+ }
+
+ // Merge casted child columns
+ withResource(GpuScalar.from("", StringType)) { emptyStrScalar =>
+ withResource(GpuScalar.from("null", StringType)) { nullStringScalar =>
+ withResource(ColumnVector.stringConcatenate(emptyStrScalar, nullStringScalar,
+ columns.toArray[ColumnView])) { fullResult =>
+ // Reflect the struct column's validity vector in the result
+ fullResult.mergeAndSetValidity(BinaryOp.BITWISE_AND, input)
+ }
+ }
+ }
+ } finally {
+ if (separatorColumn != null) {
+ columns.foreach(col =>
+ if(col.getNativeView() != separatorColumn.getNativeView()) {
+ col.close()
+ })
+ separatorColumn.close()
+ }
+ if (spaceColumn != null) {
+ spaceColumn.close()
+ }
+ }
+ }
+
+ private def castStructToString(input: ColumnView,
+ legacyCastToString: Boolean, inputSchema: Array[StructField]): ColumnVector = {
+ if (legacyCastToString) {
+ legacyStructToString(input, inputSchema)
+ } else {
+ modernStructToString(input,inputSchema)
+ }
+ }
+
+ private def castFloatingTypeToString(input: ColumnView): ColumnVector = {
+ withResource(input.castTo(DType.STRING)) { cudfCast =>
// replace "e+" with "E"
val replaceExponent = withResource(Scalar.fromString("e+")) { cudfExponent =>
@@ -546,7 +661,7 @@ case class GpuCast(
withResource(input.contains(boolStrings)) { validBools =>
// in ansi mode, fail if any values are not valid bool strings
if (ansiEnabled) {
- withResource(validBools.all(DType.BOOL8)) { isAllBool =>
+ withResource(validBools.all()) { isAllBool =>
if (!isAllBool.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
@@ -565,6 +680,52 @@ case class GpuCast(
}
}
+ def castStringToInts(
+ input: ColumnVector,
+ ansiEnabled: Boolean,
+ dType: DType): ColumnVector = {
+ val cleaned = if (!ansiEnabled) {
+ // TODO would be great to get rid of this regex, but the overflow checks don't work
+ // on the more lenient pattern.
+ // To avoid doing the expensive regex all the time, we will first check to see if we need
+ // to do it. The only time we do need to do it is when we have a '.' in any of the strings.
+ val data = input.getData
+ val hasDot = withResource(
+ ColumnView.fromDeviceBuffer(data, 0, DType.INT8, data.getLength.toInt)) { childData =>
+ withResource(GpuScalar.from('.'.toByte, ByteType)) { dot =>
+ childData.contains(dot)
+ }
+ }
+ if (hasDot) {
+ withResource(input.extractRe("^([+\\-]?[0-9]+)(?:\\.[0-9]*)?$")) { table =>
+ table.getColumn(0).incRefCount()
+ }
+ } else {
+ input.incRefCount()
+ }
+ } else {
+ input.incRefCount()
+ }
+ withResource(cleaned) { cleaned =>
+ withResource(cleaned.isInteger(dType)) { isInt =>
+ if (ansiEnabled) {
+ withResource(isInt.all()) { allInts =>
+ if (!allInts.getBoolean) {
+ throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
+ }
+ }
+ cleaned.castTo(dType)
+ } else {
+ withResource(cleaned.castTo(dType)) { parsedInt =>
+ withResource(GpuScalar.from(null, dataType)) { nullVal =>
+ isInt.ifElse(parsedInt, nullVal)
+ }
+ }
+ }
+ }
+ }
+ }
+
def castStringToFloats(
input: ColumnVector,
ansiEnabled: Boolean,
@@ -941,63 +1102,6 @@ case class GpuCast(
}
}
- /**
- * Cast column of long values to a smaller integral type (bytes, short, int).
- *
- * @param longStrings Long values in string format
- * @param castToType Type to cast to
- * @param minValue Named parameter for function to create Scalar representing range minimum value
- * @param maxValue Named parameter for function to create Scalar representing range maximum value
- * @return Values cast to specified integral type
- */
- private def castStringToIntegralType(longStrings: ColumnVector,
- castToType: DType,
- minValue: => Scalar,
- maxValue: => Scalar): ColumnVector = {
-
- // evaluate min and max named parameters once since they are used in multiple places
- withResource(minValue) { minValue: Scalar =>
- withResource(maxValue) { maxValue: Scalar =>
- withResource(Scalar.fromNull(DType.INT64)) { nulls =>
- withResource(longStrings.castTo(DType.INT64)) { values =>
-
- // replace values less than minValue with null
- val gtEqMinOrNull = withResource(values.greaterOrEqualTo(minValue)) { isGtEqMin =>
- if (ansiMode) {
- withResource(isGtEqMin.all(DType.BOOL8)) { all =>
- if (!all.getBoolean) {
- throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
- }
- }
- }
- isGtEqMin.ifElse(values, nulls)
- }
-
- // replace values greater than maxValue with null
- val ltEqMaxOrNull = withResource(gtEqMinOrNull) { gtEqMinOrNull =>
- withResource(gtEqMinOrNull.lessOrEqualTo(maxValue)) { isLtEqMax =>
- if (ansiMode) {
- withResource(isLtEqMax.all(DType.BOOL8)) { all =>
- if (!all.getBoolean) {
- throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
- }
- }
- }
- isLtEqMax.ifElse(gtEqMinOrNull, nulls)
- }
- }
-
- // cast the final values
- withResource(ltEqMaxOrNull) { ltEqMaxOrNull =>
- ltEqMaxOrNull.castTo(castToType)
- }
- }
- }
- }
-
- }
- }
-
private def castIntegralsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {
// Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison.
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala
new file mode 100644
index 000000000000..c5cb3f69b3db
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import ai.rapids.cudf.{ColumnVector, Scalar}
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.types.{DataType, StringType}
+
+case class GpuGetJsonObject(json: Expression, path: Expression) extends GpuBinaryExpression with
+ ExpectsInputTypes {
+ override def left: Expression = json
+ override def right: Expression = path
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+ override def nullable: Boolean = true
+ override def prettyName: String = "get_json_object"
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = {
+ throw new UnsupportedOperationException("JSON path must be a scalar value")
+ }
+
+ override def doColumnar(lhs: Scalar, rhs: GpuColumnVector): ColumnVector = {
+ throw new UnsupportedOperationException("JSON path must be a scalar value")
+ }
+
+ override def doColumnar(lhs: GpuColumnVector, rhs: Scalar): ColumnVector = {
+ lhs.getBase().getJSONObject(rhs)
+ }
+
+ override def doColumnar(numRows: Int, lhs: Scalar, rhs: Scalar): ColumnVector = {
+ withResource(GpuColumnVector.from(lhs, numRows, left.dataType)) { expandedLhs =>
+ doColumnar(expandedLhs, rhs)
+ }
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala
index ebb764270bb8..9e02b4430b52 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioning.scala
@@ -16,10 +16,9 @@
package com.nvidia.spark.rapids
-import scala.collection.mutable.ArrayBuffer
-
-import ai.rapids.cudf.{ColumnVector, DType, NvtxColor, NvtxRange, Table}
+import ai.rapids.cudf.{DType, NvtxColor, NvtxRange}
+import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashClusteredDistribution}
import org.apache.spark.sql.rapids.GpuMurmur3Hash
@@ -47,60 +46,40 @@ case class GpuHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
}
}
- override def columnarEval(batch: ColumnarBatch): Any = {
- // We are doing this here because the cudf partition command is at this level
- val numRows = batch.numRows
- withResource(new NvtxRange("Hash partition", NvtxColor.PURPLE)) { _ =>
- val sortedTable = withResource(batch) { batch =>
- val parts = withResource(new NvtxRange("Calculate part", NvtxColor.CYAN)) { _ =>
- withResource(GpuMurmur3Hash.compute(batch, expressions)) { hash =>
- withResource(GpuScalar.from(numPartitions, IntegerType)) { partsLit =>
- hash.pmod(partsLit, DType.INT32)
- }
- }
- }
- withResource(new NvtxRange("sort by part", NvtxColor.DARK_GREEN)) { _ =>
- withResource(parts) { parts =>
- val allColumns = new ArrayBuffer[ColumnVector](batch.numCols() + 1)
- allColumns += parts
- allColumns ++= GpuColumnVector.extractBases(batch)
- withResource(new Table(allColumns: _*)) { fullTable =>
- fullTable.orderBy(Table.asc(0))
- }
+ def partitionInternalAndClose(batch: ColumnarBatch): (Array[Int], Array[GpuColumnVector]) = {
+ val types = GpuColumnVector.extractTypes(batch)
+ val partedTable = withResource(batch) { batch =>
+ val parts = withResource(new NvtxRange("Calculate part", NvtxColor.CYAN)) { _ =>
+ withResource(GpuMurmur3Hash.compute(batch, expressions)) { hash =>
+ withResource(GpuScalar.from(numPartitions, IntegerType)) { partsLit =>
+ hash.pmod(partsLit, DType.INT32)
}
}
}
- val (partitionIndexes, partitionColumns) = withResource(sortedTable) { sortedTable =>
- val cutoffs = withResource(new Table(sortedTable.getColumn(0))) { justPartitions =>
- val partsTable = withResource(GpuScalar.from(0, IntegerType)) { zeroLit =>
- withResource(ColumnVector.sequence(zeroLit, numPartitions)) { partsColumn =>
- new Table(partsColumn)
- }
- }
- withResource(partsTable) { partsTable =>
- justPartitions.upperBound(Array(false), partsTable, Array(false))
- }
- }
- val partitionIndexes = withResource(cutoffs) { cutoffs =>
- val buffer = new ArrayBuffer[Int](numPartitions)
- // The first index is always 0
- buffer += 0
- withResource(cutoffs.copyToHost()) { hostCutoffs =>
- (0 until numPartitions).foreach { i =>
- buffer += hostCutoffs.getInt(i)
- }
- }
- buffer.toArray
+ withResource(parts) { parts =>
+ withResource(GpuColumnVector.from(batch)) { table =>
+ table.partition(parts, numPartitions)
}
- val dataTypes = GpuColumnVector.extractTypes(batch)
- closeOnExcept(new ArrayBuffer[GpuColumnVector]()) { partitionColumns =>
- (1 until sortedTable.getNumberOfColumns).foreach { index =>
- partitionColumns +=
- GpuColumnVector.from(sortedTable.getColumn(index).incRefCount(),
- dataTypes(index - 1))
- }
+ }
+ }
+ withResource(partedTable) { partedTable =>
+ val parts = partedTable.getPartitions
+ val tp = partedTable.getTable
+ val columns = (0 until partedTable.getNumberOfColumns.toInt).zip(types).map {
+ case (index, sparkType) =>
+ GpuColumnVector.from(tp.getColumn(index).incRefCount(), sparkType)
+ }
+ (parts, columns.toArray)
+ }
+ }
- (partitionIndexes, partitionColumns.toArray)
+ override def columnarEval(batch: ColumnarBatch): Any = {
+ // We are doing this here because the cudf partition command is at this level
+ withResource(new NvtxRange("Hash partition", NvtxColor.PURPLE)) { _ =>
+ val numRows = batch.numRows
+ val (partitionIndexes, partitionColumns) = {
+ withResource(new NvtxRange("partition", NvtxColor.BLUE)) { _ =>
+ partitionInternalAndClose(batch)
}
}
val ret = withResource(partitionColumns) { partitionColumns =>
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 317fc3bd0135..0667ccd5fe51 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -424,6 +424,16 @@ object GpuOverrides {
"\\S", "\\v", "\\V", "\\w", "\\w", "\\p", "$", "\\b", "\\B", "\\A", "\\G", "\\Z", "\\z", "\\R",
"?", "|", "(", ")", "{", "}", "\\k", "\\Q", "\\E", ":", "!", "<=", ">")
+ private[this] val _commonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
+
+ private[this] val pluginSupportedOrderableSig = _commonTypes +
+ TypeSig.STRUCT.nested(_commonTypes)
+
+ private[this] def isStructType(dataType: DataType) = dataType match {
+ case StructType(_) => true
+ case _ => false
+ }
+
// this listener mechanism is global and is intended for use by unit tests only
private val listeners: ListBuffer[GpuOverridesListener] = new ListBuffer[GpuOverridesListener]()
@@ -1814,16 +1824,28 @@ object GpuOverrides {
expr[SortOrder](
"Sort order",
ExprChecks.projectOnly(
- TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
+ pluginSupportedOrderableSig,
TypeSig.orderable,
Seq(ParamCheck(
"input",
- TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
+ pluginSupportedOrderableSig,
TypeSig.orderable))),
- (a, conf, p, r) => new BaseExprMeta[SortOrder](a, conf, p, r) {
+ (sortOrder, conf, p, r) => new BaseExprMeta[SortOrder](sortOrder, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ if (isStructType(sortOrder.dataType)) {
+ val nullOrdering = sortOrder.nullOrdering
+ val directionDefaultNullOrdering = sortOrder.direction.defaultNullOrdering
+ val direction = sortOrder.direction.sql
+ if (nullOrdering != directionDefaultNullOrdering) {
+ willNotWorkOnGpu(s"only default null ordering $directionDefaultNullOrdering " +
+ s"for direction $direction is supported for nested types; actual: ${nullOrdering}")
+ }
+ }
+ }
+
// One of the few expressions that are not replaced with a GPU version
override def convertToGpu(): Expression =
- a.withNewChildren(childExprs.map(_.convertToGpu()))
+ sortOrder.withNewChildren(childExprs.map(_.convertToGpu()))
}),
expr[Count](
"Count aggregate operator",
@@ -2416,6 +2438,16 @@ object GpuOverrides {
override def convertToGpu(): GpuExpression = GpuCollectList(
childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset)
}),
+ expr[GetJsonObject](
+ "Extracts a json object from path",
+ ExprChecks.projectOnly(
+ TypeSig.STRING, TypeSig.STRING, Seq(ParamCheck("json", TypeSig.STRING, TypeSig.STRING),
+ ParamCheck("path", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))),
+ (a, conf, p, r) => new BinaryExprMeta[GetJsonObject](a, conf, p, r) {
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
+ GpuGetJsonObject(lhs, rhs)
+ }
+ ),
expr[ScalarSubquery](
"Subquery that will return only one row and one column",
ExprChecks.projectOnly(
@@ -2500,6 +2532,14 @@ object GpuOverrides {
override val childExprs: Seq[BaseExprMeta[_]] =
rp.ordering.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ override def tagPartForGpu() {
+ val numPartitions = rp.numPartitions
+ if (numPartitions > 1 && rp.ordering.exists(so => isStructType(so.dataType))) {
+ willNotWorkOnGpu("only single partition sort is supported for nested types, " +
+ s"actual partitions: $numPartitions")
+ }
+ }
+
override def convertToGpu(): GpuPartitioning = {
if (rp.numPartitions > 1) {
val gpuOrdering = childExprs.map(_.convertToGpu()).asInstanceOf[Seq[SortOrder]]
@@ -2613,7 +2653,7 @@ object GpuOverrides {
}),
exec[TakeOrderedAndProjectExec](
"Take the first limit elements as defined by the sortOrder, and do projection if needed.",
- ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.NULL, TypeSig.all),
+ ExecChecks(pluginSupportedOrderableSig, TypeSig.all),
(takeExec, conf, p, r) =>
new SparkPlanMeta[TakeOrderedAndProjectExec](takeExec, conf, p, r) {
val sortOrder: Seq[BaseExprMeta[SortOrder]] =
@@ -2679,7 +2719,7 @@ object GpuOverrides {
}),
exec[CollectLimitExec](
"Reduce to single partition and apply limit",
- ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
+ ExecChecks(pluginSupportedOrderableSig, TypeSig.all),
(collectLimitExec, conf, p, r) => new GpuCollectLimitMeta(collectLimitExec, conf, p, r))
.disabledByDefault("Collect Limit replacement can be slower on the GPU, if huge number " +
"of rows in a batch it could help by limiting the number of rows transferred from " +
@@ -2752,9 +2792,16 @@ object GpuOverrides {
"The backend for the sort operator",
// The SortOrder TypeSig will govern what types can actually be used as sorting key data type.
// The types below are allowed as inputs and outputs.
- ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.ARRAY +
- TypeSig.STRUCT).nested(), TypeSig.all),
- (sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r)),
+ ExecChecks(pluginSupportedOrderableSig + (TypeSig.ARRAY + TypeSig.STRUCT).nested(),
+ TypeSig.all),
+ (sort, conf, p, r) => new GpuSortMeta(sort, conf, p, r) {
+ override def tagPlanForGpu() {
+ if (!conf.stableSort && sort.sortOrder.exists(so => isStructType(so.dataType))) {
+ willNotWorkOnGpu("it's disabled for nested types " +
+ s"unless ${RapidsConf.STABLE_SORT.key} is true")
+ }
+ }
+ }),
exec[ExpandExec](
"The backend for the expand operator",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
index d00ddb23e875..c0960c70f617 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
@@ -16,7 +16,9 @@
package com.nvidia.spark.rapids
+import org.apache.spark.RangePartitioner
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.RangePartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
@@ -406,8 +408,11 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
case _: GpuColumnarToRowExecParent => () // Ignored
case _: ExecutedCommandExec => () // Ignored
case _: RDDScanExec => () // Ignored
- case _: ShuffleExchangeExec => () // Ignored for now, we don't force it to the GPU if
- // children are not on the gpu
+ case shuffleExchange: ShuffleExchangeExec if conf.cpuRangePartitioningPermitted
+ || !shuffleExchange.outputPartitioning.isInstanceOf[RangePartitioning] => {
+ // Ignored for now, we don't force it to the GPU if
+ // children are not on the gpu
+ }
case other =>
if (!plan.supportsColumnar &&
!conf.testingAllowedNonGpu.contains(getBaseNameFromClass(other.getClass.toString))) {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala
index d52d25d54341..d68bd12c3596 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala
@@ -107,14 +107,23 @@ trait RapidsBuffer extends AutoCloseable {
def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch
/**
- * Get the underlying memory buffer. This may be either a HostMemoryBuffer
- * or a DeviceMemoryBuffer depending on where the buffer currently resides.
+ * Get the underlying memory buffer. This may be either a HostMemoryBuffer or a DeviceMemoryBuffer
+ * depending on where the buffer currently resides.
* The caller must have successfully acquired the buffer beforehand.
* @see [[addReference]]
* @note It is the responsibility of the caller to close the buffer.
*/
def getMemoryBuffer: MemoryBuffer
+ /**
+ * Get the device memory buffer from the underlying storage. If the buffer currently resides
+ * outside of device memory, a new DeviceMemoryBuffer is created with the data copied over.
+ * The caller must have successfully acquired the buffer beforehand.
+ * @see [[addReference]]
+ * @note It is the responsibility of the caller to close the buffer.
+ */
+ def getDeviceMemoryBuffer: DeviceMemoryBuffer
+
/**
* Try to add a reference to this buffer to acquire it.
* @note The close method must be called for every successfully obtained reference.
@@ -184,6 +193,9 @@ sealed class DegenerateRapidsBuffer(
override def getMemoryBuffer: MemoryBuffer =
throw new UnsupportedOperationException("degenerate buffer has no memory buffer")
+ override def getDeviceMemoryBuffer: DeviceMemoryBuffer =
+ throw new UnsupportedOperationException("degenerate buffer has no device memory buffer")
+
override def addReference(): Boolean = true
override def getSpillPriority: Long = Long.MaxValue
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala
index ad9eaee836e4..bd48e0d7a1a3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala
@@ -27,26 +27,34 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.RapidsDiskBlockManager
+/**
+ * Exception thrown when inserting a buffer into the catalog with a duplicate buffer ID
+ * and storage tier combination.
+ */
+class DuplicateBufferException(s: String) extends RuntimeException(s) {}
+
/**
* Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally
* `RapidsBufferCatalog.singleton` should be used instead.
*/
class RapidsBufferCatalog extends Logging {
- /** Map of buffer IDs to buffers */
- private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, RapidsBuffer]
+ /** Map of buffer IDs to buffers sorted by storage tier */
+ private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]]
/**
- * Lookup the buffer that corresponds to the specified buffer ID and acquire it.
+ * Lookup the buffer that corresponds to the specified buffer ID at the highest storage tier,
+ * and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
* @param id buffer identifier
* @return buffer that has been acquired
*/
def acquireBuffer(id: RapidsBufferId): RapidsBuffer = {
(0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ =>
- val buffer = bufferMap.get(id)
- if (buffer == null) {
- throw new NoSuchElementException(s"Cannot locate buffer associated with ID: $id")
+ val buffers = bufferMap.get(id)
+ if (buffers == null || buffers.isEmpty) {
+ throw new NoSuchElementException(s"Cannot locate buffers associated with ID: $id")
}
+ val buffer = buffers.head
if (buffer.addReference()) {
return buffer
}
@@ -54,51 +62,90 @@ class RapidsBufferCatalog extends Logging {
throw new IllegalStateException(s"Unable to acquire buffer for ID: $id")
}
+ /**
+ * Lookup the buffer that corresponds to the specified buffer ID at the specified storage tier,
+ * and acquire it.
+ * NOTE: It is the responsibility of the caller to close the buffer.
+ * @param id buffer identifier
+ * @return buffer that has been acquired, None if not found
+ */
+ def acquireBuffer(id: RapidsBufferId, tier: StorageTier): Option[RapidsBuffer] = {
+ val buffers = bufferMap.get(id)
+ if (buffers != null) {
+ buffers.find(_.storageTier == tier).foreach(buffer =>
+ if (buffer.addReference()) {
+ return Some(buffer)
+ }
+ )
+ }
+ None
+ }
+
+ /**
+ * Check if the buffer that corresponds to the specified buffer ID is stored in a slower storage
+ * tier.
+ *
+ * @param id buffer identifier
+ * @param tier storage tier to check
+ * @return true if the buffer is stored in multiple tiers
+ */
+ def isBufferSpilled(id: RapidsBufferId, tier: StorageTier): Boolean = {
+ val buffers = bufferMap.get(id)
+ buffers != null && buffers.exists(_.storageTier > tier)
+ }
+
/** Get the table metadata corresponding to a buffer ID. */
def getBufferMeta(id: RapidsBufferId): TableMeta = {
- val buffer = bufferMap.get(id)
- if (buffer == null) {
+ val buffers = bufferMap.get(id)
+ if (buffers == null || buffers.isEmpty) {
throw new NoSuchElementException(s"Cannot locate buffer associated with ID: $id")
}
- buffer.meta
+ buffers.head.meta
}
/**
* Register a new buffer with the catalog. An exception will be thrown if an
- * existing buffer was registered with the same buffer ID.
+ * existing buffer was registered with the same buffer ID and storage tier.
*/
def registerNewBuffer(buffer: RapidsBuffer): Unit = {
- val old = bufferMap.putIfAbsent(buffer.id, buffer)
- if (old != null) {
- throw new IllegalStateException(s"Buffer ID ${buffer.id} already registered $old")
+ val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] {
+ override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = {
+ if (value == null) {
+ Seq(buffer)
+ } else {
+ val(first, second) = value.partition(_.storageTier < buffer.storageTier)
+ if (second.nonEmpty && second.head.storageTier == buffer.storageTier) {
+ throw new DuplicateBufferException(
+ s"Buffer ID ${buffer.id} at tier ${buffer.storageTier} already registered " +
+ s"${second.head}")
+ }
+ first ++ Seq(buffer) ++ second
+ }
+ }
}
+ bufferMap.compute(buffer.id, updater)
}
- /**
- * Replace the mapping at the specified tier with a specified buffer.
- * NOTE: The mapping will not be updated if the current mapping is to a higher priority
- * storage tier.
- * @param tier the storage tier of the buffer being replaced
- * @param buffer the new buffer to associate
- */
- def updateBufferMap(tier: StorageTier, buffer: RapidsBuffer): Unit = {
- val updater = new BiFunction[RapidsBufferId, RapidsBuffer, RapidsBuffer] {
- override def apply(key: RapidsBufferId, value: RapidsBuffer): RapidsBuffer = {
- if (value == null || value.storageTier >= tier) {
- buffer
+ /** Remove a buffer ID from the catalog at the specified storage tier. */
+ def removeBufferTier(id: RapidsBufferId, tier: StorageTier): Unit = {
+ val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] {
+ override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = {
+ val updated = value.filter(_.storageTier != tier)
+ if (updated.isEmpty) {
+ null
} else {
- value
+ updated
}
}
}
- bufferMap.compute(buffer.id, updater)
+ bufferMap.computeIfPresent(id, updater)
}
- /** Remove a buffer ID from the catalog and release the resources of the registered buffer. */
+ /** Remove a buffer ID from the catalog and release the resources of the registered buffers. */
def removeBuffer(id: RapidsBufferId): Unit = {
- val buffer = bufferMap.remove(id)
- if (buffer != null) {
- buffer.free()
+ val buffers = bufferMap.remove(id)
+ if (buffers != null) {
+ buffers.foreach(_.free())
}
}
@@ -115,6 +162,7 @@ object RapidsBufferCatalog extends Logging with Arm {
private var diskStorage: RapidsDiskStore = _
private var gdsStorage: RapidsGdsStore = _
private var memoryEventHandler: DeviceMemoryEventHandler = _
+ private var _shouldUnspill: Boolean = _
private lazy val conf: SparkConf = {
val env = SparkEnv.get
@@ -145,6 +193,8 @@ object RapidsBufferCatalog extends Logging with Arm {
logInfo("Installing GPU memory handler for spill")
memoryEventHandler = new DeviceMemoryEventHandler(deviceStorage, rapidsConf.gpuOomDumpDir)
Rmm.setEventHandler(memoryEventHandler)
+
+ _shouldUnspill = rapidsConf.isUnspillEnabled
}
def close(): Unit = {
@@ -180,6 +230,8 @@ object RapidsBufferCatalog extends Logging with Arm {
def getDeviceStorage: RapidsDeviceMemoryStore = deviceStorage
+ def shouldUnspill: Boolean = _shouldUnspill
+
/**
* Adds a contiguous table to the device storage, taking ownership of the table.
* @param id buffer ID to associate with this buffer
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala
index 1e618d611dec..934939b35890 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala
@@ -20,8 +20,8 @@ import java.util.Comparator
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
-import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, NvtxColor, NvtxRange}
-import com.nvidia.spark.rapids.StorageTier.StorageTier
+import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange}
+import com.nvidia.spark.rapids.StorageTier.{DEVICE, StorageTier}
import com.nvidia.spark.rapids.format.TableMeta
import org.apache.spark.internal.Logging
@@ -35,13 +35,14 @@ object RapidsBufferStore {
/**
* Base class for all buffer store types.
*
- * @param name name of this store
+ * @param tier storage tier of this store
* @param catalog catalog to register this store
*/
abstract class RapidsBufferStore(
val tier: StorageTier,
- catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton)
- extends AutoCloseable with Logging {
+ catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton,
+ deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage)
+ extends AutoCloseable with Logging with Arm {
val name: String = tier.toString
@@ -55,7 +56,9 @@ abstract class RapidsBufferStore(
def add(buffer: RapidsBufferBase): Unit = synchronized {
val old = buffers.put(buffer.id, buffer)
- require(old == null, s"duplicate buffer registered: ${buffer.id}")
+ if (old != null) {
+ throw new DuplicateBufferException(s"duplicate buffer registered: ${buffer.id}")
+ }
spillable.offer(buffer)
totalBytesStored += buffer.size
}
@@ -125,13 +128,23 @@ abstract class RapidsBufferStore(
* (i.e.: this method will not take ownership of the incoming buffer object).
* This does not need to update the catalog, the caller is responsible for that.
* @param buffer data from another store
+ * @param memoryBuffer memory buffer obtained from the specified Rapids buffer. It will be closed
+ * by this method
* @param stream CUDA stream to use for copy or null
* @return new buffer that was created
*/
- def copyBuffer(buffer: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase = {
- val newBuffer = createBuffer(buffer, stream)
- buffers.add(newBuffer)
- newBuffer
+ def copyBuffer(buffer: RapidsBuffer, memoryBuffer: MemoryBuffer, stream: Cuda.Stream)
+ : RapidsBufferBase = {
+ val newBuffer = createBuffer(buffer, memoryBuffer, stream)
+ try {
+ buffers.add(newBuffer)
+ catalog.registerNewBuffer(newBuffer)
+ newBuffer
+ } catch {
+ case e: Exception =>
+ newBuffer.free()
+ throw e
+ }
}
/**
@@ -198,10 +211,13 @@ abstract class RapidsBufferStore(
* adding a reference to the existing buffer and later closing it when the transfer completes.
* @note DO NOT close the buffer unless adding a reference!
* @param buffer data from another store
+ * @param memoryBuffer memory buffer obtained from the specified Rapids buffer. It will be closed
+ * by this method
* @param stream CUDA stream to use or null
* @return new buffer tracking the data in this store
*/
- protected def createBuffer(buffer: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase
+ protected def createBuffer(buffer: RapidsBuffer, memoryBuffer: MemoryBuffer, stream: Cuda.Stream)
+ : RapidsBufferBase
/** Update bookkeeping for a new buffer */
protected def addBuffer(buffer: RapidsBufferBase): Unit = synchronized {
@@ -230,17 +246,21 @@ abstract class RapidsBufferStore(
// If we fail to get a reference then this buffer has since been freed and probably best
// to return back to the outer loop to see if enough has been freed.
if (buffer.addReference()) {
- val newBuffer = try {
- logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name} " +
- s"total mem=${buffers.getTotalBytes}")
- buffer.spillCallback(buffer.storageTier, spillStore.tier, buffer.size)
- spillStore.copyBuffer(buffer, stream)
+ try {
+ if (catalog.isBufferSpilled(buffer.id, buffer.storageTier)) {
+ logDebug(s"Skipping spilling $buffer ${buffer.id} to ${spillStore.name} as it is " +
+ s"already stored in multiple tiers total mem=${buffers.getTotalBytes}")
+ catalog.removeBufferTier(buffer.id, buffer.storageTier)
+ } else {
+ logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name} " +
+ s"total mem=${buffers.getTotalBytes}")
+ buffer.spillCallback(buffer.storageTier, spillStore.tier, buffer.size)
+ spillStore.copyBuffer(buffer, buffer.getMemoryBuffer, stream)
+ }
} finally {
buffer.close()
}
- if (newBuffer != null) {
- catalog.updateBufferMap(buffer.storageTier, newBuffer)
- }
+ catalog.removeBufferTier(buffer.id, buffer.storageTier)
buffer.free()
}
}
@@ -251,7 +271,11 @@ abstract class RapidsBufferStore(
override val size: Long,
override val meta: TableMeta,
initialSpillPriority: Long,
- override val spillCallback: RapidsBuffer.SpillCallback) extends RapidsBuffer with Arm {
+ override val spillCallback: RapidsBuffer.SpillCallback,
+ catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton,
+ deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage)
+ extends RapidsBuffer with Arm {
+ private val MAX_UNSPILL_ATTEMPTS = 100
private[this] var isValid = true
protected[this] var refcount = 0
private[this] var spillPriority: Long = initialSpillPriority
@@ -259,6 +283,19 @@ abstract class RapidsBufferStore(
/** Release the underlying resources for this buffer. */
protected def releaseResources(): Unit
+ /**
+ * Materialize the memory buffer from the underlying storage.
+ *
+ * If the buffer resides in device or host memory, only reference count is incremented.
+ * If the buffer resides in secondary storage, a new host or device memory buffer is created,
+ * with the data copied to the new buffer.
+ * The caller must have successfully acquired the buffer beforehand.
+ * @see [[addReference]]
+ * @note It is the responsibility of the caller to close the buffer.
+ * @note This is an internal API only used by Rapids buffer stores.
+ */
+ protected def materializeMemoryBuffer: MemoryBuffer = getMemoryBuffer
+
/**
* Determine if a buffer is currently acquired.
* @note Unless this is called by the thread that currently "owns" an
@@ -282,14 +319,7 @@ abstract class RapidsBufferStore(
// allocated. Allocations can trigger synchronous spills which can
// deadlock if another thread holds the device store lock and is trying
// to spill to this store.
- withResource(DeviceMemoryBuffer.allocate(size)) { deviceBuffer =>
- withResource(getMemoryBuffer) {
- case h: HostMemoryBuffer =>
- logDebug(s"copying from host $h to device $deviceBuffer")
- deviceBuffer.copyFromHostBuffer(h)
- case _ => throw new IllegalStateException(
- "must override getColumnarBatch if not providing a host buffer")
- }
+ withResource(getDeviceMemoryBuffer) { deviceBuffer =>
columnarBatchFromDeviceBuffer(deviceBuffer, sparkTypes)
}
}
@@ -304,6 +334,45 @@ abstract class RapidsBufferStore(
}
}
+ override def getDeviceMemoryBuffer: DeviceMemoryBuffer = {
+ if (RapidsBufferCatalog.shouldUnspill) {
+ (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ =>
+ catalog.acquireBuffer(id, DEVICE) match {
+ case Some(buffer) =>
+ withResource(buffer) { _ =>
+ return buffer.getDeviceMemoryBuffer
+ }
+ case _ =>
+ try {
+ logDebug(s"Unspilling $this $id to $DEVICE")
+ val newBuffer = deviceStorage.copyBuffer(
+ this, materializeMemoryBuffer, Cuda.DEFAULT_STREAM)
+ if (newBuffer.addReference()) {
+ withResource(newBuffer) { _ =>
+ return newBuffer.getDeviceMemoryBuffer
+ }
+ }
+ } catch {
+ case _: DuplicateBufferException =>
+ logDebug(s"Lost device buffer registration race for buffer $id, retrying...")
+ }
+ }
+ }
+ throw new IllegalStateException(s"Unable to get device memory buffer for ID: $id")
+ } else {
+ withResource(materializeMemoryBuffer) {
+ case h: HostMemoryBuffer =>
+ closeOnExcept(DeviceMemoryBuffer.allocate(size)) { deviceBuffer =>
+ logDebug(s"copying from host $h to device $deviceBuffer")
+ deviceBuffer.copyFromHostBuffer(h)
+ deviceBuffer
+ }
+ case d: DeviceMemoryBuffer => d
+ case b => throw new IllegalStateException(s"Unrecognized buffer: $b")
+ }
+ }
+ }
+
override def close(): Unit = synchronized {
if (refcount == 0) {
throw new IllegalStateException("Buffer already closed")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 0f19b654ae61..9a2392a2bdb4 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -352,6 +352,15 @@ object RapidsConf {
.bytesConf(ByteUnit.BYTE)
.createWithDefault(ByteUnit.GiB.toBytes(1))
+ val UNSPILL = conf("spark.rapids.memory.gpu.unspill.enabled")
+ .doc("When a spilled GPU buffer is needed again, should it be unspilled, or only copied " +
+ "back into GPU memory temporarily. Unspilling may be useful for GPU buffers that are " +
+ "needed frequently, for example, broadcast variables; however, it may also increase GPU " +
+ "memory usage")
+ .internal()
+ .booleanConf
+ .createWithDefault(false)
+
val GDS_SPILL = conf("spark.rapids.memory.gpu.direct.storage.spill.enabled")
.doc("Should GPUDirect Storage (GDS) be used to spill GPU memory buffers directly to disk. " +
"GDS must be enabled and the directory `spark.local.dir` must support GDS. This is an " +
@@ -590,6 +599,14 @@ object RapidsConf {
.booleanConf
.createWithDefault(false)
+ val ENABLE_CAST_DECIMAL_TO_STRING = conf("spark.rapids.sql.castDecimalToString.enabled")
+ .doc("When set to true, casting from decimal to string is supported on the GPU. The GPU " +
+ "does NOT produce exact same string as spark produces, but producing strings which are " +
+ "semantically equal. For instance, given input BigDecimal(123, -2), the GPU produces " +
+ "\"12300\", which spark produces \"1.23E+4\".")
+ .booleanConf
+ .createWithDefault(false)
+
val ENABLE_CSV_TIMESTAMPS = conf("spark.rapids.sql.csvTimestamps.enabled")
.doc("When set to true, enables the CSV parser to read timestamps. The default output " +
"format for Spark includes a timezone at the end. Anything except the UTC timezone is not " +
@@ -944,14 +961,14 @@ object RapidsConf {
.internal()
.doc("Default cost of transitioning from GPU to CPU")
.doubleConf
- .createWithDefault(0.15)
+ .createWithDefault(0.1)
val OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST = conf(
"spark.rapids.sql.optimizer.defaultTransitionToGpuCost")
.internal()
.doc("Default cost of transitioning from CPU to GPU")
.doubleConf
- .createWithDefault(0.15)
+ .createWithDefault(0.1)
val USE_ARROW_OPT = conf("spark.rapids.arrowCopyOptimizationEnabled")
.doc("Option to turn off using the optimized Arrow copy code when reading from " +
@@ -961,6 +978,12 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)
+ val CPU_RANGE_PARTITIONING_ALLOWED = conf("spark.rapids.allowCpuRangePartitioning")
+ .doc("Option to control enforcement of range partitioning on GPU.")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
+
private def printSectionHeader(category: String): Unit =
println(s"\n### $category")
@@ -1149,6 +1172,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val hostSpillStorageSize: Long = get(HOST_SPILL_STORAGE_SIZE)
+ lazy val isUnspillEnabled: Boolean = get(UNSPILL)
+
lazy val isGdsSpillEnabled: Boolean = get(GDS_SPILL)
lazy val hasNans: Boolean = get(HAS_NANS)
@@ -1185,14 +1210,14 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val isCastStringToTimestampEnabled: Boolean = get(ENABLE_CAST_STRING_TO_TIMESTAMP)
- lazy val isCastStringToIntegerEnabled: Boolean = get(ENABLE_CAST_STRING_TO_INTEGER)
-
lazy val isCastStringToFloatEnabled: Boolean = get(ENABLE_CAST_STRING_TO_FLOAT)
lazy val isCastStringToDecimalEnabled: Boolean = get(ENABLE_CAST_STRING_TO_DECIMAL)
lazy val isCastFloatToIntegralTypesEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)
+ lazy val isCastDecimalToStringEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_STRING)
+
lazy val isCsvTimestampEnabled: Boolean = get(ENABLE_CSV_TIMESTAMPS)
lazy val isParquetEnabled: Boolean = get(ENABLE_PARQUET)
@@ -1287,6 +1312,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val getAlluxioPathsToReplace: Option[Seq[String]] = get(ALLUXIO_PATHS_REPLACE)
+ lazy val cpuRangePartitioningPermitted = get(CPU_RANGE_PARTITIONING_ALLOWED)
+
def isOperatorEnabled(key: String, incompat: Boolean, isDisabledByDefault: Boolean): Boolean = {
val default = !(isDisabledByDefault || incompat) || (incompat && isIncompatEnabled)
conf.get(key).map(toBoolean(_, key)).getOrElse(default)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala
index 57f2c15ac6cb..3b411a9340b3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids
-import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, MemoryBuffer, Table}
+import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table}
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
@@ -28,11 +28,26 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* @param catalog catalog to register this store
*/
class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton)
- extends RapidsBufferStore(StorageTier.DEVICE, catalog) {
- override protected def createBuffer(
- other: RapidsBuffer,
+ extends RapidsBufferStore(StorageTier.DEVICE, catalog) with Arm {
+
+ override protected def createBuffer(other: RapidsBuffer, memoryBuffer: MemoryBuffer,
stream: Cuda.Stream): RapidsBufferBase = {
- throw new IllegalStateException("should not be spilling to device memory")
+ val deviceBuffer = {
+ memoryBuffer match {
+ case d: DeviceMemoryBuffer => d
+ case h: HostMemoryBuffer =>
+ withResource(h) { _ =>
+ closeOnExcept(DeviceMemoryBuffer.allocate(other.size)) { deviceBuffer =>
+ logDebug(s"copying from host $h to device $deviceBuffer")
+ deviceBuffer.copyFromHostBuffer(h, stream)
+ deviceBuffer
+ }
+ }
+ case b => throw new IllegalStateException(s"Unrecognized buffer: $b")
+ }
+ }
+ new RapidsDeviceMemoryBuffer(other.id, other.size, other.meta, None,
+ deviceBuffer, other.getSpillPriority, other.spillCallback)
}
/**
@@ -163,11 +178,13 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
table.foreach(_.close())
}
- override def getMemoryBuffer: MemoryBuffer = {
+ override def getDeviceMemoryBuffer: DeviceMemoryBuffer = {
contigBuffer.incRefCount()
contigBuffer
}
+ override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer
+
override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
if (table.isDefined) {
//REFCOUNT ++ of all columns
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala
index 399eeee1639a..97de1b564f96 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala
@@ -29,15 +29,14 @@ import org.apache.spark.sql.rapids.RapidsDiskBlockManager
/** A buffer store using files on the local disks. */
class RapidsDiskStore(
diskBlockManager: RapidsDiskBlockManager,
- catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton)
- extends RapidsBufferStore(StorageTier.DISK, catalog) {
+ catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton,
+ deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage)
+ extends RapidsBufferStore(StorageTier.DISK, catalog, deviceStorage) {
private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File]
- override def createBuffer(
- incoming: RapidsBuffer,
+ override protected def createBuffer(incoming: RapidsBuffer, incomingBuffer: MemoryBuffer,
stream: Cuda.Stream): RapidsBufferBase = {
- val incomingBuffer = incoming.getMemoryBuffer
- try {
+ withResource(incomingBuffer) { _ =>
val hostBuffer = incomingBuffer match {
case h: HostMemoryBuffer => h
case _ => throw new UnsupportedOperationException("buffer without host memory")
@@ -58,9 +57,7 @@ class RapidsDiskStore(
}
logDebug(s"Spilled to $path $fileOffset:${incoming.size}")
new this.RapidsDiskBuffer(id, fileOffset, incoming.size, incoming.meta,
- incoming.getSpillPriority, incoming.spillCallback)
- } finally {
- incomingBuffer.close()
+ incoming.getSpillPriority, incoming.spillCallback, deviceStorage)
}
}
@@ -92,8 +89,10 @@ class RapidsDiskStore(
size: Long,
meta: TableMeta,
spillPriority: Long,
- spillCallback: RapidsBuffer.SpillCallback)
- extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) {
+ spillCallback: RapidsBuffer.SpillCallback,
+ deviceStorage: RapidsDeviceMemoryStore)
+ extends RapidsBufferBase(
+ id, size, meta, spillPriority, spillCallback, deviceStorage = deviceStorage) {
private[this] var hostBuffer: Option[HostMemoryBuffer] = None
override val storageTier: StorageTier = StorageTier.DISK
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala
index 2dfeed5511aa..26180827ec53 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala
@@ -24,8 +24,6 @@ import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import org.apache.spark.sql.rapids.RapidsDiskBlockManager
-import org.apache.spark.sql.types.DataType
-import org.apache.spark.sql.vectorized.ColumnarBatch
/** A buffer store using GPUDirect Storage (GDS). */
class RapidsGdsStore(
@@ -34,10 +32,9 @@ class RapidsGdsStore(
extends RapidsBufferStore(StorageTier.GDS, catalog) with Arm {
private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File]
- override def createBuffer(
- other: RapidsBuffer,
+ override protected def createBuffer(other: RapidsBuffer, otherBuffer: MemoryBuffer,
stream: Cuda.Stream): RapidsBufferBase = {
- withResource(other.getMemoryBuffer) { otherBuffer =>
+ withResource(otherBuffer) { _ =>
val deviceBuffer = otherBuffer match {
case d: DeviceMemoryBuffer => d
case _ => throw new IllegalStateException("copying from buffer without device memory")
@@ -66,7 +63,7 @@ class RapidsGdsStore(
class RapidsGdsBuffer(
id: RapidsBufferId,
- fileOffset: Long,
+ val fileOffset: Long,
size: Long,
meta: TableMeta,
spillPriority: Long,
@@ -74,8 +71,9 @@ class RapidsGdsStore(
extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) {
override val storageTier: StorageTier = StorageTier.GDS
- // TODO(rongou): cache this buffer to avoid repeated reads from disk.
- override def getMemoryBuffer: DeviceMemoryBuffer = synchronized {
+ override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer
+
+ override def materializeMemoryBuffer: MemoryBuffer = {
val path = if (id.canShareDiskPaths) {
sharedBufferFiles.get(id)
} else {
@@ -99,11 +97,5 @@ class RapidsGdsStore(
}
}
}
-
- override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = {
- withResource(getMemoryBuffer) { deviceBuffer =>
- columnarBatchFromDeviceBuffer(deviceBuffer, sparkTypes)
- }
- }
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala
index 5022e8f2772e..5f90820dda39 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala
@@ -29,8 +29,9 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil
*/
class RapidsHostMemoryStore(
maxSize: Long,
- catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton)
- extends RapidsBufferStore(StorageTier.HOST, catalog) {
+ catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton,
+ deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage)
+ extends RapidsBufferStore(StorageTier.HOST, catalog, deviceStorage) {
private[this] val pool = HostMemoryBuffer.allocate(maxSize, false)
private[this] val addressAllocator = new AddressSpaceAllocator(maxSize)
private[this] var haveLoggedMaxExceeded = false
@@ -71,12 +72,10 @@ class RapidsHostMemoryStore(
(buffer, true)
}
- override protected def createBuffer(
- other: RapidsBuffer,
+ override protected def createBuffer(other: RapidsBuffer, otherBuffer: MemoryBuffer,
stream: Cuda.Stream): RapidsBufferBase = {
- val (hostBuffer, isPinned) = allocateHostBuffer(other.size)
- try {
- val otherBuffer = other.getMemoryBuffer
+ withResource(otherBuffer) { _ =>
+ val (hostBuffer, isPinned) = allocateHostBuffer(other.size)
try {
otherBuffer match {
case devBuffer: DeviceMemoryBuffer =>
@@ -87,22 +86,21 @@ class RapidsHostMemoryStore(
}
case _ => throw new IllegalStateException("copying from buffer without device memory")
}
- } finally {
- otherBuffer.close()
+ } catch {
+ case e: Exception =>
+ hostBuffer.close()
+ throw e
}
- } catch {
- case e: Exception =>
- hostBuffer.close()
- throw e
+ new RapidsHostMemoryBuffer(
+ other.id,
+ other.size,
+ other.meta,
+ other.getSpillPriority,
+ hostBuffer,
+ isPinned,
+ other.spillCallback,
+ deviceStorage)
}
- new RapidsHostMemoryBuffer(
- other.id,
- other.size,
- other.meta,
- other.getSpillPriority,
- hostBuffer,
- isPinned,
- other.spillCallback)
}
def numBytesFree: Long = maxSize - currentSize
@@ -119,8 +117,10 @@ class RapidsHostMemoryStore(
spillPriority: Long,
buffer: HostMemoryBuffer,
isInternalPoolAllocated: Boolean,
- spillCallback: RapidsBuffer.SpillCallback)
- extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) {
+ spillCallback: RapidsBuffer.SpillCallback,
+ deviceStorage: RapidsDeviceMemoryStore)
+ extends RapidsBufferBase(
+ id, size, meta, spillPriority, spillCallback, deviceStorage = deviceStorage) {
override val storageTier: StorageTier = StorageTier.HOST
override def getMemoryBuffer: MemoryBuffer = {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
index fda3f2bd72c2..f3499c39ed9b 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
@@ -112,6 +112,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
def convertToCpu(): BASE = wrapped
private var cannotBeReplacedReasons: Option[mutable.Set[String]] = None
+ private var mustBeReplacedReasons: Option[mutable.Set[String]] = None
private var cannotReplaceAnyOfPlanReasons: Option[mutable.Set[String]] = None
private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None
protected var cannotRunOnGpuBecauseOfSparkPlan: Boolean = false
@@ -124,7 +125,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
* is reached that is already on CPU.
*/
final def recursiveCostPreventsRunningOnGpu(): Unit = {
- if (canThisBeReplaced) {
+ if (canThisBeReplaced && !mustThisBeReplaced) {
costPreventsRunningOnGpu()
childDataWriteCmds.foreach(_.recursiveCostPreventsRunningOnGpu())
}
@@ -170,6 +171,10 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
}
}
+ final def mustBeReplaced(because: String): Unit = {
+ mustBeReplacedReasons.get.add(because)
+ }
+
/**
* Call this if there is a condition found that the entire plan is not allowed
* to run on the GPU.
@@ -191,6 +196,12 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
*/
final def canThisBeReplaced: Boolean = cannotBeReplacedReasons.exists(_.isEmpty)
+ /**
+ * Returns true iff this must be replaced because its children have already been
+ * replaced and this needs to also be replaced for compatibility.
+ */
+ final def mustThisBeReplaced: Boolean = mustBeReplacedReasons.exists(_.nonEmpty)
+
/**
* Returns the list of reasons the entire plan can't be replaced. An empty
* set means the entire plan is ok to be replaced, do the normal checking
@@ -229,6 +240,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
def initReasons(): Unit = {
cannotBeReplacedReasons = Some(mutable.Set[String]())
+ mustBeReplacedReasons = Some(mutable.Set[String]())
shouldBeRemovedReasons = Some(mutable.Set[String]())
cannotReplaceAnyOfPlanReasons = Some(mutable.Set[String]())
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
index b4361d7631d8..cb57d905ac78 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala
@@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import ai.rapids.cudf.{ColumnVector, NvtxColor, Table}
+import ai.rapids.cudf.{ColumnVector, NvtxColor, OrderByArg, Table}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Expression, NullsFirst, NullsLast, SortOrder}
import org.apache.spark.sql.types.DataType
@@ -33,11 +33,11 @@ object SortUtils extends Arm {
case _ => None
}
- def getOrder(order: SortOrder, index: Int): Table.OrderByArg =
+ def getOrder(order: SortOrder, index: Int): OrderByArg =
if (order.isAscending) {
- Table.asc(index, order.nullOrdering == NullsFirst)
+ OrderByArg.asc(index, order.nullOrdering == NullsFirst)
} else {
- Table.desc(index, order.nullOrdering == NullsLast)
+ OrderByArg.desc(index, order.nullOrdering == NullsLast)
}
}
@@ -88,7 +88,7 @@ class GpuSorter(
private[this] lazy val (sortOrdersThatNeedComputation, cudfOrdering, cpuOrderingInternal) = {
val sortOrdersThatNeedsComputation = mutable.ArrayBuffer[SortOrder]()
val cpuOrdering = mutable.ArrayBuffer[SortOrder]()
- val cudfOrdering = mutable.ArrayBuffer[Table.OrderByArg]()
+ val cudfOrdering = mutable.ArrayBuffer[OrderByArg]()
var newColumnIndex = numInputColumns
// Remove duplicates in the ordering itself because
// there is no need to do it twice.
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
index 0da3ec1cdab4..ec68248c1aa9 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
@@ -187,6 +187,8 @@ trait SparkShims {
def shouldIgnorePath(path: String): Boolean
+ def getLegacyComplexTypeToString(): Boolean
+
def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager)
def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager)
def getArrowOffsetsBuf(vec: ValueVector): (ByteBuffer, ReferenceManager)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index bfd8a6b9b94c..4539271cb5a4 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -170,6 +170,19 @@ final class TypeSig private(
new TypeSig(it, nt, lt, nts)
}
+ /**
+ * Remove a type signature. The reverse of +
+ * @param other what to remove
+ * @return the new signature
+ */
+ def - (other: TypeSig): TypeSig = {
+ val it = initialTypes -- other.initialTypes
+ val nt = nestedTypes -- other.nestedTypes
+ val lt = litOnlyTypes -- other.litOnlyTypes
+ val nts = notes -- other.notes.keySet
+ new TypeSig(it, nt, lt, nts)
+ }
+
/**
* Add nested types to this type signature. Note that these do not stack so if nesting has
* nested types too they are ignored.
@@ -542,18 +555,23 @@ class ExecChecks private(
override def tag(meta: RapidsMeta[_, _, _]): Unit = {
val plan = meta.wrapped.asInstanceOf[SparkPlan]
val allowDecimal = meta.conf.decimalTypeEnabled
- if (!check.areAllSupportedByPlugin(plan.output.map(_.dataType), allowDecimal)) {
- val unsupported = plan.output.map(_.dataType)
- .filter(!check.isSupportedByPlugin(_, allowDecimal))
- .toSet
- meta.willNotWorkOnGpu(s"unsupported data types in output: ${unsupported.mkString(", ")}")
+
+ val unsupportedOutputTypes = plan.output
+ .filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
+ .toSet
+
+ if (unsupportedOutputTypes.nonEmpty) {
+ meta.willNotWorkOnGpu("unsupported data types in output: " +
+ unsupportedOutputTypes.mkString(", "))
}
- if (!check.areAllSupportedByPlugin(
- plan.children.flatMap(_.output.map(_.dataType)),
- allowDecimal)) {
- val unsupported = plan.children.flatMap(_.output.map(_.dataType))
- .filter(!check.isSupportedByPlugin(_, allowDecimal)).toSet
- meta.willNotWorkOnGpu(s"unsupported data types in input: ${unsupported.mkString(", ")}")
+
+ val unsupportedInputTypes = plan.children.flatMap { childPlan =>
+ childPlan.output.filterNot(attr => check.isSupportedByPlugin(attr.dataType, allowDecimal))
+ }.toSet
+
+ if (unsupportedInputTypes.nonEmpty) {
+ meta.willNotWorkOnGpu("unsupported data types in input: " +
+ unsupportedInputTypes.mkString(", "))
}
}
@@ -754,7 +772,7 @@ class CastChecks extends ExprChecks {
val binaryChecks: TypeSig = none
val sparkBinarySig: TypeSig = STRING + BINARY
- val decimalChecks: TypeSig = DECIMAL
+ val decimalChecks: TypeSig = DECIMAL + STRING
val sparkDecimalSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING
val calendarChecks: TypeSig = none
@@ -766,7 +784,8 @@ class CastChecks extends ExprChecks {
val mapChecks: TypeSig = none
val sparkMapSig: TypeSig = STRING + MAP.nested(all)
- val structChecks: TypeSig = none
+ val structChecks: TypeSig = psNote(TypeEnum.STRING, "the struct's children must also support " +
+ "being cast to string")
val sparkStructSig: TypeSig = STRING + STRUCT.nested(all)
val udtChecks: TypeSig = none
@@ -840,8 +859,8 @@ class CastChecks extends ExprChecks {
}
def gpuCanCast(from: DataType, to: DataType, allowDecimal: Boolean = true): Boolean = {
- val (_, sparkSig) = getChecksAndSigs(from)
- sparkSig.isSupportedByPlugin(to, allowDecimal)
+ val (checks, _) = getChecksAndSigs(from)
+ checks.isSupportedByPlugin(to, allowDecimal)
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala
index e0df29805a1f..aeda82eff934 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/decimalExpressions.scala
@@ -56,7 +56,7 @@ case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression {
override def toString: String = s"UnscaledValue($child)"
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
- withResource(input.getBase.logicalCastTo(DType.INT64)) { view =>
+ withResource(input.getBase.bitCastTo(DType.INT64)) { view =>
view.copyToColumnVector()
}
}
@@ -85,13 +85,13 @@ case class GpuMakeDecimal(
}
withResource(overflowed) { overflowed =>
withResource(Scalar.fromNull(outputType)) { nullVal =>
- withResource(base.logicalCastTo(outputType)) { view =>
+ withResource(base.bitCastTo(outputType)) { view =>
overflowed.ifElse(nullVal, view)
}
}
}
} else {
- withResource(base.logicalCastTo(outputType)) { view =>
+ withResource(base.bitCastTo(outputType)) { view =>
view.copyToColumnVector()
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala
index cee61bfe2e0c..5b895eeecad6 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/BufferSendState.scala
@@ -176,6 +176,7 @@ class BufferSendState(
case _: HostMemoryBuffer =>
//TODO: HostMemoryBuffer needs the same functionality that
// DeviceMemoryBuffer has to copy from/to device/host buffers
+ logDebug(s"copying to host memory bounce buffer $memBuff")
CudaUtil.copy(
memBuff,
blockRange.rangeStart,
@@ -186,10 +187,12 @@ class BufferSendState(
memBuff match {
case mh: HostMemoryBuffer =>
// host original => device bounce
+ logDebug(s"copying from host to device memory bounce buffer $memBuff")
d.copyFromHostBufferAsync(buffOffset, mh, blockRange.rangeStart,
blockRange.rangeSize(), serverStream)
case md: DeviceMemoryBuffer =>
// device original => device bounce
+ logDebug(s"copying from device to device memory bounce buffer $memBuff")
d.copyFromDeviceBufferAsync(buffOffset, md, blockRange.rangeStart,
blockRange.rangeSize(), serverStream)
case _ => throw new IllegalStateException("What buffer is this")
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index 0fcd3b4bfcee..6f6e55322eed 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.rapids
import ai.rapids.cudf
import ai.rapids.cudf.{Aggregation, AggregationOnColumn, ColumnVector, DType}
+import ai.rapids.cudf.Aggregation.NullPolicy
import com.nvidia.spark.rapids._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -180,13 +181,11 @@ abstract case class CudfAggregate(ref: Expression) extends GpuUnevaluable {
}
class CudfCount(ref: Expression) extends CudfAggregate(ref) {
- // includeNulls set to false in count aggregate to exclude nulls while calculating count(column)
- val includeNulls = false
override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => cudf.Scalar.fromLong(col.getRowCount - col.getNullCount)
override val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar =
(col: cudf.ColumnVector) => col.sum
- override lazy val updateAggregate: Aggregation = Aggregation.count(includeNulls)
+ override lazy val updateAggregate: Aggregation = Aggregation.count(NullPolicy.EXCLUDE)
override lazy val mergeAggregate: Aggregation = Aggregation.sum()
override def toString(): String = "CudfCount"
}
@@ -241,7 +240,7 @@ class CudfMin(ref: Expression) extends CudfAggregate(ref) {
}
abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) {
- val includeNulls: Boolean
+ val includeNulls: NullPolicy
val offset: Int
override val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar =
@@ -253,22 +252,22 @@ abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) {
}
class CudfFirstIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
- override val includeNulls: Boolean = true
+ override val includeNulls: NullPolicy = NullPolicy.INCLUDE
override val offset: Int = 0
}
class CudfFirstExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
- override val includeNulls: Boolean = false
+ override val includeNulls: NullPolicy = NullPolicy.EXCLUDE
override val offset: Int = 0
}
class CudfLastIncludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
- override val includeNulls: Boolean = true
+ override val includeNulls: NullPolicy = NullPolicy.INCLUDE
override val offset: Int = -1
}
class CudfLastExcludeNulls(ref: Expression) extends CudfFirstLastBase(ref) {
- override val includeNulls: Boolean = false
+ override val includeNulls: NullPolicy = NullPolicy.EXCLUDE
override val offset: Int = -1
}
@@ -399,7 +398,7 @@ case class GpuCount(children: Seq[Expression]) extends GpuDeclarativeAggregate
// we could support it by doing an `Aggregation.nunique(false)`
override lazy val windowInputProjection: Seq[Expression] = inputProjection
override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
- Aggregation.count(false).onColumn(inputs.head._2)
+ Aggregation.count(NullPolicy.EXCLUDE).onColumn(inputs.head._2)
}
case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
index 34a40f3d6d3a..4ef353197920 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids
import scala.collection.mutable
-import ai.rapids.cudf.{ContiguousTable, Table}
+import ai.rapids.cudf.{ContiguousTable, OrderByArg, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.hadoop.fs.Path
@@ -276,7 +276,7 @@ class GpuDynamicPartitionDataWriter(
val columnIds = 0 until t.getNumberOfColumns
val distinct = t.groupBy(columnIds: _*).aggregate()
try {
- distinct.orderBy(columnIds.map(Table.asc(_, nullsSmallest)): _*)
+ distinct.orderBy(columnIds.map(OrderByArg.asc(_, nullsSmallest)): _*)
} finally {
distinct.close()
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala
index f79a7d9d9f6c..4e32dc219e50 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala
@@ -16,14 +16,16 @@
package org.apache.spark.sql.rapids.execution.python
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import ai.rapids.cudf
-import ai.rapids.cudf.{Aggregation, Table}
+import ai.rapids.cudf.{Aggregation, OrderByArg}
+import ai.rapids.cudf.Aggregation.NullPolicy
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
@@ -121,11 +123,11 @@ class GroupingIterator(
withResource(GpuColumnVector.from(projected)) { table =>
table
.groupBy(partitionIndices:_*)
- .aggregate(Aggregation.count(true).onColumn(0))
+ .aggregate(Aggregation.count(NullPolicy.INCLUDE).onColumn(0))
}
}
val orderedTable = withResource(cntTable) { table =>
- table.orderBy(partitionIndices.map(id => Table.asc(id, true)): _*)
+ table.orderBy(partitionIndices.map(id => OrderByArg.asc(id, true)): _*)
}
val (countHostCol, numRows) = withResource(orderedTable) { table =>
// Yes copying the data to host, it would be OK since just copying the aggregated
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
index eb589dd277c6..59960c751a9e 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala
@@ -335,7 +335,6 @@ class AdaptiveQueryExecSuite
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "400")
- .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "50")
// disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions
.set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0")
@@ -370,7 +369,6 @@ class AdaptiveQueryExecSuite
// disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions
.set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0")
.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
- .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(RapidsConf.DECIMAL_TYPE_ENABLED.key, "true")
.set(RapidsConf.TEST_ALLOWED_NONGPU.key, "DataWritingCommandExec")
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
index 38f113602e9c..fdc8472f1dec 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -37,7 +37,6 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
.set("spark.sql.storeAssignmentPolicy", "ANSI") // note this is the default in 3.0.0
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES.key, "true")
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_STRING.key, "true")
- .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true")
@@ -382,6 +381,20 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
comparisonFunc = Some(compareStringifiedFloats))
}
+ test("ansi_cast decimal to string") {
+ val sqlCtx = SparkSession.getActiveSession.get.sqlContext
+ sqlCtx.setConf("spark.sql.legacy.allowNegativeScaleOfDecimal", "true")
+ sqlCtx.setConf("spark.rapids.sql.castDecimalToString.enabled", "true")
+
+ Seq(10, 15, 18).foreach { precision =>
+ Seq(-precision, -5, 0, 5, precision).foreach { scale =>
+ testCastToString(DataTypes.createDecimalType(precision, scale),
+ ansiMode = true,
+ comparisonFunc = Some(compareStringifiedDecimalsInSemantic))
+ }
+ }
+ }
+
private def castToStringExpectedFun[T]: T => Option[String] = (d: T) => Some(String.valueOf(d))
private def testCastToString[T](dataType: DataType, ansiMode: Boolean,
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
index b86c16da5f7e..f145c56ff75a 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
@@ -63,7 +63,6 @@ class CastOpSuite extends GpuExpressionTestSuite {
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES.key, "true")
.set(RapidsConf.ENABLE_CAST_FLOAT_TO_STRING.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true")
- .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true")
.set("spark.sql.ansi.enabled", String.valueOf(ansiEnabled))
@@ -229,6 +228,19 @@ class CastOpSuite extends GpuExpressionTestSuite {
testCastToString[Double](DataTypes.DoubleType, comparisonFunc = Some(compareStringifiedFloats))
}
+ test("cast decimal to string") {
+ val sqlCtx = SparkSession.getActiveSession.get.sqlContext
+ sqlCtx.setConf("spark.sql.legacy.allowNegativeScaleOfDecimal", "true")
+ sqlCtx.setConf("spark.rapids.sql.castDecimalToString.enabled", "true")
+
+ Seq(10, 15, 18).foreach { precision =>
+ Seq(-precision, -5, 0, 5, precision).foreach { scale =>
+ testCastToString(DataTypes.createDecimalType(precision, scale),
+ comparisonFunc = Some(compareStringifiedDecimalsInSemantic))
+ }
+ }
+ }
+
private def testCastToString[T](
dataType: DataType,
comparisonFunc: Option[(String, String) => Boolean] = None) {
@@ -362,6 +374,15 @@ class CastOpSuite extends GpuExpressionTestSuite {
col("doubles").cast(TimestampType))
}
+ testSparkResultsAreEqual("Test cast from strings to int", doublesAsStrings,
+ conf = sparkConf) {
+ frame => frame.select(
+ col("c0").cast(LongType),
+ col("c0").cast(IntegerType),
+ col("c0").cast(ShortType),
+ col("c0").cast(ByteType))
+ }
+
testSparkResultsAreEqual("Test cast from strings to doubles", doublesAsStrings,
conf = sparkConf, maxFloatDiff = 0.0001) {
frame => frame.select(
@@ -473,6 +494,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 2),
scale = 2,
+ ansiEnabled = true,
customRandGenerator = Some(new scala.util.Random(1234L)))
// fromScale > toScale
@@ -481,6 +503,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 2,
+ ansiEnabled = true,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 18),
scale = 15,
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CostBasedOptimizerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CostBasedOptimizerSuite.scala
index 0118ef1baed5..e7de0b806a78 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/CostBasedOptimizerSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/CostBasedOptimizerSuite.scala
@@ -45,6 +45,8 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite with BeforeAndA
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
.set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_CPU_COST.key, "0.15")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST.key, "0.15")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "false")
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.ENABLE_REPLACE_SORTMERGEJOIN.key, "false")
@@ -100,6 +102,8 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite with BeforeAndA
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
.set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_CPU_COST.key, "0.15")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST.key, "0.15")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "false")
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.ENABLE_REPLACE_SORTMERGEJOIN.key, "false")
@@ -155,6 +159,8 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite with BeforeAndA
val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_CPU_COST.key, "0.15")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST.key, "0.15")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "false")
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.TEST_ALLOWED_NONGPU.key,
@@ -194,6 +200,8 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite with BeforeAndA
val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
.set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_CPU_COST.key, "0.15")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST.key, "0.15")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "false")
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.TEST_ALLOWED_NONGPU.key,
@@ -353,6 +361,37 @@ class CostBasedOptimizerSuite extends SparkQueryCompareTestSuite with BeforeAndA
}, conf)
}
+
+ test("keep CustomShuffleReaderExec on GPU") {
+
+ // if we force a GPU CustomShuffleReaderExec back onto CPU due to cost then the query will
+ // fail because the shuffle already happened on GPU and we end up with an invalid plan
+
+ val conf = new SparkConf()
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "1")
+ .set(RapidsConf.OPTIMIZER_ENABLED.key, "true")
+ .set(RapidsConf.OPTIMIZER_EXPLAIN.key, "ALL")
+ .set(RapidsConf.EXPLAIN.key, "ALL")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_CPU_COST.key, "0")
+ .set(RapidsConf.OPTIMIZER_DEFAULT_TRANSITION_TO_GPU_COST.key, "0")
+ .set("spark.rapids.sql.optimizer.exec.CustomShuffleReaderExec", "99999999")
+ .set(RapidsConf.TEST_ALLOWED_NONGPU.key,
+ "ProjectExec,SortMergeJoinExec,SortExec,Alias,Cast,LessThan")
+
+ withGpuSparkSession(spark => {
+ val df1: DataFrame = createQuery(spark).alias("l")
+ val df2: DataFrame = createQuery(spark).alias("r")
+ val df = df1.join(df2,
+ col("l.more_strings_1").equalTo(col("r.more_strings_2")))
+ df.collect()
+
+ println(df.queryExecution.executedPlan)
+
+ df
+ }, conf)
+ }
+
private def createQuery(spark: SparkSession) = {
val df1 = nullableStringsDf(spark)
.repartition(2)
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
index fa0ff51a891f..f42d7e3c65f0 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids
-import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
+import org.apache.spark.sql.types.{DataType, DataTypes, Decimal, DecimalType, StructType}
abstract class GpuExpressionTestSuite extends SparkQueryCompareTestSuite {
@@ -172,6 +172,11 @@ abstract class GpuExpressionTestSuite extends SparkQueryCompareTestSuite {
}
}
+ def compareStringifiedDecimalsInSemantic(expected: String, actual: String): Boolean = {
+ (expected == null && actual == null) ||
+ (expected != null && actual != null && Decimal(expected) == Decimal(actual))
+ }
+
private def getAs(column: RapidsHostColumnVector, index: Int, dataType: DataType): Option[Any] = {
if (column.isNullAt(index)) {
None
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala
index 81003d4c42cd..6bd6b6458d0b 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala
@@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import java.io.File
import java.util.NoSuchElementException
-import com.nvidia.spark.rapids.StorageTier.StorageTier
+import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, GDS, HOST, StorageTier}
import com.nvidia.spark.rapids.format.TableMeta
import org.mockito.Mockito._
import org.scalatest.FunSuite
@@ -44,7 +44,22 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
val buffer = mockBuffer(bufferId)
catalog.registerNewBuffer(buffer)
val buffer2 = mockBuffer(bufferId)
- assertThrows[IllegalStateException](catalog.registerNewBuffer(buffer2))
+ assertThrows[DuplicateBufferException](catalog.registerNewBuffer(buffer2))
+ }
+
+ test("buffer registering slower tier does not hide faster tier") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ val buffer2 = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer2)
+ val buffer3 = mockBuffer(bufferId, tier = DISK)
+ catalog.registerNewBuffer(buffer3)
+ val acquired = catalog.acquireBuffer(MockBufferId(5))
+ assertResult(5)(acquired.id.tableId)
+ assertResult(buffer)(acquired)
+ verify(buffer).addReference()
}
test("acquire buffer") {
@@ -69,6 +84,28 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
verify(buffer, times(9)).addReference()
}
+ test("acquire buffer at specific tier") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ val buffer2 = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer2)
+ val acquired = catalog.acquireBuffer(MockBufferId(5), HOST).get
+ assertResult(5)(acquired.id.tableId)
+ assertResult(buffer2)(acquired)
+ verify(buffer2).addReference()
+ }
+
+ test("acquire buffer at nonexistent tier") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer)
+ assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty)
+ assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty)
+ }
+
test("get buffer meta") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
@@ -79,18 +116,46 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
assertResult(expectedMeta)(meta)
}
- test("update buffer map only updates for faster tier") {
+ test("buffer is spilled to slower tier only") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ val buffer2 = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer2)
+ val buffer3 = mockBuffer(bufferId, tier = DISK)
+ catalog.registerNewBuffer(buffer3)
+ assert(catalog.isBufferSpilled(bufferId, DEVICE))
+ assert(catalog.isBufferSpilled(bufferId, HOST))
+ assert(!catalog.isBufferSpilled(bufferId, DISK))
+ }
+
+ test("remove buffer tier") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
- val buffer1 = mockBuffer(bufferId, tier = StorageTier.HOST)
- catalog.registerNewBuffer(buffer1)
- val buffer2 = mockBuffer(bufferId, tier = StorageTier.DEVICE)
- catalog.updateBufferMap(StorageTier.HOST, buffer2)
- var resultBuffer = catalog.acquireBuffer(bufferId)
- assertResult(buffer2)(resultBuffer)
- catalog.updateBufferMap(StorageTier.HOST, buffer1)
- resultBuffer = catalog.acquireBuffer(bufferId)
- assertResult(buffer2)(resultBuffer)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ val buffer2 = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer2)
+ val buffer3 = mockBuffer(bufferId, tier = DISK)
+ catalog.registerNewBuffer(buffer3)
+ catalog.removeBufferTier(bufferId, DEVICE)
+ catalog.removeBufferTier(bufferId, DISK)
+ assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isEmpty)
+ assert(catalog.acquireBuffer(MockBufferId(5), HOST).isDefined)
+ assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty)
+ }
+
+ test("remove nonexistent buffer tier") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ catalog.removeBufferTier(bufferId, HOST)
+ catalog.removeBufferTier(bufferId, DISK)
+ assert(catalog.acquireBuffer(MockBufferId(5), DEVICE).isDefined)
+ assert(catalog.acquireBuffer(MockBufferId(5), HOST).isEmpty)
+ assert(catalog.acquireBuffer(MockBufferId(5), DISK).isEmpty)
}
test("remove buffer releases buffer resources") {
@@ -102,6 +167,21 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
verify(buffer).free()
}
+ test("remove buffer releases buffer resources at all tiers") {
+ val catalog = new RapidsBufferCatalog
+ val bufferId = MockBufferId(5)
+ val buffer = mockBuffer(bufferId, tier = DEVICE)
+ catalog.registerNewBuffer(buffer)
+ val buffer2 = mockBuffer(bufferId, tier = HOST)
+ catalog.registerNewBuffer(buffer2)
+ val buffer3 = mockBuffer(bufferId, tier = DISK)
+ catalog.registerNewBuffer(buffer3)
+ catalog.removeBuffer(bufferId)
+ verify(buffer).free()
+ verify(buffer2).free()
+ verify(buffer3).free()
+ }
+
private def mockBuffer(
bufferId: RapidsBufferId,
meta: TableMeta = null,
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala
index 5159166f25c6..737675376b48 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala
@@ -131,7 +131,8 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
test("cannot receive spilled buffers") {
val catalog = new RapidsBufferCatalog
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
- assertThrows[IllegalStateException](store.copyBuffer(mock[RapidsBuffer], Cuda.DEFAULT_STREAM))
+ assertThrows[IllegalStateException](store.copyBuffer(
+ mock[RapidsBuffer], mock[MemoryBuffer], Cuda.DEFAULT_STREAM))
}
}
@@ -204,7 +205,8 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
extends RapidsBufferStore(StorageTier.HOST, catalog) {
val spilledBuffers = new ArrayBuffer[RapidsBufferId]
- override protected def createBuffer(b: RapidsBuffer, s: Cuda.Stream): RapidsBufferBase = {
+ override protected def createBuffer(b: RapidsBuffer, m: MemoryBuffer, s: Cuda.Stream)
+ : RapidsBufferBase = {
spilledBuffers += b.id
new MockRapidsBuffer(b.id, b.size, b.meta, b.getSpillPriority)
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala
index 44286fdc3e0f..e1769d6f20d1 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala
@@ -21,7 +21,7 @@ import java.math.RoundingMode
import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer, Table}
import org.mockito.ArgumentMatchers
-import org.mockito.Mockito.{spy, verify}
+import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.mockito.MockitoSugar
@@ -69,10 +69,9 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit
val path = bufferId.getDiskPath(null)
assert(path.exists)
assertResult(bufferSize)(path.length)
- verify(catalog).updateBufferMap(
- ArgumentMatchers.eq(StorageTier.DEVICE), ArgumentMatchers.any[RapidsBuffer])
- verify(catalog).updateBufferMap(
- ArgumentMatchers.eq(StorageTier.HOST), ArgumentMatchers.any[RapidsBuffer])
+ verify(catalog, times(3)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer])
+ verify(catalog).removeBufferTier(
+ ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE))
withResource(catalog.acquireBuffer(bufferId)) { buffer =>
assertResult(StorageTier.DISK)(buffer.storageTier)
assertResult(bufferSize)(buffer.size)
@@ -96,23 +95,24 @@ class RapidsDiskStoreSuite extends FunSuite with BeforeAndAfterEach with Arm wit
withResource(new RapidsDeviceMemoryStore(catalog)) { devStore =>
withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore =>
devStore.setSpillStore(hostStore)
- withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore =>
- hostStore.setSpillStore(diskStore)
- addTableToStore(devStore, bufferId, spillPriority)
- val expectedBatch = withResource(catalog.acquireBuffer(bufferId)) { buffer =>
- assertResult(StorageTier.DEVICE)(buffer.storageTier)
- buffer.getColumnarBatch(sparkTypes)
- }
- withResource(expectedBatch) { expectedBatch =>
- devStore.synchronousSpill(0)
- hostStore.synchronousSpill(0)
- withResource(catalog.acquireBuffer(bufferId)) { buffer =>
- assertResult(StorageTier.DISK)(buffer.storageTier)
- withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch =>
- TestUtils.compareBatches(expectedBatch, actualBatch)
+ withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog, devStore)) {
+ diskStore =>
+ hostStore.setSpillStore(diskStore)
+ addTableToStore(devStore, bufferId, spillPriority)
+ val expectedBatch = withResource(catalog.acquireBuffer(bufferId)) { buffer =>
+ assertResult(StorageTier.DEVICE)(buffer.storageTier)
+ buffer.getColumnarBatch(sparkTypes)
+ }
+ withResource(expectedBatch) { expectedBatch =>
+ devStore.synchronousSpill(0)
+ hostStore.synchronousSpill(0)
+ withResource(catalog.acquireBuffer(bufferId)) { buffer =>
+ assertResult(StorageTier.DISK)(buffer.storageTier)
+ withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch =>
+ TestUtils.compareBatches(expectedBatch, actualBatch)
+ }
}
}
- }
}
}
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala
index c66b0bb64118..06fd532ff6d1 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala
@@ -19,10 +19,10 @@ package com.nvidia.spark.rapids
import java.io.File
import java.math.RoundingMode
-import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, Table}
+import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, MemoryBuffer, Table}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.mockito.{ArgumentCaptor, ArgumentMatchers}
-import org.mockito.Mockito.{never, spy, verify}
+import org.mockito.Mockito.{never, spy, times, verify, when}
import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar
@@ -73,8 +73,9 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
devStore.synchronousSpill(0)
assertResult(bufferSize)(hostStore.currentSize)
assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree)
- verify(catalog).updateBufferMap(
- ArgumentMatchers.eq(StorageTier.DEVICE), ArgumentMatchers.any[RapidsBuffer])
+ verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer])
+ verify(catalog).removeBufferTier(
+ ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE))
withResource(catalog.acquireBuffer(bufferId)) { buffer =>
assertResult(StorageTier.HOST)(buffer.storageTier)
assertResult(bufferSize)(buffer.size)
@@ -120,7 +121,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
val hostStoreMaxSize = 1L * 1024 * 1024
val catalog = new RapidsBufferCatalog
withResource(new RapidsDeviceMemoryStore(catalog)) { devStore =>
- withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore =>
+ withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog, devStore)) { hostStore =>
devStore.setSpillStore(hostStore)
withResource(buildContiguousTable()) { ct =>
withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) {
@@ -148,7 +149,8 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
val catalog = new RapidsBufferCatalog
withResource(new RapidsDeviceMemoryStore(catalog)) { devStore =>
val mockStore = mock[RapidsBufferStore]
- withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog)) { hostStore =>
+ when(mockStore.tier) thenReturn(StorageTier.DISK)
+ withResource(new RapidsHostMemoryStore(hostStoreMaxSize, catalog, devStore)) { hostStore =>
devStore.setSpillStore(hostStore)
hostStore.setSpillStore(mockStore)
withResource(buildContiguousTable(1024 * 1024)) { bigTable =>
@@ -158,6 +160,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
devStore.addContiguousTable(bigBufferId, bigTable, spillPriority)
devStore.synchronousSpill(0)
verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer],
+ ArgumentMatchers.any[MemoryBuffer],
ArgumentMatchers.any[Cuda.Stream])
withResource(catalog.acquireBuffer(bigBufferId)) { buffer =>
assertResult(StorageTier.HOST)(buffer.storageTier)
@@ -169,7 +172,8 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
devStore.addContiguousTable(smallBufferId, smallTable, spillPriority)
devStore.synchronousSpill(0)
val ac: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer])
- verify(mockStore).copyBuffer(ac.capture(), ArgumentMatchers.any[Cuda.Stream])
+ verify(mockStore).copyBuffer(ac.capture(), ArgumentMatchers.any[MemoryBuffer],
+ ArgumentMatchers.any[Cuda.Stream])
assertResult(bigBufferId)(ac.getValue.id)
}
}
diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala
index ff31aef0860f..c5cb403b4d50 100644
--- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala
+++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids
import java.util.UUID
-import ai.rapids.cudf.MemoryBuffer
+import ai.rapids.cudf.{DeviceMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.{Arm, RapidsBuffer, RapidsBufferCatalog, RapidsBufferId, SpillableColumnarBatchImpl, StorageTier}
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
@@ -46,6 +46,7 @@ class SpillableColumnarBatchSuite extends FunSuite with Arm {
override val meta: TableMeta = null
override val storageTier: StorageTier = StorageTier.DEVICE
override def getMemoryBuffer: MemoryBuffer = null
+ override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null
override def addReference(): Boolean = true
override def free(): Unit = {}
override def getSpillPriority: Long = 0