From 0e080225d3446682b038466b698de053ad383d17 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 23 Jun 2022 08:19:04 -0700 Subject: [PATCH] Refactor `lists::contains` (#11019) This is just a not-very-simple refactor to `lists::contains`: * Remove some wrong lines in doxygen of `lists/contains.hpp`, and rewrite doxygen there a little bit. * Add more comments to the code. * Reduce the number of code paths of the template struct functors. * Rename some/many variables, and reorganize code to make it cleaner. No new feature is added in this PR, just modifying the existing functions and moving things around. This PR is extracted from the bigger PR for easier review. The original PR is https://github.com/rapidsai/cudf/pull/10548 for supporting nested type in `lists::contains`. As such, this blocks it. Authors: - Nghia Truong (https://github.com/ttnghia) - Bradley Dice (https://github.com/bdice) Approvers: - Bradley Dice (https://github.com/bdice) - Karthikeyan (https://github.com/karthikeyann) URL: https://github.com/rapidsai/cudf/pull/11019 --- cpp/include/cudf/lists/contains.hpp | 12 +- cpp/src/lists/contains.cu | 480 ++++++++++++---------------- cpp/tests/lists/contains_tests.cpp | 52 +-- 3 files changed, 236 insertions(+), 308 deletions(-) diff --git a/cpp/include/cudf/lists/contains.hpp b/cpp/include/cudf/lists/contains.hpp index d529677d505..d9d37f21bd3 100644 --- a/cpp/include/cudf/lists/contains.hpp +++ b/cpp/include/cudf/lists/contains.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,8 +37,6 @@ namespace lists { * Output `column[i]` is set to null if one or more of the following are true: * 1. The search key `search_key` is null * 2. The list row `lists[i]` is null - * 3. The list row `lists[i]` does not contain the search key, and contains at least - * one null. * * @param lists Lists column whose `n` rows are to be searched * @param search_key The scalar key to be looked up in each list row @@ -61,8 +59,6 @@ std::unique_ptr contains( * Output `column[i]` is set to null if one or more of the following are true: * 1. The row `search_keys[i]` is null * 2. The list row `lists[i]` is null - * 3. The list row `lists[i]` does not contain the `search_keys[i]`, and contains at least - * one null. * * @param lists Lists column whose `n` rows are to be searched * @param search_keys Column of elements to be looked up in each list row @@ -79,10 +75,12 @@ std::unique_ptr contains( * contains at least one null element. * * The output column has as many elements as the input `lists` column. - * Output `column[i]` is set to null the list row `lists[i]` is null. + * Output `column[i]` is set to null if the row `lists[i]` is null. * Otherwise, `column[i]` is set to a non-null boolean value, depending on whether that list * contains a null element. - * (Empty list rows are considered *NOT* to contain a null element.) + * + * A row with an empty list will always return false. + * Nulls inside non-null nested elements (such as lists or structs) are not considered. * * @param lists Lists column whose `n` rows are to be searched * @param mr Device memory resource used to allocate the returned column's device memory. diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index 8662bad0c8c..f4c7292ba0d 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -17,13 +17,11 @@ #include #include #include -#include #include #include #include #include #include -#include #include #include #include @@ -33,176 +31,135 @@ #include #include #include -#include #include #include -#include #include -#include #include #include -#include #include -namespace cudf { -namespace lists { +namespace cudf::lists { namespace { -auto constexpr absent_index = size_type{-1}; +/** + * @brief A sentinel value used for marking that a given key has not been found in the search list. + * + * The value should be `-1` as indicated in the public API documentation. + */ +auto constexpr __device__ NOT_FOUND_SENTINEL = size_type{-1}; -auto get_search_keys_device_iterable_view(cudf::column_view const& search_keys, - rmm::cuda_stream_view stream) -{ - return column_device_view::create(search_keys, stream); -} +/** + * @brief A sentinel value used for marking that a given output row should be null. + */ +auto constexpr __device__ NULL_SENTINEL = std::numeric_limits::min(); -auto get_search_keys_device_iterable_view(cudf::scalar const& search_key, rmm::cuda_stream_view) +/** + * @brief Indicate the current supported types in `cudf::lists::contains`. + * + * TODO: Add supported nested types. + */ +template +static auto constexpr is_supported_non_nested_type() { - return &search_key; + return cudf::is_fixed_width() || std::is_same_v; } -template -auto __device__ find_begin(list_device_view const& list) -{ - if constexpr (find_option == duplicate_find_option::FIND_FIRST) { - return list.pair_rep_begin(); - } else { - return thrust::make_reverse_iterator(list.pair_rep_end()); - } -} +/** + * @brief Functor to perform searching for index of a key element in a given list. + */ +struct search_list_fn { + duplicate_find_option const find_option; -template -auto __device__ find_end(list_device_view const& list) -{ - if constexpr (find_option == duplicate_find_option::FIND_FIRST) { - return list.pair_rep_end(); - } else { - return thrust::make_reverse_iterator(list.pair_rep_begin()); - } -} + template ())> + __device__ size_type operator()(list_device_view list, thrust::optional key_opt) const + { + // A null list or null key will result in a null output row. + if (list.is_null() || !key_opt) { return NULL_SENTINEL; } -template -size_type __device__ distance([[maybe_unused]] Iterator begin, Iterator end, Iterator find_iter) -{ - if (find_iter == end) { - return absent_index; // Not found. + return find_option == duplicate_find_option::FIND_FIRST + ? search_list(list, *key_opt) + : search_list(list, *key_opt); } - if constexpr (find_option == duplicate_find_option::FIND_FIRST) { - return find_iter - begin; // Distance of find_position from begin. - } else { - return end - find_iter - 1; // Distance of find_position from end. + template ())> + __device__ size_type operator()(list_device_view, thrust::optional) const + { + CUDF_UNREACHABLE("Unsupported type."); } -} -/** - * @brief __device__ functor to search for a key in a `list_device_view`. - */ -template -struct finder { - template - __device__ size_type operator()(list_device_view const& list, ElementType const& search_key) const + private: + template ())> + static __device__ inline size_type search_list(list_device_view const list, + Element const search_key) { - auto const list_begin = find_begin(list); - auto const list_end = find_end(list); - auto const find_iter = thrust::find_if( - thrust::seq, list_begin, list_end, [search_key] __device__(auto element_and_validity) { - auto [element, element_is_valid] = element_and_validity; - return element_is_valid && cudf::equality_compare(element, search_key); + auto const [begin, end] = element_index_pair_iter(list.size()); + auto const found_iter = + thrust::find_if(thrust::seq, begin, end, [&] __device__(auto const idx) { + return !list.is_null(idx) && + cudf::equality_compare(list.template element(idx), search_key); }); - return distance(list_begin, list_end, find_iter); - }; -}; - -/** - * @brief Functor to search each list row for the specified search keys. - */ -template -struct lookup_functor { - template - struct is_supported { - static constexpr bool value = - cudf::is_numeric() || cudf::is_chrono() || - cudf::is_fixed_point() || std::is_same_v; - }; - - template - std::enable_if_t::value, std::unique_ptr> operator()( - Args&&...) const - { - CUDF_FAIL( - "List search operations are only supported on numeric types, decimals, chrono types, and " - "strings."); + // If the key is found, return its found position in the list from `found_iter`. + return found_iter == end ? NOT_FOUND_SENTINEL : *found_iter; } - std::pair construct_null_mask( - lists_column_view const& input_lists, - column_view const& result_validity, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) const + /** + * @brief Return a pair of index iterators {begin, end} to loop through elements within a list. + * + * Depending on the value of `forward`, a pair of forward or reverse iterators will be + * returned, allowing to loop through elements in the list in first-to-last or last-to-first + * order. + * + * Note that the element indices always restart to `0` at the first position in each list. + * + * @tparam forward A boolean value indicating whether we want to iterate elements in the list by + * forward or reverse order. + * @param size The number of elements in the list. + * @return A pair of {begin, end} iterators to iterate through the range `[0, size)`. + */ + template + static __device__ auto element_index_pair_iter(size_type const size) { - if (!search_keys_have_nulls && !input_lists.has_nulls() && !input_lists.child().has_nulls()) { - return {rmm::device_buffer{0, stream, mr}, size_type{0}}; + if constexpr (forward) { + return thrust::pair(thrust::make_counting_iterator(0), thrust::make_counting_iterator(size)); } else { - return cudf::detail::valid_if( - result_validity.begin(), result_validity.end(), thrust::identity{}, stream, mr); + return thrust::pair(thrust::make_reverse_iterator(thrust::make_counting_iterator(size)), + thrust::make_reverse_iterator(thrust::make_counting_iterator(0))); } } +}; - template - void search_each_list_row(cudf::detail::lists_column_device_view const& d_lists, - SearchKeyPairIter search_key_pair_iter, - duplicate_find_option find_option, - cudf::mutable_column_device_view ret_positions, - cudf::mutable_column_device_view ret_validity, - rmm::cuda_stream_view stream) const - { - auto output_iterator = thrust::make_zip_iterator( - thrust::make_tuple(ret_positions.data(), ret_validity.data())); - - thrust::tabulate( - rmm::exec_policy(stream), - output_iterator, - output_iterator + d_lists.size(), - [d_lists, search_key_pair_iter, absent_index = absent_index, find_option] __device__( - auto row_index) -> thrust::pair { - auto [search_key, search_key_is_valid] = search_key_pair_iter[row_index]; - - if (search_keys_have_nulls && !search_key_is_valid) { return {absent_index, false}; } - - auto list = cudf::list_device_view(d_lists, row_index); - if (list.is_null()) { return {absent_index, false}; } - - auto const position = find_option == duplicate_find_option::FIND_FIRST - ? finder{}(list, search_key) - : finder{}(list, search_key); - return {position, true}; - }); - } - - template - std::enable_if_t::value, std::unique_ptr> operator()( - cudf::lists_column_view const& lists, - SearchKeyType const& search_key, +/** + * @brief Dispatch functor to search for key element(s) in the corresponding rows of a lists column. + */ +struct dispatch_index_of { + template + std::enable_if_t(), std::unique_ptr> operator()( + lists_column_view const& lists, + SearchKeyType const& search_keys, duplicate_find_option find_option, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const + rmm::mr::device_memory_resource* mr) const { - using namespace cudf; - using namespace cudf::detail; - CUDF_EXPECTS(!cudf::is_nested(lists.child().type()), "Nested types not supported in list search operations."); - CUDF_EXPECTS(lists.child().type() == search_key.type(), + CUDF_EXPECTS(lists.child().type() == search_keys.type(), "Type/Scale of search key does not match list column element type."); - CUDF_EXPECTS(search_key.type().id() != type_id::EMPTY, "Type cannot be empty."); + CUDF_EXPECTS(search_keys.type().id() != type_id::EMPTY, "Type cannot be empty."); auto constexpr search_key_is_scalar = std::is_same_v; - - if constexpr (search_keys_have_nulls && search_key_is_scalar) { - return make_numeric_column(data_type(type_id::INT32), + auto const search_keys_have_nulls = [&search_keys, stream] { + if constexpr (search_key_is_scalar) { + return !search_keys.is_valid(stream); + } else { + return search_keys.has_nulls(); + } + }(); + + if (search_key_is_scalar && search_keys_have_nulls) { + // If the scalar key is invalid/null, the entire output column will be all nulls. + return make_numeric_column(data_type{cudf::type_to_id()}, lists.size(), cudf::create_null_mask(lists.size(), mask_state::ALL_NULL, mr), lists.size(), @@ -210,131 +167,115 @@ struct lookup_functor { mr); } - auto const device_view = column_device_view::create(lists.parent(), stream); - auto const d_lists = lists_column_device_view{*device_view}; - auto const d_skeys = get_search_keys_device_iterable_view(search_key, stream); - - auto result_positions = make_numeric_column( - data_type{type_id::INT32}, lists.size(), cudf::mask_state::UNALLOCATED, stream, mr); - auto result_validity = make_numeric_column( - data_type{type_id::BOOL8}, lists.size(), cudf::mask_state::UNALLOCATED, stream, mr); - auto mutable_result_positions = - mutable_column_device_view::create(result_positions->mutable_view(), stream); - auto mutable_result_validity = - mutable_column_device_view::create(result_validity->mutable_view(), stream); - auto search_key_iter = - cudf::detail::make_pair_rep_iterator(*d_skeys); - - search_each_list_row(d_lists, - search_key_iter, - find_option, - *mutable_result_positions, - *mutable_result_validity, - stream); - - auto [null_mask, num_nulls] = construct_null_mask(lists, result_validity->view(), stream, mr); - result_positions->set_null_mask(std::move(null_mask), num_nulls); - return result_positions; + auto const lists_cdv_ptr = column_device_view::create(lists.parent(), stream); + auto const input_it = cudf::detail::make_counting_transform_iterator( + size_type{0}, + [lists = cudf::detail::lists_column_device_view{*lists_cdv_ptr}] __device__(auto const idx) { + return list_device_view{lists, idx}; + }); + + auto out_positions = make_numeric_column( + data_type{type_to_id()}, lists.size(), cudf::mask_state::UNALLOCATED, stream, mr); + auto const out_begin = out_positions->mutable_view().template begin(); + + auto const do_search = [&](auto const keys_iter) { + thrust::transform(rmm::exec_policy(stream), + input_it, + input_it + lists.size(), + keys_iter, + out_begin, + search_list_fn{find_option}); + }; + + if constexpr (search_key_is_scalar) { + auto const keys_iter = cudf::detail::make_optional_iterator( + search_keys, nullate::DYNAMIC{search_keys_have_nulls}); + do_search(keys_iter); + } else { + auto const keys_cdv_ptr = column_device_view::create(search_keys, stream); + auto const keys_iter = cudf::detail::make_optional_iterator( + *keys_cdv_ptr, nullate::DYNAMIC{search_keys_have_nulls}); + do_search(keys_iter); + } + + if (search_keys_have_nulls || lists.has_nulls()) { + auto [null_mask, null_count] = cudf::detail::valid_if( + out_begin, + out_begin + lists.size(), + [] __device__(auto const idx) { return idx != NULL_SENTINEL; }, + stream, + mr); + out_positions->set_null_mask(std::move(null_mask), null_count); + } + return out_positions; + } + + template + std::enable_if_t(), std::unique_ptr> operator()( + lists_column_view const&, + SearchKeyType const&, + duplicate_find_option, + rmm::cuda_stream_view, + rmm::mr::device_memory_resource*) const + { + CUDF_FAIL("Unsupported type in `dispatch_index_of` functor."); } }; /** - * @brief Converts key-positions vector (from index_of()) to a BOOL8 vector, indicating if - * the search key was found. + * @brief Converts key-positions vector (from `index_of()`) to a BOOL8 vector, indicating if + * the search key(s) were found. */ std::unique_ptr to_contains(std::unique_ptr&& key_positions, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - CUDF_EXPECTS(key_positions->type().id() == type_id::INT32, - "Expected input column of type INT32."); - // If position == -1, the list did not contain the search key. - auto const num_rows = key_positions->size(); - auto const positions_begin = key_positions->view().begin(); - auto result = - make_numeric_column(data_type{type_id::BOOL8}, num_rows, mask_state::UNALLOCATED, stream, mr); + CUDF_EXPECTS(key_positions->type().id() == type_to_id(), + "Expected input column of type cudf::size_type."); + auto const positions_begin = key_positions->view().template begin(); + auto result = make_numeric_column( + data_type{type_id::BOOL8}, key_positions->size(), mask_state::UNALLOCATED, stream, mr); thrust::transform(rmm::exec_policy(stream), positions_begin, - positions_begin + num_rows, - result->mutable_view().begin(), - [] __device__(auto i) { return i != absent_index; }); - [[maybe_unused]] auto [_, null_mask, __] = key_positions->release(); - result->set_null_mask(std::move(*null_mask)); + positions_begin + key_positions->size(), + result->mutable_view().template begin(), + [] __device__(auto const i) { + // position == NOT_FOUND_SENTINEL: the list does not contain the search key. + return i != NOT_FOUND_SENTINEL; + }); + + auto const null_count = key_positions->null_count(); + [[maybe_unused]] auto [data, null_mask, children] = key_positions->release(); + result->set_null_mask(std::move(*null_mask.release()), null_count); + return result; } } // namespace namespace detail { -/** - * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, - * cudf::scalar const&, - * duplicate_find_option, - * rmm::cuda_stream_view, - * rmm::mr::device_memory_resource*) - */ -std::unique_ptr index_of(cudf::lists_column_view const& lists, +std::unique_ptr index_of(lists_column_view const& lists, cudf::scalar const& search_key, duplicate_find_option find_option, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return search_key.is_valid(stream) - ? cudf::type_dispatcher(search_key.type(), - lookup_functor{}, // No nulls in search key - lists, - search_key, - find_option, - stream, - mr) - : cudf::type_dispatcher(search_key.type(), - lookup_functor{}, // Nulls in search key - lists, - search_key, - find_option, - stream, - mr); + return cudf::type_dispatcher( + search_key.type(), dispatch_index_of{}, lists, search_key, find_option, stream, mr); } -/** - * @copydoc cudf::lists::detail::index_of(cudf::lists_column_view const&, - * cudf::column_view const&, - * duplicate_find_option, - * rmm::cuda_stream_view, - * rmm::mr::device_memory_resource*) - */ -std::unique_ptr index_of(cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, +std::unique_ptr index_of(lists_column_view const& lists, + column_view const& search_keys, duplicate_find_option find_option, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { CUDF_EXPECTS(search_keys.size() == lists.size(), "Number of search keys must match list column size."); - - return search_keys.has_nulls() - ? cudf::type_dispatcher(search_keys.type(), - lookup_functor{}, // Nulls in search keys - lists, - search_keys, - find_option, - stream, - mr) - : cudf::type_dispatcher(search_keys.type(), - lookup_functor{}, // No nulls in search keys - lists, - search_keys, - find_option, - stream, - mr); + return cudf::type_dispatcher( + search_keys.type(), dispatch_index_of{}, lists, search_keys, find_option, stream, mr); } -/** - * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, - * cudf::scalar const&, - * rmm::cuda_stream_view, - * rmm::mr::device_memory_resource*) - */ -std::unique_ptr contains(cudf::lists_column_view const& lists, +std::unique_ptr contains(lists_column_view const& lists, cudf::scalar const& search_key, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -343,14 +284,8 @@ std::unique_ptr contains(cudf::lists_column_view const& lists, index_of(lists, search_key, duplicate_find_option::FIND_FIRST, stream), stream, mr); } -/** - * @copydoc cudf::lists::detail::contains(cudf::lists_column_view const&, - * cudf::column_view const&, - * rmm::cuda_stream_view, - * rmm::mr::device_memory_resource*) - */ -std::unique_ptr contains(cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, +std::unique_ptr contains(lists_column_view const& lists, + column_view const& search_keys, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -361,45 +296,39 @@ std::unique_ptr contains(cudf::lists_column_view const& lists, index_of(lists, search_keys, duplicate_find_option::FIND_FIRST, stream), stream, mr); } -/** - * @copydoc cudf::lists::contain_nulls(cudf::lists_column_view const&, - * rmm::mr::device_memory_resource*) - * @param stream CUDA stream used for device memory operations and kernel launches. - */ -std::unique_ptr contains_nulls(cudf::lists_column_view const& input_lists, +std::unique_ptr contains_nulls(lists_column_view const& lists, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto const num_rows = input_lists.size(); - auto const d_lists = column_device_view::create(input_lists.parent()); - auto has_nulls_output = make_numeric_column( - data_type{type_id::BOOL8}, input_lists.size(), mask_state::UNALLOCATED, stream, mr); - auto const output_begin = has_nulls_output->mutable_view().begin(); - thrust::tabulate( - rmm::exec_policy(stream), - output_begin, - output_begin + num_rows, - [lists = cudf::detail::lists_column_device_view{*d_lists}] __device__(auto list_idx) { - auto list = list_device_view{lists, list_idx}; - auto list_begin = thrust::make_counting_iterator(size_type{0}); - return list.is_null() || - thrust::any_of(thrust::seq, list_begin, list_begin + list.size(), [&list](auto i) { - return list.is_null(i); - }); - }); - auto const validity_begin = cudf::detail::make_counting_transform_iterator( - 0, [lists = cudf::detail::lists_column_device_view{*d_lists}] __device__(auto list_idx) { - return not list_device_view{lists, list_idx}.is_null(); - }); - auto [null_mask, num_nulls] = cudf::detail::valid_if( - validity_begin, validity_begin + num_rows, thrust::identity{}, stream, mr); - has_nulls_output->set_null_mask(std::move(null_mask), num_nulls); - return has_nulls_output; + auto const lists_cv = lists.parent(); + auto output = make_numeric_column(data_type{type_to_id()}, + lists.size(), + copy_bitmask(lists_cv), + lists_cv.null_count(), + stream, + mr); + auto const out_begin = output->mutable_view().template begin(); + auto const lists_cdv_ptr = column_device_view::create(lists_cv, stream); + + thrust::tabulate(rmm::exec_policy(stream), + out_begin, + out_begin + lists.size(), + [lists = cudf::detail::lists_column_device_view{*lists_cdv_ptr}] __device__( + auto const list_idx) { + auto const list = list_device_view{lists, list_idx}; + return list.is_null() || + thrust::any_of(thrust::seq, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(list.size()), + [&list](auto const idx) { return list.is_null(idx); }); + }); + + return output; } } // namespace detail -std::unique_ptr contains(cudf::lists_column_view const& lists, +std::unique_ptr contains(lists_column_view const& lists, cudf::scalar const& search_key, rmm::mr::device_memory_resource* mr) { @@ -407,22 +336,22 @@ std::unique_ptr contains(cudf::lists_column_view const& lists, return detail::contains(lists, search_key, cudf::default_stream_value, mr); } -std::unique_ptr contains(cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, +std::unique_ptr contains(lists_column_view const& lists, + column_view const& search_keys, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); return detail::contains(lists, search_keys, cudf::default_stream_value, mr); } -std::unique_ptr contains_nulls(cudf::lists_column_view const& input_lists, +std::unique_ptr contains_nulls(lists_column_view const& lists, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::contains_nulls(input_lists, cudf::default_stream_value, mr); + return detail::contains_nulls(lists, cudf::default_stream_value, mr); } -std::unique_ptr index_of(cudf::lists_column_view const& lists, +std::unique_ptr index_of(lists_column_view const& lists, cudf::scalar const& search_key, duplicate_find_option find_option, rmm::mr::device_memory_resource* mr) @@ -431,8 +360,8 @@ std::unique_ptr index_of(cudf::lists_column_view const& lists, return detail::index_of(lists, search_key, find_option, cudf::default_stream_value, mr); } -std::unique_ptr index_of(cudf::lists_column_view const& lists, - cudf::column_view const& search_keys, +std::unique_ptr index_of(lists_column_view const& lists, + column_view const& search_keys, duplicate_find_option find_option, rmm::mr::device_memory_resource* mr) { @@ -440,5 +369,4 @@ std::unique_ptr index_of(cudf::lists_column_view const& lists, return detail::index_of(lists, search_keys, find_option, cudf::default_stream_value, mr); } -} // namespace lists -} // namespace cudf +} // namespace cudf::lists diff --git a/cpp/tests/lists/contains_tests.cpp b/cpp/tests/lists/contains_tests.cpp index 066eb7eafc8..4cc0c4155b8 100644 --- a/cpp/tests/lists/contains_tests.cpp +++ b/cpp/tests/lists/contains_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,17 +118,18 @@ TYPED_TEST(TypedContainsTest, ScalarKeyWithNoNulls) { using T = TypeParam; - auto search_space = lists_column_view{lists_column_wrapper{{0, 1, 2, 1}, - {3, 4, 5}, - {6, 7, 8}, - {9, 0, 1, 3, 1}, - {2, 3, 4}, - {5, 6, 7}, - {8, 9, 0}, - {}, - {1, 2, 1, 3}, - {}}}; - auto search_key_one = create_scalar_search_key(1); + auto const search_space_col = lists_column_wrapper{{0, 1, 2, 1}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1, 3, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 1, 3}, + {}}; + auto const search_space = lists_column_view{search_space_col}; + auto search_key_one = create_scalar_search_key(1); { // CONTAINS @@ -161,19 +162,20 @@ TYPED_TEST(TypedContainsTest, ScalarKeyWithNullLists) // Test List columns that have NULL list rows. using T = TypeParam; - auto search_space = lists_column_view{lists_column_wrapper{{{0, 1, 2, 1}, - {3, 4, 5}, - {6, 7, 8}, - {}, - {9, 0, 1, 3, 1}, - {2, 3, 4}, - {5, 6, 7}, - {8, 9, 0}, - {}, - {1, 2, 2, 3}, - {}}, - nulls_at({3, 10})}}; - auto search_key_one = create_scalar_search_key(1); + auto const search_space_col = lists_column_wrapper{{{0, 1, 2, 1}, + {3, 4, 5}, + {6, 7, 8}, + {}, + {9, 0, 1, 3, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 2, 3}, + {}}, + nulls_at({3, 10})}; + auto const search_space = lists_column_view{search_space_col}; + auto search_key_one = create_scalar_search_key(1); { // CONTAINS auto result = lists::contains(search_space, *search_key_one);