From 228cc7987d23aea286dbcf101f979fef2a23b420 Mon Sep 17 00:00:00 2001 From: MithunR Date: Mon, 14 Mar 2022 15:32:38 -0700 Subject: [PATCH 1/2] Implement `maps_column_view` abstraction over `LIST>` (#10380) Fixes #9109. This commit adds a `map` abstraction over a `column_view` of type `LIST>`, where `K` and `V` are key and value types. A list column of structs with two members may thus be viewed as a `map` column. `maps_column_view` is to a `LIST>` column what `lists_column_view` is to a `LIST` column. The `maps_column_view` abstraction provides methods to fetch lists of keys and values (as `LIST` and `LIST` respectively). It also provides map lookup methods to find the values corresponding to a specified key, for each row in the "map" column. E.g. ```c++ auto input_column = get_list_of_structs_col(); // input_column == [ {1:10, 2:20}, {1:100, 3:300}, {2:2000, 3:3000, 4:4000} ]; auto maps_view = cudf::jni::maps_column_view{input_column->view()}; auto keys = maps_view.keys(); // keys == [ {1,2}, {1,3}, {2,3,4} ]; auto values = maps_view.values(); // values == [ {10,20}, {100, 300}, {2000, 3000, 4000} ]; auto lookup_1 = maps_view.get_values_for( *make_numeric_scalar(1) ); // lookup_1 = [ {10, 100, null} ]; ``` This abstraction should help replace the Java/JNI `map_lookup` and `map_contains` kernels, which only handles `MAP`. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Jason Lowe (https://github.com/jlowe) - AJ Schmidt (https://github.com/ajschmidt8) - Nghia Truong (https://github.com/ttnghia) - Jake Hemstad (https://github.com/jrhemstad) URL: https://github.com/rapidsai/cudf/pull/10380 --- conda/recipes/libcudf/meta.yaml | 2 + cpp/include/cudf/lists/detail/contains.hpp | 78 +++++++++++ cpp/include/cudf/lists/detail/extract.hpp | 49 +++++++ cpp/src/lists/contains.cu | 61 ++++---- cpp/src/lists/extract.cu | 35 ++++- .../main/java/ai/rapids/cudf/ColumnView.java | 27 ++-- java/src/main/native/CMakeLists.txt | 1 + .../main/native/include/maps_column_view.hpp | 130 ++++++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 19 +-- java/src/main/native/src/maps_column_view.cu | 94 +++++++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 65 +++++++-- 11 files changed, 495 insertions(+), 66 deletions(-) create mode 100644 cpp/include/cudf/lists/detail/contains.hpp create mode 100644 cpp/include/cudf/lists/detail/extract.hpp create mode 100644 java/src/main/native/include/maps_column_view.hpp create mode 100644 java/src/main/native/src/maps_column_view.cu diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index ebfc649c0d2..4ea4ace11da 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -150,7 +150,9 @@ test: - test -f $PREFIX/include/cudf/labeling/label_bins.hpp - test -f $PREFIX/include/cudf/lists/detail/combine.hpp - test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp + - test -f $PREFIX/include/cudf/lists/detail/contains.hpp - test -f $PREFIX/include/cudf/lists/detail/copying.hpp + - test -f $PREFIX/include/cudf/lists/detail/extract.hpp - test -f $PREFIX/include/cudf/lists/lists_column_factories.hpp - test -f $PREFIX/include/cudf/lists/detail/drop_list_duplicates.hpp - test -f $PREFIX/include/cudf/lists/detail/interleave_columns.hpp diff --git a/cpp/include/cudf/lists/detail/contains.hpp b/cpp/include/cudf/lists/detail/contains.hpp new file mode 100644 index 00000000000..24318e72e98 --- /dev/null +++ b/cpp/include/cudf/lists/detail/contains.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace lists { +namespace detail { + +/** + * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, + * cudf::scalar const&, + * duplicate_find_option, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr index_of( + cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + cudf::lists::duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, + * cudf::column_view const&, + * duplicate_find_option, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr index_of( + cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + cudf::lists::duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::contains(cudf::lists_column_view const&, + * cudf::scalar const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::contains(cudf::lists_column_view const&, + * cudf::column_view const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +} // namespace detail +} // namespace lists +} // namespace cudf diff --git a/cpp/include/cudf/lists/detail/extract.hpp b/cpp/include/cudf/lists/detail/extract.hpp new file mode 100644 index 00000000000..44c31c9ddb2 --- /dev/null +++ b/cpp/include/cudf/lists/detail/extract.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace lists { +namespace detail { + +/** + * @copydoc cudf::lists::extract_list_element(lists_column_view, size_type, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element( + lists_column_view lists_column, + size_type const index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::extract_list_element(lists_column_view, column_view const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element( + lists_column_view lists_column, + column_view const& indices, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +} // namespace detail +} // namespace lists +} // namespace cudf diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index 5d095fdd5a3..5704ff81665 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -251,18 +252,17 @@ std::unique_ptr to_contains(std::unique_ptr&& key_positions, namespace detail { /** - * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, - * cudf::scalar const&, - * duplicate_find_option, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, + * cudf::scalar const&, + * duplicate_find_option, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ -std::unique_ptr index_of( - cudf::lists_column_view const& lists, - cudf::scalar const& search_key, - duplicate_find_option find_option, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr index_of(cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { return search_key.is_valid(stream) ? cudf::type_dispatcher(search_key.type(), @@ -282,18 +282,17 @@ std::unique_ptr index_of( } /** - * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, - * cudf::column_view const&, - * duplicate_find_option, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, + * cudf::column_view const&, + * duplicate_find_option, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ -std::unique_ptr index_of( - cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, - duplicate_find_option find_option, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr index_of(cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { CUDF_EXPECTS(search_keys.size() == lists.size(), "Number of search keys must match list column size."); @@ -316,10 +315,10 @@ std::unique_ptr index_of( } /** - * @copydoc cudf::lists::contains(cudf::lists_column_view const&, - * cudf::scalar const&, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, + * cudf::scalar const&, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ std::unique_ptr contains(cudf::lists_column_view const& lists, cudf::scalar const& search_key, @@ -331,10 +330,10 @@ std::unique_ptr contains(cudf::lists_column_view const& lists, } /** - * @copydoc cudf::lists::contains(cudf::lists_column_view const&, - * cudf::column_view const&, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, + * cudf::column_view const&, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ std::unique_ptr contains(cudf::lists_column_view const& lists, cudf::column_view const& search_keys, diff --git a/cpp/src/lists/extract.cu b/cpp/src/lists/extract.cu index 7c6c612eb25..0e8659b54ff 100644 --- a/cpp/src/lists/extract.cu +++ b/cpp/src/lists/extract.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -107,10 +108,10 @@ std::unique_ptr make_index_offsets(size_type num_lists, rmm::cuda_ * @param stream CUDA stream used for device memory operations and kernel launches. */ template -std::unique_ptr extract_list_element(lists_column_view lists_column, - index_t const& index, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr extract_list_element_impl(lists_column_view lists_column, + index_t const& index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { auto const num_lists = lists_column.size(); if (num_lists == 0) { return empty_like(lists_column.child()); } @@ -135,6 +136,26 @@ std::unique_ptr extract_list_element(lists_column_view lists_column, return std::move(extracted_lists->release().children[lists_column_view::child_column_index]); } +/** + * @copydoc cudf::lists::extract_list_element + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element(lists_column_view lists_column, + size_type const index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return detail::extract_list_element_impl(lists_column, index, stream, mr); +} + +std::unique_ptr extract_list_element(lists_column_view lists_column, + column_view const& indices, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return detail::extract_list_element_impl(lists_column, indices, stream, mr); +} + } // namespace detail /** @@ -146,7 +167,7 @@ std::unique_ptr extract_list_element(lists_column_view const& lists_colu size_type index, rmm::mr::device_memory_resource* mr) { - return detail::extract_list_element(lists_column, index, rmm::cuda_stream_default, mr); + return detail::extract_list_element_impl(lists_column, index, rmm::cuda_stream_default, mr); } /** @@ -160,7 +181,7 @@ std::unique_ptr extract_list_element(lists_column_view const& lists_colu { CUDF_EXPECTS(indices.size() == lists_column.size(), "Index column must have as many elements as lists column."); - return detail::extract_list_element(lists_column, indices, rmm::cuda_stream_default, mr); + return detail::extract_list_element_impl(lists_column, indices, rmm::cuda_stream_default, mr); } } // namespace lists diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 3fe244c0112..ed3ac124216 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3244,17 +3244,23 @@ public final ColumnVector urlEncode() throws CudfException { return new ColumnVector(urlEncode(getNativeView())); } - /** For a column of type List> and a passed in String key, return a string column - * for all the values in the struct that match the key, null otherwise. - * @param key the String scalar to lookup in the column - * @return a string column of values or nulls based on the lookup result + private static void assertIsSupportedMapKeyType(DType keyType) { + boolean isSupportedKeyType = + !keyType.equals(DType.EMPTY) && !keyType.equals(DType.LIST) && !keyType.equals(DType.STRUCT); + assert isSupportedKeyType : "Map lookup by STRUCT and LIST keys is not supported."; + } + + /** + * Given a column of type List> and a key of type X, return a column of type Y, + * where each row in the output column is the Y value corresponding to the X key. + * If the key is not found, the corresponding output value is null. + * @param key the scalar key to lookup in the column + * @return a column of values or nulls based on the lookup result */ public final ColumnVector getMapValue(Scalar key) { - assert type.equals(DType.LIST) : "column type must be a LIST"; - assert key != null : "target string may not be null"; - assert key.getType().equals(DType.STRING) : "target string must be a string scalar"; - + assert key != null : "Lookup key may not be null"; + assertIsSupportedMapKeyType(key.getType()); return new ColumnVector(mapLookup(getNativeView(), key.getScalarHandle())); } @@ -3266,9 +3272,8 @@ public final ColumnVector getMapValue(Scalar key) { */ public final ColumnVector getMapKeyExistence(Scalar key) { assert type.equals(DType.LIST) : "column type must be a LIST"; - assert key != null : "target string may not be null"; - assert key.getType().equals(DType.STRING) : "target must be a string scalar"; - + assert key != null : "Lookup key may not be null"; + assertIsSupportedMapKeyType(key.getType()); return new ColumnVector(mapContains(getNativeView(), key.getScalarHandle())); } diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index ffbeeb155e0..6e0c07bc4f0 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -238,6 +238,7 @@ add_library( src/TableJni.cpp src/aggregation128_utils.cu src/map_lookup.cu + src/maps_column_view.cu src/row_conversion.cu src/check_nvcomp_output_sizes.cu ) diff --git a/java/src/main/native/include/maps_column_view.hpp b/java/src/main/native/include/maps_column_view.hpp new file mode 100644 index 00000000000..26d97fd5789 --- /dev/null +++ b/java/src/main/native/include/maps_column_view.hpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace cudf { + +class scalar; + +namespace jni { + +/** + * @brief Given a column-view of LIST>, an instance of this class + * provides an abstraction of a column of maps. + * + * Each list row is treated as a map of key->value, with possibly repeated keys. + * The list may be looked up by a scalar key, or by a column of keys, to + * retrieve the corresponding value. + */ +class maps_column_view { +public: + maps_column_view(lists_column_view const &lists_of_structs, + rmm::cuda_stream_view stream = rmm::cuda_stream_default); + + // Rule of 5. + maps_column_view(maps_column_view const &maps_view) = default; + maps_column_view(maps_column_view &&maps_view) = default; + maps_column_view &operator=(maps_column_view const &) = default; + maps_column_view &operator=(maps_column_view &&) = default; + ~maps_column_view() = default; + + /** + * @brief Returns number of map rows in the column. + */ + size_type size() const { return keys_.size(); } + + /** + * @brief Getter for keys as a list column. + * + * Note: Keys are not deduped. Repeated keys are returned in order. + */ + lists_column_view const &keys() const { return keys_; } + + /** + * @brief Getter for values as a list column. + * + * Note: Values for repeated keys are not dropped. + */ + lists_column_view const &values() const { return values_; } + + /** + * @brief Map lookup by a column of keys. + * + * The lookup column must have as many rows as the map column, + * and must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * If a key is repeated in a map row, the value corresponding to the last matching + * key is returned. + * If a lookup key is null or not found, the corresponding value is null. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr Column of values corresponding the value of the lookup key. + */ + std::unique_ptr get_values_for( + column_view const &keys, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + + /** + * @brief Map lookup by a scalar key. + * + * The type of the lookup scalar must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * If a key is repeated in a map row, the value corresponding to the last matching + * key is returned. + * If the lookup key is null or not found, the corresponding value is null. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr + */ + std::unique_ptr get_values_for( + scalar const &key, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + + /** + * @brief Check if each map row contains a specified scalar key. + * + * The type of the lookup scalar must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * + * Each row in the returned column contains a bool indicating whether the row contains + * the specified key (`true`) or not (`false`). + * The returned column contains no nulls. i.e. If the search key is null, or if the + * map row is null, the result row is `false`. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr + */ + std::unique_ptr + contains(scalar const &key, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + +private: + lists_column_view keys_, values_; +}; + +} // namespace jni +} // namespace cudf diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index ec9e19e518d..8c8e9b91e8d 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -75,6 +75,7 @@ #include "dtype_utils.hpp" #include "jni_utils.hpp" #include "map_lookup.hpp" +#include "maps_column_view.hpp" using cudf::jni::ptr_as_jlong; using cudf::jni::release_as_jlong; @@ -1361,12 +1362,13 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jc jlong map_column_view, jlong lookup_key) { JNI_NULL_CHECK(env, map_column_view, "column is null", 0); - JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0); + JNI_NULL_CHECK(env, lookup_key, "lookup key is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(map_column_view); - cudf::string_scalar *ss_key = reinterpret_cast(lookup_key); - return release_as_jlong(cudf::jni::map_lookup(*cv, *ss_key)); + auto const *cv = reinterpret_cast(map_column_view); + auto const *scalar_key = reinterpret_cast(lookup_key); + auto const maps_view = cudf::jni::maps_column_view{*cv}; + return release_as_jlong(maps_view.get_values_for(*scalar_key)); } CATCH_STD(env, 0); } @@ -1375,12 +1377,13 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, jlong map_column_view, jlong lookup_key) { JNI_NULL_CHECK(env, map_column_view, "column is null", 0); - JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0); + JNI_NULL_CHECK(env, lookup_key, "lookup key is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(map_column_view); - cudf::string_scalar *ss_key = reinterpret_cast(lookup_key); - return release_as_jlong(cudf::jni::map_contains(*cv, *ss_key)); + auto const *cv = reinterpret_cast(map_column_view); + auto const *scalar_key = reinterpret_cast(lookup_key); + auto const maps_view = cudf::jni::maps_column_view{*cv}; + return release_as_jlong(maps_view.contains(*scalar_key)); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/maps_column_view.cu b/java/src/main/native/src/maps_column_view.cu new file mode 100644 index 00000000000..e2d352812fd --- /dev/null +++ b/java/src/main/native/src/maps_column_view.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace cudf::jni { + +namespace { +column_view make_lists(column_view const &lists_child, lists_column_view const &lists_of_structs) { + return column_view{data_type{type_id::LIST}, + lists_of_structs.size(), + nullptr, + lists_of_structs.null_mask(), + lists_of_structs.null_count(), + lists_of_structs.offset(), + {lists_of_structs.offsets(), lists_child}}; +} +} // namespace + +maps_column_view::maps_column_view(lists_column_view const &lists_of_structs, + rmm::cuda_stream_view stream) + : keys_{make_lists(lists_of_structs.child().child(0), lists_of_structs)}, + values_{make_lists(lists_of_structs.child().child(1), lists_of_structs)} { + auto const structs = lists_of_structs.child(); + CUDF_EXPECTS(structs.type().id() == type_id::STRUCT, + "maps_column_view input must have exactly 1 child (STRUCT) column."); + CUDF_EXPECTS(structs.num_children() == 2, + "maps_column_view key-value struct must have exactly 2 children."); +} + +template +std::unique_ptr get_values_for_impl(maps_column_view const &maps_view, + KeyT const &lookup_keys, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) { + auto const keys_ = maps_view.keys(); + auto const values_ = maps_view.values(); + CUDF_EXPECTS(lookup_keys.type().id() == keys_.child().type().id(), + "Lookup keys must have the same type as the keys of the map column."); + auto key_indices = + lists::detail::index_of(keys_, lookup_keys, lists::duplicate_find_option::FIND_LAST, stream); + auto constexpr absent_offset = size_type{-1}; + auto constexpr nullity_offset = std::numeric_limits::min(); + thrust::replace(rmm::exec_policy(stream), key_indices->mutable_view().template begin(), + key_indices->mutable_view().template end(), absent_offset, + nullity_offset); + return lists::detail::extract_list_element(values_, key_indices->view(), stream, mr); +} + +std::unique_ptr +maps_column_view::get_values_for(column_view const &lookup_keys, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + CUDF_EXPECTS(lookup_keys.size() == size(), + "Lookup keys must have the same size as the map column."); + + return get_values_for_impl(*this, lookup_keys, stream, mr); +} + +std::unique_ptr +maps_column_view::get_values_for(scalar const &lookup_key, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + return get_values_for_impl(*this, lookup_key, stream, mr); +} + +std::unique_ptr maps_column_view::contains(scalar const &lookup_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + CUDF_EXPECTS(lookup_key.type().id() == keys_.child().type().id(), + "Lookup keys must have the same type as the keys of the map column."); + auto const contains = lists::detail::contains(keys_, lookup_key, stream); + + // Replace nulls with BOOL8{false}; + auto const scalar_false = numeric_scalar{false, true, stream}; + return detail::replace_nulls(contains->view(), scalar_false, stream, mr); +} + +} // namespace cudf::jni diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0ba29840156..d1509f14c6e 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -5833,14 +5833,30 @@ void testStructChildValidity() { } @Test - void testGetMapValue() { + void testGetMapValueForInteger() { + List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2))); + List list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 3))); + List list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(5, 4))); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3); + Scalar lookupKey = Scalar.fromInt(1); + ColumnVector res = cv.getMapValue(lookupKey); + ColumnVector expected = ColumnVector.fromBoxedInts(2, 3, null)) { + assertColumnsAreEqual(expected, res); + } + } + + @Test + void testGetMapValueForStrings() { List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "b"))); List list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "c"))); List list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("e", "d"))); HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3); - ColumnVector res = cv.getMapValue(Scalar.fromString("a")); + Scalar lookupKey = Scalar.fromString("a"); + ColumnVector res = cv.getMapValue(lookupKey); ColumnVector expected = ColumnVector.fromStrings("b", "c", null)) { assertColumnsAreEqual(expected, res); } @@ -5851,14 +5867,45 @@ void testGetMapValueEmptyInput() { HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType)); - ColumnVector res = cv.getMapValue(Scalar.fromString("a")); + Scalar lookupKey = Scalar.fromString("a"); + ColumnVector res = cv.getMapValue(lookupKey); ColumnVector expected = ColumnVector.fromStrings()) { assertColumnsAreEqual(expected, res); } } @Test - void testGetMapKeyExistence() { + void testGetMapKeyExistenceForInteger() { + List list1 = Arrays.asList(new HostColumnVector.StructData(1, 2)); + List list2 = Arrays.asList(new HostColumnVector.StructData(1, 3)); + List list3 = Arrays.asList(new HostColumnVector.StructData(5, 4)); + List list4 = Arrays.asList(new HostColumnVector.StructData(1, 7)); + List list5 = Arrays.asList(new HostColumnVector.StructData(1, null)); + List list6 = Arrays.asList(new HostColumnVector.StructData(null, null)); + List list7 = Arrays.asList(new HostColumnVector.StructData()); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); + Scalar lookup1 = Scalar.fromInt(1); + ColumnVector resValidKey = cv.getMapKeyExistence(lookup1); + ColumnVector expectedValid = ColumnVector.fromBoxedBooleans(true, true, false, true, true, false, false); + ColumnVector expectedNull = ColumnVector.fromBoxedBooleans(false, false, false, false, false, false, false); + Scalar lookupNull = Scalar.fromNull(DType.INT32); + ColumnVector resNullKey = cv.getMapKeyExistence(lookupNull)) { + assertColumnsAreEqual(expectedValid, resValidKey); + assertColumnsAreEqual(expectedNull, resNullKey); + } + + AssertionError e = assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); + ColumnVector resNullKey = cv.getMapKeyExistence(null)) { + } + }); + assertTrue(e.getMessage().contains("Lookup key may not be null")); + } + + @Test + void testGetMapKeyExistenceForStrings() { List list1 = Arrays.asList(new HostColumnVector.StructData("a", "b")); List list2 = Arrays.asList(new HostColumnVector.StructData("a", "c")); List list3 = Arrays.asList(new HostColumnVector.StructData("e", "d")); @@ -5869,10 +5916,12 @@ void testGetMapKeyExistence() { HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); - ColumnVector resValidKey = cv.getMapKeyExistence(Scalar.fromString("a")); + Scalar lookupA = Scalar.fromString("a"); + ColumnVector resValidKey = cv.getMapKeyExistence(lookupA); ColumnVector expectedValid = ColumnVector.fromBoxedBooleans(true, true, false, true, true, false, false); ColumnVector expectedNull = ColumnVector.fromBoxedBooleans(false, false, false, false, false, false, false); - ColumnVector resNullKey = cv.getMapKeyExistence(Scalar.fromNull(DType.STRING))) { + Scalar lookupNull = Scalar.fromNull(DType.STRING); + ColumnVector resNullKey = cv.getMapKeyExistence(lookupNull)) { assertColumnsAreEqual(expectedValid, resValidKey); assertColumnsAreEqual(expectedNull, resNullKey); } @@ -5882,10 +5931,8 @@ void testGetMapKeyExistence() { ColumnVector resNullKey = cv.getMapKeyExistence(null)) { } }); - assertTrue(e.getMessage().contains("target string may not be null")); + assertTrue(e.getMessage().contains("Lookup key may not be null")); } - - @Test void testListOfStructsOfStructs() { List list1 = Arrays.asList( From 4596244b5ac185f3c11d4712824ad2d65948cf26 Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Tue, 15 Mar 2022 13:19:33 +0800 Subject: [PATCH 2/2] JNI support for Collect Ops in Reduction (#10427) Exposes public APIs for collect operations as `ReductionAggregation`, which are essential to spark-rapids. In addition, this PR also extends the test framework of Reduction to discriminate output types from input types. Authors: - Alfred Xu (https://github.com/sperlingxx) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/10427 --- .../java/ai/rapids/cudf/ColumnVector.java | 11 + .../ai/rapids/cudf/GroupByAggregation.java | 4 +- .../ai/rapids/cudf/ReductionAggregation.java | 65 ++- .../java/ai/rapids/cudf/ReductionTest.java | 470 +++++++++++++----- 4 files changed, 411 insertions(+), 139 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 11b654ccec6..aab8e7dd475 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1206,6 +1206,17 @@ public static ColumnVector emptyStructs(HostColumnVector.DataType dataType, long } } + /** + * Create a new vector from the given values. + */ + public static ColumnVector fromBooleans(boolean... values) { + byte[] bytes = new byte[values.length]; + for (int i = 0; i < values.length; i++) { + bytes[i] = values[i] ? (byte) 1 : (byte) 0; + } + return build(DType.BOOL8, values.length, (b) -> b.appendArray(bytes)); + } + /** * Create a new vector from the given values. */ diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index 682d844c43c..500d18f7eae 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -243,7 +243,7 @@ public static GroupByAggregation collectList(NullPolicy nullPolicy) { } /** - * Collect the values into a set. All null values will be excluded, and all nan values are regarded as + * Collect the values into a set. All null values will be excluded, and all NaN values are regarded as * unique instances. */ public static GroupByAggregation collectSet() { @@ -270,7 +270,7 @@ public static GroupByAggregation mergeLists() { } /** - * Merge the partial sets produced by multiple CollectSetAggregations. Each null/nan value will be regarded as + * Merge the partial sets produced by multiple CollectSetAggregations. Each null/NaN value will be regarded as * a unique instance. */ public static GroupByAggregation mergeSets() { diff --git a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java index 7eff85dcd0d..9147d6763ac 100644 --- a/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/ReductionAggregation.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -209,4 +209,67 @@ public static ReductionAggregation nth(int offset) { public static ReductionAggregation nth(int offset, NullPolicy nullPolicy) { return new ReductionAggregation(Aggregation.nth(offset, nullPolicy)); } + + /** + * Collect the values into a list. Nulls will be skipped. + */ + public static ReductionAggregation collectList() { + return new ReductionAggregation(Aggregation.collectList()); + } + + /** + * Collect the values into a list. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + */ + public static ReductionAggregation collectList(NullPolicy nullPolicy) { + return new ReductionAggregation(Aggregation.collectList(nullPolicy)); + } + + /** + * Collect the values into a set. All null values will be excluded, and all NaN values are regarded as + * unique instances. + */ + public static ReductionAggregation collectSet() { + return new ReductionAggregation(Aggregation.collectSet()); + } + + /** + * Collect the values into a set. + * + * @param nullPolicy Indicates whether to include/exclude nulls during collection. + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static ReductionAggregation collectSet(NullPolicy nullPolicy, + NullEquality nullEquality, NaNEquality nanEquality) { + return new ReductionAggregation(Aggregation.collectSet(nullPolicy, nullEquality, nanEquality)); + } + + /** + * Merge the partial lists produced by multiple CollectListAggregations. + * NOTICE: The partial lists to be merged should NOT include any null list element (but can include null list entries). + */ + public static ReductionAggregation mergeLists() { + return new ReductionAggregation(Aggregation.mergeLists()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. Each null/NaN value will be regarded as + * a unique instance. + */ + public static ReductionAggregation mergeSets() { + return new ReductionAggregation(Aggregation.mergeSets()); + } + + /** + * Merge the partial sets produced by multiple CollectSetAggregations. + * + * @param nullEquality Flag to specify whether null entries within each list should be considered equal. + * @param nanEquality Flag to specify whether NaN values in floating point column should be considered equal. + */ + public static ReductionAggregation mergeSets(NullEquality nullEquality, NaNEquality nanEquality) { + return new ReductionAggregation(Aggregation.mergeSets(nullEquality, nanEquality)); + } + } diff --git a/java/src/test/java/ai/rapids/cudf/ReductionTest.java b/java/src/test/java/ai/rapids/cudf/ReductionTest.java index 2b26597c8f7..2efd23703bc 100644 --- a/java/src/test/java/ai/rapids/cudf/ReductionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ReductionTest.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,14 @@ */ package ai.rapids.cudf; +import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import java.util.EnumSet; +import java.util.List; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -43,12 +45,14 @@ class ReductionTest extends CudfTestBase { Aggregation.Kind.ANY, Aggregation.Kind.ALL); - private static Scalar buildExpectedScalar(ReductionAggregation op, DType baseType, Object expectedObject) { + private static Scalar buildExpectedScalar(ReductionAggregation op, + HostColumnVector.DataType dataType, Object expectedObject) { + if (expectedObject == null) { - return Scalar.fromNull(baseType); + return Scalar.fromNull(dataType.getType()); } if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { - if (baseType.equals(DType.FLOAT32)) { + if (dataType.getType().equals(DType.FLOAT32)) { return Scalar.fromFloat((Float) expectedObject); } return Scalar.fromDouble((Double) expectedObject); @@ -56,7 +60,7 @@ private static Scalar buildExpectedScalar(ReductionAggregation op, DType baseTyp if (BOOL_REDUCTIONS.contains(op.getWrapped().kind)) { return Scalar.fromBool((Boolean) expectedObject); } - switch (baseType.typeId) { + switch (dataType.getType().typeId) { case BOOL8: return Scalar.fromBool((Boolean) expectedObject); case INT8: @@ -77,177 +81,346 @@ private static Scalar buildExpectedScalar(ReductionAggregation op, DType baseTyp case TIMESTAMP_MILLISECONDS: case TIMESTAMP_MICROSECONDS: case TIMESTAMP_NANOSECONDS: - return Scalar.timestampFromLong(baseType, (Long) expectedObject); + return Scalar.timestampFromLong(dataType.getType(), (Long) expectedObject); case STRING: return Scalar.fromString((String) expectedObject); + case LIST: + HostColumnVector.DataType et = dataType.getChild(0); + ColumnVector col = null; + try { + switch (et.getType().typeId) { + case BOOL8: + col = et.isNullable() ? ColumnVector.fromBoxedBooleans((Boolean[]) expectedObject) : + ColumnVector.fromBooleans((boolean[]) expectedObject); + return Scalar.listFromColumnView(col); + case INT8: + col = et.isNullable() ? ColumnVector.fromBoxedBytes((Byte[]) expectedObject) : + ColumnVector.fromBytes((byte[]) expectedObject); + return Scalar.listFromColumnView(col); + case INT16: + col = et.isNullable() ? ColumnVector.fromBoxedShorts((Short[]) expectedObject) : + ColumnVector.fromShorts((short[]) expectedObject); + return Scalar.listFromColumnView(col); + case INT32: + col = et.isNullable() ? ColumnVector.fromBoxedInts((Integer[]) expectedObject) : + ColumnVector.fromInts((int[]) expectedObject); + return Scalar.listFromColumnView(col); + case INT64: + col = et.isNullable() ? ColumnVector.fromBoxedLongs((Long[]) expectedObject) : + ColumnVector.fromLongs((long[]) expectedObject); + return Scalar.listFromColumnView(col); + case FLOAT32: + col = et.isNullable() ? ColumnVector.fromBoxedFloats((Float[]) expectedObject) : + ColumnVector.fromFloats((float[]) expectedObject); + return Scalar.listFromColumnView(col); + case FLOAT64: + col = et.isNullable() ? ColumnVector.fromBoxedDoubles((Double[]) expectedObject) : + ColumnVector.fromDoubles((double[]) expectedObject); + return Scalar.listFromColumnView(col); + case STRING: + col = ColumnVector.fromStrings((String[]) expectedObject); + return Scalar.listFromColumnView(col); + default: + throw new IllegalArgumentException("Unexpected element type of List: " + et); + } + } finally { + if (col != null) { + col.close(); + } + } default: - throw new IllegalArgumentException("Unexpected type: " + baseType); + throw new IllegalArgumentException("Unexpected type: " + dataType); } } private static Stream createBooleanParams() { Boolean[] vals = new Boolean[]{true, true, null, false, true, false, null}; + HostColumnVector.DataType bool = new HostColumnVector.BasicType(true, DType.BOOL8); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Boolean[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Boolean[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, true, 0.), - Arguments.of(ReductionAggregation.min(), vals, false, 0.), - Arguments.of(ReductionAggregation.max(), vals, true, 0.), - Arguments.of(ReductionAggregation.product(), vals, false, 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, true, 0.), - Arguments.of(ReductionAggregation.mean(), vals, 0.6, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 0.5477225575051662, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 0.3, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, false, 0.) + Arguments.of(ReductionAggregation.sum(), new Boolean[0], bool, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Boolean[]{null, null, null}, bool, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, bool, true, 0.), + Arguments.of(ReductionAggregation.min(), vals, bool, false, 0.), + Arguments.of(ReductionAggregation.max(), vals, bool, true, 0.), + Arguments.of(ReductionAggregation.product(), vals, bool, false, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, bool, true, 0.), + Arguments.of(ReductionAggregation.mean(), vals, bool, 0.6, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, bool, 0.5477225575051662, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, bool, 0.3, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, bool, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, bool, false, 0.) ); } private static Stream createByteParams() { Byte[] vals = new Byte[]{-1, 7, 123, null, 50, 60, 100}; + HostColumnVector.DataType int8 = new HostColumnVector.BasicType(true, DType.INT8); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Byte[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Byte[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, (byte) 83, 0.), - Arguments.of(ReductionAggregation.min(), vals, (byte) -1, 0.), - Arguments.of(ReductionAggregation.max(), vals, (byte) 123, 0.), - Arguments.of(ReductionAggregation.product(), vals, (byte) 160, 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, (byte) 47, 0.), - Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Byte[0], int8, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Byte[]{null, null, null}, int8, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, int8, (byte) 83, 0.), + Arguments.of(ReductionAggregation.min(), vals, int8, (byte) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, int8, (byte) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, int8, (byte) 160, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, int8, (byte) 47, 0.), + Arguments.of(ReductionAggregation.mean(), vals, int8, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, int8, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, int8, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, int8, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, int8, true, 0.) ); } private static Stream createShortParams() { Short[] vals = new Short[]{-1, 7, 123, null, 50, 60, 100}; + HostColumnVector.DataType int16 = new HostColumnVector.BasicType(true, DType.INT16); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Short[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Short[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, (short) 339, 0.), - Arguments.of(ReductionAggregation.min(), vals, (short) -1, 0.), - Arguments.of(ReductionAggregation.max(), vals, (short) 123, 0.), - Arguments.of(ReductionAggregation.product(), vals, (short) -22624, 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, (short) 31279, 0.), - Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Short[0], int16, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Short[]{null, null, null}, int16, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, int16, (short) 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, int16, (short) -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, int16, (short) 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, int16, (short) -22624, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, int16, (short) 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, int16, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, int16, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, int16, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, int16, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, int16, true, 0.) ); } private static Stream createIntParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; + HostColumnVector.BasicType int32 = new HostColumnVector.BasicType(true, DType.INT32); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Integer[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Integer[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, 339, 0.), - Arguments.of(ReductionAggregation.min(), vals, -1, 0.), - Arguments.of(ReductionAggregation.max(), vals, 123, 0.), - Arguments.of(ReductionAggregation.product(), vals, -258300000, 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279, 0.), - Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, true, 0.) + Arguments.of(ReductionAggregation.sum(), new Integer[0], int32, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Integer[]{null, null, null}, int32, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, int32, 339, 0.), + Arguments.of(ReductionAggregation.min(), vals, int32, -1, 0.), + Arguments.of(ReductionAggregation.max(), vals, int32, 123, 0.), + Arguments.of(ReductionAggregation.product(), vals, int32, -258300000, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, int32, 31279, 0.), + Arguments.of(ReductionAggregation.mean(), vals, int32, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, int32, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, int32, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, int32, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, int32, true, 0.) ); } private static Stream createLongParams() { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; + HostColumnVector.BasicType int64 = new HostColumnVector.BasicType(true, DType.INT64); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Long[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Long[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, 339L, 0.), - Arguments.of(ReductionAggregation.min(), vals, -1L, 0.), - Arguments.of(ReductionAggregation.max(), vals, 123L, 0.), - Arguments.of(ReductionAggregation.product(), vals, -258300000L, 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279L, 0.), - Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, true, 0.), - Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Long[0], int64, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Long[]{null, null, null}, int64, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, int64, 339L, 0.), + Arguments.of(ReductionAggregation.min(), vals, int64, -1L, 0.), + Arguments.of(ReductionAggregation.max(), vals, int64, 123L, 0.), + Arguments.of(ReductionAggregation.product(), vals, int64, -258300000L, 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, int64, 31279L, 0.), + Arguments.of(ReductionAggregation.mean(), vals, int64, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, int64, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, int64, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, int64, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, int64, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, int64, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, int64, 111.5, DELTAD) ); } private static Stream createFloatParams() { Float[] vals = new Float[]{-1f, 7f, 123f, null, 50f, 60f, 100f}; + Float[] notNulls = new Float[]{-1f, 7f, 123f, 50f, 60f, 100f}; + Float[] repeats = new Float[]{Float.MIN_VALUE, 7f, 7f, null, null, Float.NaN, Float.NaN, 50f, 50f, 100f}; + HostColumnVector.BasicType fp32 = new HostColumnVector.BasicType(true, DType.FLOAT32); + HostColumnVector.DataType listOfFloat = new HostColumnVector.ListType( + true, new HostColumnVector.BasicType(true, DType.FLOAT32)); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Float[0], null, 0f), - Arguments.of(ReductionAggregation.sum(), new Float[]{null, null, null}, null, 0f), - Arguments.of(ReductionAggregation.sum(), vals, 339f, 0f), - Arguments.of(ReductionAggregation.min(), vals, -1f, 0f), - Arguments.of(ReductionAggregation.max(), vals, 123f, 0f), - Arguments.of(ReductionAggregation.product(), vals, -258300000f, 0f), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279f, 0f), - Arguments.of(ReductionAggregation.mean(), vals, 56.5f, DELTAF), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839f, DELTAF), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1f, DELTAF), - Arguments.of(ReductionAggregation.any(), vals, true, 0f), - Arguments.of(ReductionAggregation.all(), vals, true, 0f) + Arguments.of(ReductionAggregation.sum(), new Float[0], fp32, null, 0f), + Arguments.of(ReductionAggregation.sum(), new Float[]{null, null, null}, fp32, null, 0f), + Arguments.of(ReductionAggregation.sum(), vals, fp32, 339f, 0f), + Arguments.of(ReductionAggregation.min(), vals, fp32, -1f, 0f), + Arguments.of(ReductionAggregation.max(), vals, fp32, 123f, 0f), + Arguments.of(ReductionAggregation.product(), vals, fp32, -258300000f, 0f), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, fp32, 31279f, 0f), + Arguments.of(ReductionAggregation.mean(), vals, fp32, 56.5f, DELTAF), + Arguments.of(ReductionAggregation.standardDeviation(), vals, fp32, 49.24530434467839f, DELTAF), + Arguments.of(ReductionAggregation.variance(), vals, fp32, 2425.1f, DELTAF), + Arguments.of(ReductionAggregation.any(), vals, fp32, true, 0f), + Arguments.of(ReductionAggregation.all(), vals, fp32, true, 0f), + Arguments.of(ReductionAggregation.collectList(NullPolicy.INCLUDE), vals, listOfFloat, vals, 0f), + Arguments.of(ReductionAggregation.collectList(), vals, listOfFloat, notNulls, 0f), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.EXCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN}, 0f), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN, null}, 0f), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN, null, null}, 0f), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN, Float.NaN, null}, 0f), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN, Float.NaN, null, null}, 0f), + Arguments.of(ReductionAggregation.collectSet(), + repeats, listOfFloat, + new Float[]{Float.MIN_VALUE, 7f, 50f, 100f, Float.NaN, Float.NaN}, 0f) ); } private static Stream createDoubleParams() { Double[] vals = new Double[]{-1., 7., 123., null, 50., 60., 100.}; + Double[] notNulls = new Double[]{-1., 7., 123., 50., 60., 100.}; + Double[] repeats = new Double[]{Double.MIN_VALUE, 7., 7., null, null, Double.NaN, Double.NaN, 50., 50., 100.}; + HostColumnVector.BasicType fp64 = new HostColumnVector.BasicType(true, DType.FLOAT64); + HostColumnVector.DataType listOfDouble = new HostColumnVector.ListType( + true, new HostColumnVector.BasicType(true, DType.FLOAT64)); return Stream.of( - Arguments.of(ReductionAggregation.sum(), new Double[0], null, 0.), - Arguments.of(ReductionAggregation.sum(), new Double[]{null, null, null}, null, 0.), - Arguments.of(ReductionAggregation.sum(), vals, 339., 0.), - Arguments.of(ReductionAggregation.min(), vals, -1., 0.), - Arguments.of(ReductionAggregation.max(), vals, 123., 0.), - Arguments.of(ReductionAggregation.product(), vals, -258300000., 0.), - Arguments.of(ReductionAggregation.sumOfSquares(), vals, 31279., 0.), - Arguments.of(ReductionAggregation.mean(), vals, 56.5, DELTAD), - Arguments.of(ReductionAggregation.standardDeviation(), vals, 49.24530434467839, DELTAD), - Arguments.of(ReductionAggregation.variance(), vals, 2425.1, DELTAD), - Arguments.of(ReductionAggregation.any(), vals, true, 0.), - Arguments.of(ReductionAggregation.all(), vals, true, 0.), - Arguments.of(ReductionAggregation.quantile(0.5), vals, 55.0, DELTAD), - Arguments.of(ReductionAggregation.quantile(0.9), vals, 111.5, DELTAD) + Arguments.of(ReductionAggregation.sum(), new Double[0], fp64, null, 0.), + Arguments.of(ReductionAggregation.sum(), new Double[]{null, null, null}, fp64, null, 0.), + Arguments.of(ReductionAggregation.sum(), vals, fp64, 339., 0.), + Arguments.of(ReductionAggregation.min(), vals, fp64, -1., 0.), + Arguments.of(ReductionAggregation.max(), vals, fp64, 123., 0.), + Arguments.of(ReductionAggregation.product(), vals, fp64, -258300000., 0.), + Arguments.of(ReductionAggregation.sumOfSquares(), vals, fp64, 31279., 0.), + Arguments.of(ReductionAggregation.mean(), vals, fp64, 56.5, DELTAD), + Arguments.of(ReductionAggregation.standardDeviation(), vals, fp64, 49.24530434467839, DELTAD), + Arguments.of(ReductionAggregation.variance(), vals, fp64, 2425.1, DELTAD), + Arguments.of(ReductionAggregation.any(), vals, fp64, true, 0.), + Arguments.of(ReductionAggregation.all(), vals, fp64, true, 0.), + Arguments.of(ReductionAggregation.quantile(0.5), vals, fp64, 55.0, DELTAD), + Arguments.of(ReductionAggregation.quantile(0.9), vals, fp64, 111.5, DELTAD), + Arguments.of(ReductionAggregation.collectList(NullPolicy.INCLUDE), vals, listOfDouble, vals, 0.), + Arguments.of(ReductionAggregation.collectList(NullPolicy.EXCLUDE), vals, listOfDouble, notNulls, 0.), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.EXCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN}, 0.), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN, null}, 0.), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN, null, null}, 0.), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.EQUAL, NaNEquality.UNEQUAL), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN, Double.NaN, null}, 0.), + Arguments.of(ReductionAggregation.collectSet( + NullPolicy.INCLUDE, NullEquality.UNEQUAL, NaNEquality.UNEQUAL), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN, Double.NaN, null, null}, 0.), + Arguments.of(ReductionAggregation.collectSet(), + repeats, listOfDouble, + new Double[]{Double.MIN_VALUE, 7., 50., 100., Double.NaN, Double.NaN}, 0.) ); } private static Stream createTimestampDaysParams() { Integer[] vals = new Integer[]{-1, 7, 123, null, 50, 60, 100}; + HostColumnVector.BasicType tsDay = new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS); return Stream.of( - Arguments.of(ReductionAggregation.max(), new Integer[0], null), - Arguments.of(ReductionAggregation.max(), new Integer[]{null, null, null}, null), - Arguments.of(ReductionAggregation.max(), vals, 123), - Arguments.of(ReductionAggregation.min(), vals, -1) + Arguments.of(ReductionAggregation.max(), new Integer[0], tsDay, null), + Arguments.of(ReductionAggregation.max(), new Integer[]{null, null, null}, tsDay, null), + Arguments.of(ReductionAggregation.max(), vals, tsDay, 123), + Arguments.of(ReductionAggregation.min(), vals, tsDay, -1) ); } - private static Stream createTimestampResolutionParams() { + private static Stream createTimestampResolutionParams(HostColumnVector.BasicType type) { Long[] vals = new Long[]{-1L, 7L, 123L, null, 50L, 60L, 100L}; return Stream.of( - Arguments.of(ReductionAggregation.max(), new Long[0], null), - Arguments.of(ReductionAggregation.max(), new Long[]{null, null, null}, null), - Arguments.of(ReductionAggregation.min(), vals, -1L), - Arguments.of(ReductionAggregation.max(), vals, 123L) + Arguments.of(ReductionAggregation.max(), new Long[0], type, null), + Arguments.of(ReductionAggregation.max(), new Long[]{null, null, null}, type, null), + Arguments.of(ReductionAggregation.min(), vals, type, -1L), + Arguments.of(ReductionAggregation.max(), vals, type, 123L) + ); + } + + private static Stream createTimestampSecondsParams() { + return createTimestampResolutionParams( + new HostColumnVector.BasicType(true, DType.TIMESTAMP_SECONDS)); + } + + private static Stream createTimestampMilliSecondsParams() { + return createTimestampResolutionParams( + new HostColumnVector.BasicType(true, DType.TIMESTAMP_MILLISECONDS)); + } + + private static Stream createTimestampMicroSecondsParams() { + return createTimestampResolutionParams( + new HostColumnVector.BasicType(true, DType.TIMESTAMP_MICROSECONDS)); + } + + private static Stream createTimestampNanoSecondsParams() { + return createTimestampResolutionParams( + new HostColumnVector.BasicType(true, DType.TIMESTAMP_NANOSECONDS)); + } + + private static Stream createFloatArrayParams() { + List[] inputs = new List[]{ + Lists.newArrayList(-1f, 7f, null), + Lists.newArrayList(7f, 50f, 60f, Float.NaN), + Lists.newArrayList(), + Lists.newArrayList(60f, 100f, Float.NaN, null) + }; + HostColumnVector.DataType fpList = new HostColumnVector.ListType( + true, new HostColumnVector.BasicType(true, DType.FLOAT32)); + return Stream.of( + Arguments.of(ReductionAggregation.mergeLists(), inputs, fpList, + new Float[]{-1f, 7f, null, + 7f, 50f, 60f, Float.NaN, + 60f, 100f, Float.NaN, null}, 0f), + Arguments.of(ReductionAggregation.mergeSets(NullEquality.EQUAL, NaNEquality.ALL_EQUAL), + inputs, fpList, + new Float[]{-1f, 7f, 50f, 60f, 100f, Float.NaN, null}, 0f), + Arguments.of(ReductionAggregation.mergeSets(NullEquality.UNEQUAL, NaNEquality.ALL_EQUAL), + inputs, fpList, + new Float[]{-1f, 7f, 50f, 60f, 100f, Float.NaN, null, null}, 0f), + Arguments.of(ReductionAggregation.mergeSets(NullEquality.EQUAL, NaNEquality.UNEQUAL), + inputs, fpList, + new Float[]{-1f, 7f, 50f, 60f, 100f, Float.NaN, Float.NaN, null}, 0f), + Arguments.of(ReductionAggregation.mergeSets(), + inputs, fpList, + new Float[]{-1f, 7f, 50f, 60f, 100f, Float.NaN, Float.NaN, null, null}, 0f) ); } private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, - Double percentage) { + Double percentage) { if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getDouble(), result.getDouble(), percentage); + } else if (expected.getType().typeId == DType.DTypeEnum.LIST) { + try (ColumnView e = expected.getListAsColumnView(); + ColumnView r = result.getListAsColumnView()) { + AssertUtils.assertColumnsAreEqual(e, r); + } } else { assertEquals(expected, result); } } private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, Scalar result, - Float percentage) { + Float percentage) { if (FLOAT_REDUCTIONS.contains(op.getWrapped().kind)) { assertEqualsWithinPercentage(expected.getFloat(), result.getFloat(), percentage); + } else if (expected.getType().typeId == DType.DTypeEnum.LIST) { + try (ColumnView e = expected.getListAsColumnView(); + ColumnView r = result.getListAsColumnView()) { + AssertUtils.assertColumnsAreEqual(e, r); + } } else { assertEquals(expected, result); } @@ -255,8 +428,9 @@ private static void assertEqualsDelta(ReductionAggregation op, Scalar expected, @ParameterizedTest @MethodSource("createBooleanParams") - void testBoolean(ReductionAggregation op, Boolean[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.BOOL8, expectedObject); + void testBoolean(ReductionAggregation op, Boolean[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedBooleans(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -265,8 +439,9 @@ void testBoolean(ReductionAggregation op, Boolean[] values, Object expectedObjec @ParameterizedTest @MethodSource("createByteParams") - void testByte(ReductionAggregation op, Byte[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.INT8, expectedObject); + void testByte(ReductionAggregation op, Byte[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedBytes(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -275,8 +450,9 @@ void testByte(ReductionAggregation op, Byte[] values, Object expectedObject, Dou @ParameterizedTest @MethodSource("createShortParams") - void testShort(ReductionAggregation op, Short[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.INT16, expectedObject); + void testShort(ReductionAggregation op, Short[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedShorts(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -285,8 +461,9 @@ void testShort(ReductionAggregation op, Short[] values, Object expectedObject, D @ParameterizedTest @MethodSource("createIntParams") - void testInt(ReductionAggregation op, Integer[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.INT32, expectedObject); + void testInt(ReductionAggregation op, Integer[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -295,8 +472,9 @@ void testInt(ReductionAggregation op, Integer[] values, Object expectedObject, D @ParameterizedTest @MethodSource("createLongParams") - void testLong(ReductionAggregation op, Long[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.INT64, expectedObject); + void testLong(ReductionAggregation op, Long[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -305,8 +483,9 @@ void testLong(ReductionAggregation op, Long[] values, Object expectedObject, Dou @ParameterizedTest @MethodSource("createFloatParams") - void testFloat(ReductionAggregation op, Float[] values, Object expectedObject, Float delta) { - try (Scalar expected = buildExpectedScalar(op, DType.FLOAT32, expectedObject); + void testFloat(ReductionAggregation op, Float[] values, + HostColumnVector.DataType type, Object expectedObject, Float delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedFloats(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -315,8 +494,9 @@ void testFloat(ReductionAggregation op, Float[] values, Object expectedObject, F @ParameterizedTest @MethodSource("createDoubleParams") - void testDouble(ReductionAggregation op, Double[] values, Object expectedObject, Double delta) { - try (Scalar expected = buildExpectedScalar(op, DType.FLOAT64, expectedObject); + void testDouble(ReductionAggregation op, Double[] values, + HostColumnVector.DataType type, Object expectedObject, Double delta) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.fromBoxedDoubles(values); Scalar result = v.reduce(op, expected.getType())) { assertEqualsDelta(op, expected, result, delta); @@ -325,8 +505,9 @@ void testDouble(ReductionAggregation op, Double[] values, Object expectedObject, @ParameterizedTest @MethodSource("createTimestampDaysParams") - void testTimestampDays(ReductionAggregation op, Integer[] values, Object expectedObject) { - try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_DAYS, expectedObject); + void testTimestampDays(ReductionAggregation op, Integer[] values, + HostColumnVector.DataType type, Object expectedObject) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(values); Scalar result = v.reduce(op, expected.getType())) { assertEquals(expected, result); @@ -334,9 +515,10 @@ void testTimestampDays(ReductionAggregation op, Integer[] values, Object expecte } @ParameterizedTest - @MethodSource("createTimestampResolutionParams") - void testTimestampSeconds(ReductionAggregation op, Long[] values, Object expectedObject) { - try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_SECONDS, expectedObject); + @MethodSource("createTimestampSecondsParams") + void testTimestampSeconds(ReductionAggregation op, Long[] values, + HostColumnVector.DataType type, Object expectedObject) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.timestampSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { assertEquals(expected, result); @@ -344,9 +526,10 @@ void testTimestampSeconds(ReductionAggregation op, Long[] values, Object expecte } @ParameterizedTest - @MethodSource("createTimestampResolutionParams") - void testTimestampMilliseconds(ReductionAggregation op, Long[] values, Object expectedObject) { - try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MILLISECONDS, expectedObject); + @MethodSource("createTimestampMilliSecondsParams") + void testTimestampMilliseconds(ReductionAggregation op, Long[] values, + HostColumnVector.DataType type, Object expectedObject) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.timestampMilliSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { assertEquals(expected, result); @@ -354,9 +537,10 @@ void testTimestampMilliseconds(ReductionAggregation op, Long[] values, Object ex } @ParameterizedTest - @MethodSource("createTimestampResolutionParams") - void testTimestampMicroseconds(ReductionAggregation op, Long[] values, Object expectedObject) { - try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_MICROSECONDS, expectedObject); + @MethodSource("createTimestampMicroSecondsParams") + void testTimestampMicroseconds(ReductionAggregation op, Long[] values, + HostColumnVector.DataType type, Object expectedObject) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { assertEquals(expected, result); @@ -364,15 +548,29 @@ void testTimestampMicroseconds(ReductionAggregation op, Long[] values, Object ex } @ParameterizedTest - @MethodSource("createTimestampResolutionParams") - void testTimestampNanoseconds(ReductionAggregation op, Long[] values, Object expectedObject) { - try (Scalar expected = buildExpectedScalar(op, DType.TIMESTAMP_NANOSECONDS, expectedObject); + @MethodSource("createTimestampNanoSecondsParams") + void testTimestampNanoseconds(ReductionAggregation op, Long[] values, + HostColumnVector.DataType type, Object expectedObject) { + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); ColumnVector v = ColumnVector.timestampNanoSecondsFromBoxedLongs(values); Scalar result = v.reduce(op, expected.getType())) { assertEquals(expected, result); } } + @ParameterizedTest + @MethodSource("createFloatArrayParams") + void testFloatArray(ReductionAggregation op, List[] values, + HostColumnVector.DataType type, Object expectedObject, Float delta) { + HostColumnVector.DataType listType = new HostColumnVector.ListType( + true, new HostColumnVector.BasicType(true, DType.FLOAT32)); + try (Scalar expected = buildExpectedScalar(op, type, expectedObject); + ColumnVector v = ColumnVector.fromLists(listType, values); + Scalar result = v.reduce(op, expected.getType())) { + assertEqualsDelta(op, expected, result, delta); + } + } + @Test void testWithSetOutputType() { try (Scalar expected = Scalar.fromLong(1 * 2 * 3 * 4L); @@ -387,13 +585,13 @@ void testWithSetOutputType() { assertEquals(expected, result); } - try (Scalar expected = Scalar.fromLong((1*1L) + (2*2L) + (3*3L) + (4*4L)); + try (Scalar expected = Scalar.fromLong((1 * 1L) + (2 * 2L) + (3 * 3L) + (4 * 4L)); ColumnVector cv = ColumnVector.fromBytes(new byte[]{1, 2, 3, 4}); Scalar result = cv.sumOfSquares(DType.INT64)) { assertEquals(expected, result); } - try (Scalar expected = Scalar.fromFloat((1 + 2 + 3 + 4f)/4); + try (Scalar expected = Scalar.fromFloat((1 + 2 + 3 + 4f) / 4); ColumnVector cv = ColumnVector.fromBytes(new byte[]{1, 2, 3, 4}); Scalar result = cv.mean(DType.FLOAT32)) { assertEquals(expected, result);