From 228cc7987d23aea286dbcf101f979fef2a23b420 Mon Sep 17 00:00:00 2001 From: MithunR Date: Mon, 14 Mar 2022 15:32:38 -0700 Subject: [PATCH] Implement `maps_column_view` abstraction over `LIST>` (#10380) Fixes #9109. This commit adds a `map` abstraction over a `column_view` of type `LIST>`, where `K` and `V` are key and value types. A list column of structs with two members may thus be viewed as a `map` column. `maps_column_view` is to a `LIST>` column what `lists_column_view` is to a `LIST` column. The `maps_column_view` abstraction provides methods to fetch lists of keys and values (as `LIST` and `LIST` respectively). It also provides map lookup methods to find the values corresponding to a specified key, for each row in the "map" column. E.g. ```c++ auto input_column = get_list_of_structs_col(); // input_column == [ {1:10, 2:20}, {1:100, 3:300}, {2:2000, 3:3000, 4:4000} ]; auto maps_view = cudf::jni::maps_column_view{input_column->view()}; auto keys = maps_view.keys(); // keys == [ {1,2}, {1,3}, {2,3,4} ]; auto values = maps_view.values(); // values == [ {10,20}, {100, 300}, {2000, 3000, 4000} ]; auto lookup_1 = maps_view.get_values_for( *make_numeric_scalar(1) ); // lookup_1 = [ {10, 100, null} ]; ``` This abstraction should help replace the Java/JNI `map_lookup` and `map_contains` kernels, which only handles `MAP`. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Jason Lowe (https://github.com/jlowe) - AJ Schmidt (https://github.com/ajschmidt8) - Nghia Truong (https://github.com/ttnghia) - Jake Hemstad (https://github.com/jrhemstad) URL: https://github.com/rapidsai/cudf/pull/10380 --- conda/recipes/libcudf/meta.yaml | 2 + cpp/include/cudf/lists/detail/contains.hpp | 78 +++++++++++ cpp/include/cudf/lists/detail/extract.hpp | 49 +++++++ cpp/src/lists/contains.cu | 61 ++++---- cpp/src/lists/extract.cu | 35 ++++- .../main/java/ai/rapids/cudf/ColumnView.java | 27 ++-- java/src/main/native/CMakeLists.txt | 1 + .../main/native/include/maps_column_view.hpp | 130 ++++++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 19 +-- java/src/main/native/src/maps_column_view.cu | 94 +++++++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 65 +++++++-- 11 files changed, 495 insertions(+), 66 deletions(-) create mode 100644 cpp/include/cudf/lists/detail/contains.hpp create mode 100644 cpp/include/cudf/lists/detail/extract.hpp create mode 100644 java/src/main/native/include/maps_column_view.hpp create mode 100644 java/src/main/native/src/maps_column_view.cu diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index ebfc649c0d2..4ea4ace11da 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -150,7 +150,9 @@ test: - test -f $PREFIX/include/cudf/labeling/label_bins.hpp - test -f $PREFIX/include/cudf/lists/detail/combine.hpp - test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp + - test -f $PREFIX/include/cudf/lists/detail/contains.hpp - test -f $PREFIX/include/cudf/lists/detail/copying.hpp + - test -f $PREFIX/include/cudf/lists/detail/extract.hpp - test -f $PREFIX/include/cudf/lists/lists_column_factories.hpp - test -f $PREFIX/include/cudf/lists/detail/drop_list_duplicates.hpp - test -f $PREFIX/include/cudf/lists/detail/interleave_columns.hpp diff --git a/cpp/include/cudf/lists/detail/contains.hpp b/cpp/include/cudf/lists/detail/contains.hpp new file mode 100644 index 00000000000..24318e72e98 --- /dev/null +++ b/cpp/include/cudf/lists/detail/contains.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace lists { +namespace detail { + +/** + * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, + * cudf::scalar const&, + * duplicate_find_option, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr index_of( + cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + cudf::lists::duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, + * cudf::column_view const&, + * duplicate_find_option, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr index_of( + cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + cudf::lists::duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::contains(cudf::lists_column_view const&, + * cudf::scalar const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::contains(cudf::lists_column_view const&, + * cudf::column_view const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +} // namespace detail +} // namespace lists +} // namespace cudf diff --git a/cpp/include/cudf/lists/detail/extract.hpp b/cpp/include/cudf/lists/detail/extract.hpp new file mode 100644 index 00000000000..44c31c9ddb2 --- /dev/null +++ b/cpp/include/cudf/lists/detail/extract.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace lists { +namespace detail { + +/** + * @copydoc cudf::lists::extract_list_element(lists_column_view, size_type, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element( + lists_column_view lists_column, + size_type const index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @copydoc cudf::lists::extract_list_element(lists_column_view, column_view const&, + * rmm::mr::device_memory_resource*) + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element( + lists_column_view lists_column, + column_view const& indices, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +} // namespace detail +} // namespace lists +} // namespace cudf diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index 5d095fdd5a3..5704ff81665 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -251,18 +252,17 @@ std::unique_ptr to_contains(std::unique_ptr&& key_positions, namespace detail { /** - * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, - * cudf::scalar const&, - * duplicate_find_option, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, + * cudf::scalar const&, + * duplicate_find_option, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ -std::unique_ptr index_of( - cudf::lists_column_view const& lists, - cudf::scalar const& search_key, - duplicate_find_option find_option, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr index_of(cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { return search_key.is_valid(stream) ? cudf::type_dispatcher(search_key.type(), @@ -282,18 +282,17 @@ std::unique_ptr index_of( } /** - * @copydoc cudf::lists::index_of(cudf::lists_column_view const&, - * cudf::column_view const&, - * duplicate_find_option, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, + * cudf::column_view const&, + * duplicate_find_option, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ -std::unique_ptr index_of( - cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, - duplicate_find_option find_option, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +std::unique_ptr index_of(cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + duplicate_find_option find_option, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { CUDF_EXPECTS(search_keys.size() == lists.size(), "Number of search keys must match list column size."); @@ -316,10 +315,10 @@ std::unique_ptr index_of( } /** - * @copydoc cudf::lists::contains(cudf::lists_column_view const&, - * cudf::scalar const&, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, + * cudf::scalar const&, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ std::unique_ptr contains(cudf::lists_column_view const& lists, cudf::scalar const& search_key, @@ -331,10 +330,10 @@ std::unique_ptr contains(cudf::lists_column_view const& lists, } /** - * @copydoc cudf::lists::contains(cudf::lists_column_view const&, - * cudf::column_view const&, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. + * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, + * cudf::column_view const&, + * rmm::cuda_stream_view, + * rmm::mr::device_memory_resource*) */ std::unique_ptr contains(cudf::lists_column_view const& lists, cudf::column_view const& search_keys, diff --git a/cpp/src/lists/extract.cu b/cpp/src/lists/extract.cu index 7c6c612eb25..0e8659b54ff 100644 --- a/cpp/src/lists/extract.cu +++ b/cpp/src/lists/extract.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -107,10 +108,10 @@ std::unique_ptr make_index_offsets(size_type num_lists, rmm::cuda_ * @param stream CUDA stream used for device memory operations and kernel launches. */ template -std::unique_ptr extract_list_element(lists_column_view lists_column, - index_t const& index, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr extract_list_element_impl(lists_column_view lists_column, + index_t const& index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { auto const num_lists = lists_column.size(); if (num_lists == 0) { return empty_like(lists_column.child()); } @@ -135,6 +136,26 @@ std::unique_ptr extract_list_element(lists_column_view lists_column, return std::move(extracted_lists->release().children[lists_column_view::child_column_index]); } +/** + * @copydoc cudf::lists::extract_list_element + * @param stream CUDA stream used for device memory operations and kernel launches. + */ +std::unique_ptr extract_list_element(lists_column_view lists_column, + size_type const index, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return detail::extract_list_element_impl(lists_column, index, stream, mr); +} + +std::unique_ptr extract_list_element(lists_column_view lists_column, + column_view const& indices, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return detail::extract_list_element_impl(lists_column, indices, stream, mr); +} + } // namespace detail /** @@ -146,7 +167,7 @@ std::unique_ptr extract_list_element(lists_column_view const& lists_colu size_type index, rmm::mr::device_memory_resource* mr) { - return detail::extract_list_element(lists_column, index, rmm::cuda_stream_default, mr); + return detail::extract_list_element_impl(lists_column, index, rmm::cuda_stream_default, mr); } /** @@ -160,7 +181,7 @@ std::unique_ptr extract_list_element(lists_column_view const& lists_colu { CUDF_EXPECTS(indices.size() == lists_column.size(), "Index column must have as many elements as lists column."); - return detail::extract_list_element(lists_column, indices, rmm::cuda_stream_default, mr); + return detail::extract_list_element_impl(lists_column, indices, rmm::cuda_stream_default, mr); } } // namespace lists diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 3fe244c0112..ed3ac124216 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3244,17 +3244,23 @@ public final ColumnVector urlEncode() throws CudfException { return new ColumnVector(urlEncode(getNativeView())); } - /** For a column of type List> and a passed in String key, return a string column - * for all the values in the struct that match the key, null otherwise. - * @param key the String scalar to lookup in the column - * @return a string column of values or nulls based on the lookup result + private static void assertIsSupportedMapKeyType(DType keyType) { + boolean isSupportedKeyType = + !keyType.equals(DType.EMPTY) && !keyType.equals(DType.LIST) && !keyType.equals(DType.STRUCT); + assert isSupportedKeyType : "Map lookup by STRUCT and LIST keys is not supported."; + } + + /** + * Given a column of type List> and a key of type X, return a column of type Y, + * where each row in the output column is the Y value corresponding to the X key. + * If the key is not found, the corresponding output value is null. + * @param key the scalar key to lookup in the column + * @return a column of values or nulls based on the lookup result */ public final ColumnVector getMapValue(Scalar key) { - assert type.equals(DType.LIST) : "column type must be a LIST"; - assert key != null : "target string may not be null"; - assert key.getType().equals(DType.STRING) : "target string must be a string scalar"; - + assert key != null : "Lookup key may not be null"; + assertIsSupportedMapKeyType(key.getType()); return new ColumnVector(mapLookup(getNativeView(), key.getScalarHandle())); } @@ -3266,9 +3272,8 @@ public final ColumnVector getMapValue(Scalar key) { */ public final ColumnVector getMapKeyExistence(Scalar key) { assert type.equals(DType.LIST) : "column type must be a LIST"; - assert key != null : "target string may not be null"; - assert key.getType().equals(DType.STRING) : "target must be a string scalar"; - + assert key != null : "Lookup key may not be null"; + assertIsSupportedMapKeyType(key.getType()); return new ColumnVector(mapContains(getNativeView(), key.getScalarHandle())); } diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index ffbeeb155e0..6e0c07bc4f0 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -238,6 +238,7 @@ add_library( src/TableJni.cpp src/aggregation128_utils.cu src/map_lookup.cu + src/maps_column_view.cu src/row_conversion.cu src/check_nvcomp_output_sizes.cu ) diff --git a/java/src/main/native/include/maps_column_view.hpp b/java/src/main/native/include/maps_column_view.hpp new file mode 100644 index 00000000000..26d97fd5789 --- /dev/null +++ b/java/src/main/native/include/maps_column_view.hpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace cudf { + +class scalar; + +namespace jni { + +/** + * @brief Given a column-view of LIST>, an instance of this class + * provides an abstraction of a column of maps. + * + * Each list row is treated as a map of key->value, with possibly repeated keys. + * The list may be looked up by a scalar key, or by a column of keys, to + * retrieve the corresponding value. + */ +class maps_column_view { +public: + maps_column_view(lists_column_view const &lists_of_structs, + rmm::cuda_stream_view stream = rmm::cuda_stream_default); + + // Rule of 5. + maps_column_view(maps_column_view const &maps_view) = default; + maps_column_view(maps_column_view &&maps_view) = default; + maps_column_view &operator=(maps_column_view const &) = default; + maps_column_view &operator=(maps_column_view &&) = default; + ~maps_column_view() = default; + + /** + * @brief Returns number of map rows in the column. + */ + size_type size() const { return keys_.size(); } + + /** + * @brief Getter for keys as a list column. + * + * Note: Keys are not deduped. Repeated keys are returned in order. + */ + lists_column_view const &keys() const { return keys_; } + + /** + * @brief Getter for values as a list column. + * + * Note: Values for repeated keys are not dropped. + */ + lists_column_view const &values() const { return values_; } + + /** + * @brief Map lookup by a column of keys. + * + * The lookup column must have as many rows as the map column, + * and must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * If a key is repeated in a map row, the value corresponding to the last matching + * key is returned. + * If a lookup key is null or not found, the corresponding value is null. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr Column of values corresponding the value of the lookup key. + */ + std::unique_ptr get_values_for( + column_view const &keys, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + + /** + * @brief Map lookup by a scalar key. + * + * The type of the lookup scalar must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * If a key is repeated in a map row, the value corresponding to the last matching + * key is returned. + * If the lookup key is null or not found, the corresponding value is null. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr + */ + std::unique_ptr get_values_for( + scalar const &key, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + + /** + * @brief Check if each map row contains a specified scalar key. + * + * The type of the lookup scalar must match the key-type of the map. + * A column of values is returned, with the same number of rows as the map column. + * + * Each row in the returned column contains a bool indicating whether the row contains + * the specified key (`true`) or not (`false`). + * The returned column contains no nulls. i.e. If the search key is null, or if the + * map row is null, the result row is `false`. + * + * @param keys Column of keys to be looked up in each corresponding map row. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr + */ + std::unique_ptr + contains(scalar const &key, rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource()) const; + +private: + lists_column_view keys_, values_; +}; + +} // namespace jni +} // namespace cudf diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index ec9e19e518d..8c8e9b91e8d 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -75,6 +75,7 @@ #include "dtype_utils.hpp" #include "jni_utils.hpp" #include "map_lookup.hpp" +#include "maps_column_view.hpp" using cudf::jni::ptr_as_jlong; using cudf::jni::release_as_jlong; @@ -1361,12 +1362,13 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jc jlong map_column_view, jlong lookup_key) { JNI_NULL_CHECK(env, map_column_view, "column is null", 0); - JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0); + JNI_NULL_CHECK(env, lookup_key, "lookup key is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(map_column_view); - cudf::string_scalar *ss_key = reinterpret_cast(lookup_key); - return release_as_jlong(cudf::jni::map_lookup(*cv, *ss_key)); + auto const *cv = reinterpret_cast(map_column_view); + auto const *scalar_key = reinterpret_cast(lookup_key); + auto const maps_view = cudf::jni::maps_column_view{*cv}; + return release_as_jlong(maps_view.get_values_for(*scalar_key)); } CATCH_STD(env, 0); } @@ -1375,12 +1377,13 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, jlong map_column_view, jlong lookup_key) { JNI_NULL_CHECK(env, map_column_view, "column is null", 0); - JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0); + JNI_NULL_CHECK(env, lookup_key, "lookup key is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(map_column_view); - cudf::string_scalar *ss_key = reinterpret_cast(lookup_key); - return release_as_jlong(cudf::jni::map_contains(*cv, *ss_key)); + auto const *cv = reinterpret_cast(map_column_view); + auto const *scalar_key = reinterpret_cast(lookup_key); + auto const maps_view = cudf::jni::maps_column_view{*cv}; + return release_as_jlong(maps_view.contains(*scalar_key)); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/maps_column_view.cu b/java/src/main/native/src/maps_column_view.cu new file mode 100644 index 00000000000..e2d352812fd --- /dev/null +++ b/java/src/main/native/src/maps_column_view.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace cudf::jni { + +namespace { +column_view make_lists(column_view const &lists_child, lists_column_view const &lists_of_structs) { + return column_view{data_type{type_id::LIST}, + lists_of_structs.size(), + nullptr, + lists_of_structs.null_mask(), + lists_of_structs.null_count(), + lists_of_structs.offset(), + {lists_of_structs.offsets(), lists_child}}; +} +} // namespace + +maps_column_view::maps_column_view(lists_column_view const &lists_of_structs, + rmm::cuda_stream_view stream) + : keys_{make_lists(lists_of_structs.child().child(0), lists_of_structs)}, + values_{make_lists(lists_of_structs.child().child(1), lists_of_structs)} { + auto const structs = lists_of_structs.child(); + CUDF_EXPECTS(structs.type().id() == type_id::STRUCT, + "maps_column_view input must have exactly 1 child (STRUCT) column."); + CUDF_EXPECTS(structs.num_children() == 2, + "maps_column_view key-value struct must have exactly 2 children."); +} + +template +std::unique_ptr get_values_for_impl(maps_column_view const &maps_view, + KeyT const &lookup_keys, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) { + auto const keys_ = maps_view.keys(); + auto const values_ = maps_view.values(); + CUDF_EXPECTS(lookup_keys.type().id() == keys_.child().type().id(), + "Lookup keys must have the same type as the keys of the map column."); + auto key_indices = + lists::detail::index_of(keys_, lookup_keys, lists::duplicate_find_option::FIND_LAST, stream); + auto constexpr absent_offset = size_type{-1}; + auto constexpr nullity_offset = std::numeric_limits::min(); + thrust::replace(rmm::exec_policy(stream), key_indices->mutable_view().template begin(), + key_indices->mutable_view().template end(), absent_offset, + nullity_offset); + return lists::detail::extract_list_element(values_, key_indices->view(), stream, mr); +} + +std::unique_ptr +maps_column_view::get_values_for(column_view const &lookup_keys, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + CUDF_EXPECTS(lookup_keys.size() == size(), + "Lookup keys must have the same size as the map column."); + + return get_values_for_impl(*this, lookup_keys, stream, mr); +} + +std::unique_ptr +maps_column_view::get_values_for(scalar const &lookup_key, rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + return get_values_for_impl(*this, lookup_key, stream, mr); +} + +std::unique_ptr maps_column_view::contains(scalar const &lookup_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) const { + CUDF_EXPECTS(lookup_key.type().id() == keys_.child().type().id(), + "Lookup keys must have the same type as the keys of the map column."); + auto const contains = lists::detail::contains(keys_, lookup_key, stream); + + // Replace nulls with BOOL8{false}; + auto const scalar_false = numeric_scalar{false, true, stream}; + return detail::replace_nulls(contains->view(), scalar_false, stream, mr); +} + +} // namespace cudf::jni diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0ba29840156..d1509f14c6e 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -5833,14 +5833,30 @@ void testStructChildValidity() { } @Test - void testGetMapValue() { + void testGetMapValueForInteger() { + List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2))); + List list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 3))); + List list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(5, 4))); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3); + Scalar lookupKey = Scalar.fromInt(1); + ColumnVector res = cv.getMapValue(lookupKey); + ColumnVector expected = ColumnVector.fromBoxedInts(2, 3, null)) { + assertColumnsAreEqual(expected, res); + } + } + + @Test + void testGetMapValueForStrings() { List list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "b"))); List list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "c"))); List list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("e", "d"))); HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3); - ColumnVector res = cv.getMapValue(Scalar.fromString("a")); + Scalar lookupKey = Scalar.fromString("a"); + ColumnVector res = cv.getMapValue(lookupKey); ColumnVector expected = ColumnVector.fromStrings("b", "c", null)) { assertColumnsAreEqual(expected, res); } @@ -5851,14 +5867,45 @@ void testGetMapValueEmptyInput() { HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType)); - ColumnVector res = cv.getMapValue(Scalar.fromString("a")); + Scalar lookupKey = Scalar.fromString("a"); + ColumnVector res = cv.getMapValue(lookupKey); ColumnVector expected = ColumnVector.fromStrings()) { assertColumnsAreEqual(expected, res); } } @Test - void testGetMapKeyExistence() { + void testGetMapKeyExistenceForInteger() { + List list1 = Arrays.asList(new HostColumnVector.StructData(1, 2)); + List list2 = Arrays.asList(new HostColumnVector.StructData(1, 3)); + List list3 = Arrays.asList(new HostColumnVector.StructData(5, 4)); + List list4 = Arrays.asList(new HostColumnVector.StructData(1, 7)); + List list5 = Arrays.asList(new HostColumnVector.StructData(1, null)); + List list6 = Arrays.asList(new HostColumnVector.StructData(null, null)); + List list7 = Arrays.asList(new HostColumnVector.StructData()); + HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT32))); + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); + Scalar lookup1 = Scalar.fromInt(1); + ColumnVector resValidKey = cv.getMapKeyExistence(lookup1); + ColumnVector expectedValid = ColumnVector.fromBoxedBooleans(true, true, false, true, true, false, false); + ColumnVector expectedNull = ColumnVector.fromBoxedBooleans(false, false, false, false, false, false, false); + Scalar lookupNull = Scalar.fromNull(DType.INT32); + ColumnVector resNullKey = cv.getMapKeyExistence(lookupNull)) { + assertColumnsAreEqual(expectedValid, resValidKey); + assertColumnsAreEqual(expectedNull, resNullKey); + } + + AssertionError e = assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); + ColumnVector resNullKey = cv.getMapKeyExistence(null)) { + } + }); + assertTrue(e.getMessage().contains("Lookup key may not be null")); + } + + @Test + void testGetMapKeyExistenceForStrings() { List list1 = Arrays.asList(new HostColumnVector.StructData("a", "b")); List list2 = Arrays.asList(new HostColumnVector.StructData("a", "c")); List list3 = Arrays.asList(new HostColumnVector.StructData("e", "d")); @@ -5869,10 +5916,12 @@ void testGetMapKeyExistence() { HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING), new HostColumnVector.BasicType(true, DType.STRING))); try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7); - ColumnVector resValidKey = cv.getMapKeyExistence(Scalar.fromString("a")); + Scalar lookupA = Scalar.fromString("a"); + ColumnVector resValidKey = cv.getMapKeyExistence(lookupA); ColumnVector expectedValid = ColumnVector.fromBoxedBooleans(true, true, false, true, true, false, false); ColumnVector expectedNull = ColumnVector.fromBoxedBooleans(false, false, false, false, false, false, false); - ColumnVector resNullKey = cv.getMapKeyExistence(Scalar.fromNull(DType.STRING))) { + Scalar lookupNull = Scalar.fromNull(DType.STRING); + ColumnVector resNullKey = cv.getMapKeyExistence(lookupNull)) { assertColumnsAreEqual(expectedValid, resValidKey); assertColumnsAreEqual(expectedNull, resNullKey); } @@ -5882,10 +5931,8 @@ void testGetMapKeyExistence() { ColumnVector resNullKey = cv.getMapKeyExistence(null)) { } }); - assertTrue(e.getMessage().contains("target string may not be null")); + assertTrue(e.getMessage().contains("Lookup key may not be null")); } - - @Test void testListOfStructsOfStructs() { List list1 = Arrays.asList(