Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change cudf::strings::find_multiple to return a lists column #10134

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions cpp/include/cudf/strings/find_multiple.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,30 +27,32 @@ namespace strings {
*/

/**
* @brief Returns a column with character position values where each
* @brief Returns a lists column with character position values where each
* of the target strings are found in each string.
*
* The size of the output column is targets.size() * strings.size().
* output[i] contains the position of target[i % targets.size()] in string[i/targets.size()]
* The size of the output column is `input.size()`.
* Each row of the output column is of size `target.size()`.
bdice marked this conversation as resolved.
Show resolved Hide resolved
*
* `output[i,j]` contains the position of `target[j]` in `input[i]`
*
* @code{.pseudo}
* Example:
* s = ["abc","def"]
* t = ["a","c","e"]
* r = find_multiple(s,t)
* r is now [ 0, 2,-1, // for "abc": "a" at pos 0, "c" at pos 2, "e" not found
* -1,-1, 1 ] // for "def": "a" and "b" not found, "e" at pos 1
* r is now {[ 0, 2,-1], // for "abc": "a" at pos 0, "c" at pos 2, "e" not found
* [-1,-1, 1 ]} // for "def": "a" and "b" not found, "e" at pos 1
* @endcode
*
* @throw cudf::logic_error targets is empty or contains nulls
* @throw cudf::logic_error if `targets` is empty or contains nulls
*
* @param strings Strings instance for this operation.
* @param input Strings instance for this operation.
* @param targets Strings to search for in each string.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New integer column with character position values.
* @return Lists column with character position values.
*/
std::unique_ptr<column> find_multiple(
strings_column_view const& strings,
strings_column_view const& input,
strings_column_view const& targets,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

Expand Down
47 changes: 28 additions & 19 deletions cpp/src/strings/find_multiple.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/sequence.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/find_multiple.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
Expand All @@ -31,37 +33,32 @@ namespace cudf {
namespace strings {
namespace detail {
std::unique_ptr<column> find_multiple(
strings_column_view const& strings,
strings_column_view const& input,
strings_column_view const& targets,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
auto strings_count = strings.size();
if (strings_count == 0) return make_empty_column(type_id::INT32);
auto targets_count = targets.size();
auto const strings_count = input.size();
auto const targets_count = targets.size();
CUDF_EXPECTS(targets_count > 0, "Must include at least one search target");
CUDF_EXPECTS(!targets.has_nulls(), "Search targets cannot contain null strings");

auto strings_column = column_device_view::create(strings.parent(), stream);
auto strings_column = column_device_view::create(input.parent(), stream);
auto d_strings = *strings_column;
auto targets_column = column_device_view::create(targets.parent(), stream);
auto d_targets = *targets_column;

auto const total_count = strings_count * targets_count;

// create output column
auto total_count = strings_count * targets_count;
auto results = make_numeric_column(data_type{type_id::INT32},
total_count,
rmm::device_buffer{0, stream, mr},
0,
stream,
mr); // no nulls
auto results_view = results->mutable_view();
auto d_results = results_view.data<int32_t>();
auto results = make_numeric_column(
data_type{type_id::INT32}, total_count, rmm::device_buffer{0, stream, mr}, 0, stream, mr);

// fill output column with position values
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(total_count),
d_results,
results->mutable_view().begin<int32_t>(),
[d_strings, d_targets, targets_count] __device__(size_type idx) {
size_type str_idx = idx / targets_count;
if (d_strings.is_null(str_idx)) return -1;
Expand All @@ -70,18 +67,30 @@ std::unique_ptr<column> find_multiple(
return d_str.find(d_tgt);
});
results->set_null_count(0);
return results;

auto offsets = cudf::detail::sequence(strings_count + 1,
numeric_scalar<offset_type>(0),
numeric_scalar<offset_type>(targets_count),
stream,
mr);
return make_lists_column(strings_count,
std::move(offsets),
std::move(results),
0,
rmm::device_buffer{0, stream, mr},
stream,
mr);
}

} // namespace detail

// external API
std::unique_ptr<column> find_multiple(strings_column_view const& strings,
std::unique_ptr<column> find_multiple(strings_column_view const& input,
strings_column_view const& targets,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::find_multiple(strings, targets, rmm::cuda_stream_default, mr);
return detail::find_multiple(input, targets, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
17 changes: 10 additions & 7 deletions cpp/tests/strings/find_multiple_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,13 +41,16 @@ TEST_F(StringsFindMultipleTest, FindMultiple)
cudf::test::strings_column_wrapper targets(h_targets.begin(), h_targets.end());
auto targets_view = cudf::strings_column_view(targets);

auto results = cudf::strings::find_multiple(strings_view, targets_view);
cudf::size_type total_count = static_cast<cudf::size_type>(h_strings.size() * h_targets.size());
EXPECT_EQ(total_count, results->size());
auto results = cudf::strings::find_multiple(strings_view, targets_view);

using LCW = cudf::test::lists_column_wrapper<int32_t>;
LCW expected({LCW{1, -1, -1, -1, 4, -1, -1},
LCW{4, -1, 2, -1, -1, -1, 2},
LCW{-1, -1, -1, -1, -1, -1, -1},
LCW{-1, 2, 1, -1, -1, -1, -1},
LCW{-1, -1, 1, 8, -1, -1, 1},
LCW{-1, -1, -1, -1, -1, -1, -1}});

cudf::test::fixed_width_column_wrapper<int32_t> expected(
{1, -1, -1, -1, 4, -1, -1, 4, -1, 2, -1, -1, -1, 2, -1, -1, -1, -1, -1, -1, -1,
-1, 2, 1, -1, -1, -1, -1, -1, -1, 1, 8, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

Expand Down