Skip to content

Commit

Permalink
- compute substrings from beginning until delimiter or from a delimit…
Browse files Browse the repository at this point in the history
…er until end of string

  - this Closes rapidsai#5158
  - this emulates spark's `substring_index` function
  • Loading branch information
sriramch committed May 27, 2020
1 parent 4fe6b24 commit 7178642
Show file tree
Hide file tree
Showing 3 changed files with 560 additions and 0 deletions.
88 changes: 88 additions & 0 deletions cpp/include/cudf/strings/find.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,94 @@ std::unique_ptr<column> ends_with(
string_scalar const& target,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Returns a column of strings that searches for the @p delimiter @p count number of
* times in the source @p strings forward if @p count is positive or backwards if @p count is
* negative. If @p count is positive, it returns a substring from the start of the source @p
* strings up until @p count occurrence of the @delimiter not including the @p delimiter.
* If @p count is negative, it returns a substring from the start of the @p count occurrence of
* the @delimiter in the source @p strings past the delimiter until the end of the string.
*
* The search for @delimiter in @p strings is case sensitive.
* If the @p count is 0, every row in the output column will be null.
* If the row value of @p strings is null, the row value in the output column will be null.
* If the @p delimiter is invalid or null, every row in the output column will be null.
* If the @p delimiter or the column value for a row is empty, the row value in the output
* column will be empty.
* If @p count occurrences of @p delimiter isn't found, the row value in the output column will
* be the row value from the input @p strings column.
*
* @code{.pseudo}
* Example:
* in_s = ['www.nvidia.com', null, 'www.google.com', '', 'foo' ]
* r = substring_index(in_s, '.', 1)
* r is ['www', null, 'www', '', 'foo']
*
* in_s = ['www.nvidia.com', null, 'www.google.com', '', 'foo' ]
* r = substring_index(in_s, '.', -2)
* r is ['nvidia.com', null, 'google.com', '', 'foo']
* @endcode
*
* @param strings Strings instance for this operation.
* @param delimiter UTF-8 encoded string to search for in each string.
* @param count Number of times to search for delimiter in each string. If the value is positive,
* forward search of delimiter is performed; else, a backward search is performed.
* @param mr Resource for allocating device memory.
* @return New strings column containing the substrings.
*/
std::unique_ptr<column> substring_index(
strings_column_view const& strings,
string_scalar const& delimiter,
size_type count,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Returns a column of strings that searches the delimiter for each row from
* @p delimiter_strings @p count number of times in the source @p strings forward if @p count
* is positive or backwards if @p count is negative. If @p count is positive, it returns a
* substring from the start of the source @p strings up until @p count occurrence of the
* delimiter for that row not including that delimiter. If @p count is negative, it returns a
* substring from the start of the @p count occurrence of the delimiter for that row in the
* source @p strings past the delimiter until the end of the string.
*
* The search for @p delimiter_strings in @p strings is case sensitive.
* If the @p count is 0, every row in the output column will be null.
* If the row value of @p strings is null, the row value in the output column will be null.
* If the row value from @p delimiter_strings is invalid or null, the row value in the
* output column will be null.
* If the row value from @p delimiter_strings or the column value for a row is empty, the
* row value in the output column will be empty.
* If @p count occurrences of delimiter isn't found, the row value in the output column will
* be the row value from the input @p strings column.
*
* @code{.pseudo}
* Example:
* in_s = ['www.nvidia.com', null, 'www.google.com', '', 'foo..bar....goo' ]
* delimiters = ['.', '..', '', null, '..']
* r = substring_index(in_s, delimiters, 2)
* r is ['www.nvidia', null, '', null, 'foo..bar']
*
* in_s = ['www.nvidia.com', null, 'www.google.com', '', 'foo..bar....goo', 'apache.org' ]
* delimiters = ['.', '..', '', null, '..', '.']
* r = substring_index(in_s, delimiters, -2)
* r is ['nvidia.com', null, '', null, '..goo', 'apache.org']
* @endcode
*
* @throw cudf::logic_error if the number of rows in @p strings and @delimiter_strings do not match.
*
* @param strings Strings instance for this operation.
* @param delimiter_strings UTF-8 encoded string for each row.
* @param count Number of times to search for delimiter in each string. If the value is positive,
* forward search of delimiter is performed; else, a backward search is performed.
* @param mr Resource for allocating device memory.
* @return New strings column containing the substrings.
*/
std::unique_ptr<column> substring_index(
strings_column_view const& strings,
strings_column_view const& delimiter_strings,
size_type count,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
153 changes: 153 additions & 0 deletions cpp/src/strings/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/find.hpp>
Expand Down Expand Up @@ -286,5 +288,156 @@ std::unique_ptr<column> ends_with(strings_column_view const& strings,
return detail::ends_with(strings, target, mr);
}

// For substring_index APIs
namespace detail {
// Internal helper class
namespace {

struct substring_index_functor {
template <typename ColItrT, typename DelimiterItrT>
std::unique_ptr<column> operator()(ColItrT const col_itr,
DelimiterItrT const delim_itr,
size_type delimiter_count,
rmm::mr::device_memory_resource* mr,
cudaStream_t stream,
size_type strings_count) const
{
// Shallow copy of the resultant strings
rmm::device_vector<string_view> out_col_strings(strings_count);

// Invalid output column strings - null rows
string_view const invalid_str{nullptr, 0};

thrust::transform(
rmm::exec_policy(stream)->on(stream),
col_itr,
col_itr + strings_count,
delim_itr,
out_col_strings.data().get(),
[delimiter_count, invalid_str] __device__(auto col_val_pair, auto delim_val_pair) {
// If the column value for this row or the delimiter is null or if the delimiter count is 0,
// result is null
if (!col_val_pair.second || !delim_val_pair.second || delimiter_count == 0)
return invalid_str;
auto col_val = col_val_pair.first;

// If the global delimiter or the row specific delimiter or if the column value for the row
// is empty, value is empty.
if (delim_val_pair.first.empty() || col_val.empty()) return string_view{};

auto delim_val = delim_val_pair.first;

auto const col_val_len = col_val.length();
auto const delimiter_len = delim_val.length();

auto nsearches = (delimiter_count < 0) ? -delimiter_count : delimiter_count;
size_type start_pos = 0;
size_type end_pos = col_val_len;
string_view out_str{};

for (auto i = 0; i < nsearches; ++i) {
if (delimiter_count < 0) {
end_pos = col_val.rfind(delim_val, 0, end_pos);
if (end_pos == -1) {
out_str = col_val;
break;
}
if (i + 1 == nsearches)
out_str =
col_val.substr(end_pos + delimiter_len, col_val_len - end_pos - delimiter_len);
} else {
auto char_pos = col_val.find(delim_val, start_pos);
if (char_pos == -1) {
out_str = col_val;
break;
}
if (i + 1 == nsearches)
out_str = col_val.substr(0, char_pos);
else
start_pos = char_pos + delimiter_len;
}
}

return out_str.empty() ? string_view{} : out_str;
});

// Create an output column with the resultant strings
return make_strings_column(out_col_strings, invalid_str, stream, mr);
}
};

} // namespace

template <typename DelimiterItrT>
std::unique_ptr<column> substring_index(strings_column_view const& strings,
DelimiterItrT const delimiter_itr,
size_type count,
rmm::mr::device_memory_resource* mr,
cudaStream_t stream = 0)
{
auto strings_count = strings.size();
// If there aren't any rows, return an empty strings column
if (strings_count == 0) return strings::detail::make_empty_strings_column(mr, stream);

// Create device view of the column
auto colview_ptr = column_device_view::create(strings.parent(), stream);
auto colview = *colview_ptr;
if (colview.nullable()) {
return substring_index_functor{}(
experimental::detail::make_pair_iterator<string_view, true>(colview),
delimiter_itr,
count,
mr,
stream,
strings_count);
} else {
return substring_index_functor{}(
experimental::detail::make_pair_iterator<string_view, false>(colview),
delimiter_itr,
count,
mr,
stream,
strings_count);
}
}

} // namespace detail

// external APIs

std::unique_ptr<column> substring_index(strings_column_view const& strings,
string_scalar const& delimiter,
size_type count,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::substring_index(
strings, experimental::detail::make_pair_iterator<string_view>(delimiter), count, mr);
}

std::unique_ptr<column> substring_index(strings_column_view const& strings,
strings_column_view const& delimiters,
size_type count,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(strings.size() == delimiters.size(),
"Strings and delimiters column sizes do not match");

CUDF_FUNC_RANGE();
auto delimiters_dev_view_ptr = cudf::column_device_view::create(delimiters.parent(), 0);
auto delimiters_dev_view = *delimiters_dev_view_ptr;
return (delimiters_dev_view.nullable())
? detail::substring_index(
strings,
experimental::detail::make_pair_iterator<string_view, true>(delimiters_dev_view),
count,
mr)
: detail::substring_index(
strings,
experimental::detail::make_pair_iterator<string_view, false>(delimiters_dev_view),
count,
mr);
}

} // namespace strings
} // namespace cudf
Loading

0 comments on commit 7178642

Please sign in to comment.