Skip to content

Commit

Permalink
- concatenate row items using a separator defined per row
Browse files Browse the repository at this point in the history
  - this Closes rapidsai#3726
  - this emulates `concatenate_ws` spark functionality
  - provides option for a global separator and global column null replacements
  - skips null values in a row to perform concatenation
  • Loading branch information
sriramch committed May 14, 2020
1 parent a872435 commit 6467dbb
Show file tree
Hide file tree
Showing 3 changed files with 550 additions and 0 deletions.
58 changes: 58 additions & 0 deletions cpp/include/cudf/strings/combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,64 @@ std::unique_ptr<column> join_strings(
string_scalar const& narep = string_scalar("", false),
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Concatenates a list of strings columns using separators for each row
* and returns the result as a string column.
*
* Each new string is created by concatenating the strings from the same
* row delimited by the row separator provided for that row. The following rules
* are applicable:
*
* - If row separator for a given row is null, output column for that row is null, unless
* there is a valid @p separator_narep
* - If all column values for a given row is null, output column for that row is null, unless
* there is a valid @p col_narep
* - null column values for a given row are skipped, if the column replacement isn't valid
* - The separator is only applied between two valid column values
* - If valid @p separator_narep and @p col_narep are provided, the output column is always
* non nullable
*
* @code{.pseudo}
* Example:
* c0 = ['aa', null, '', 'ee', null, 'ff']
* c1 = [null, 'cc', 'dd', null, null, 'gg']
* c2 = ['bb', '', null, null, null, 'hh']
* sep = ['::', '%%', '^^', '!', '*', null]
* out0 = concatenate([c0, c1, c2], sep)
* out0 is ['aa::bb', 'cc%%', '^^dd', 'ee', null, null]
*
* sep_rep = '+'
* out1 = concatenate([c0, c1, c2], sep, sep_rep)
* out1 is ['aa::bb', 'cc%%', '^^dd', 'ee', null, 'ff+gg+hh']
*
* col_rep = '-'
* out2 = concatenate([c0, c1, c2], sep, invalid_sep_rep, col_rep)
* out2 is ['aa::-::bb', '-%%cc%%', '^^dd^^-', 'ee!-!-', '-*-*-', null]
* @endcode
*
* @throw cudf::logic_error if no input columns are specified - table view is empty
* @throw cudf::logic_error if input columns are not all strings columns.
* @throw cudf::logic_error if the number of rows from @p separators and @p strings_columns
* do not match
*
* @param strings_columns List of string columns to concatenate.
* @param separators String column that provides the separator for a given row
* @param separator_narep String that should be used in place of a null separator for a given
* row. Default of invalid-scalar means no row separator value replacements.
* Default is an invalid string.
* @param col_narep String that should be used in place of any null strings
* found in any column. Default of invalid-scalar means no null column value replacements.
* Default is an invalid string.
* @param mr Resource for allocating device memory.
* @return New column with concatenated results.
*/
std::unique_ptr<column> concatenate(
table_view const& strings_columns,
strings_column_view const& separators,
string_scalar const& separator_narep = string_scalar("", false),
string_scalar const& col_narep = string_scalar("", false),
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
199 changes: 199 additions & 0 deletions cpp/src/strings/combine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,195 @@ std::unique_ptr<column> join_strings(
mr);
}

//
std::unique_ptr<column> concatenate(table_view const& strings_columns,
strings_column_view const& separators,
string_scalar const& separator_narep,
string_scalar const& col_narep,
rmm::mr::device_memory_resource* mr,
cudaStream_t stream)
{
auto num_columns = strings_columns.num_columns();
CUDF_EXPECTS(num_columns > 0, "At least one column must be specified");
// Check if all columns are of type string
CUDF_EXPECTS(std::all_of(strings_columns.begin(),
strings_columns.end(),
[](auto c) { return c.type().id() == STRING; }),
"All columns must be of type string");

auto strings_count = strings_columns.num_rows();
CUDF_EXPECTS(strings_count == separators.size(),
"Separators column should be of the same size of string columns");
if (strings_count == 0) // Empty begets empty
return detail::make_empty_strings_column(mr, stream);

// Invalid output column strings - null rows
string_view const invalid_str{nullptr, 0};
auto const separator_rep = get_scalar_device_view(const_cast<string_scalar&>(separator_narep));
auto const col_rep = get_scalar_device_view(const_cast<string_scalar&>(col_narep));
auto const separator_col_view_ptr = column_device_view::create(separators.parent(), stream);
auto const separator_col_view = *separator_col_view_ptr;

if (num_columns == 1) {
// Shallow copy of the resultant strings
rmm::device_vector<string_view> out_col_strings(strings_count);

// Device view of the only column in the table view
auto const col0_ptr = column_device_view::create(strings_columns.column(0), stream);
auto const col0 = *col0_ptr;

// Execute it on every element
thrust::transform(
rmm::exec_policy(stream)->on(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(strings_count),
out_col_strings.data().get(),
// Output depends on the separator
[col0, invalid_str, separator_col_view, separator_rep, col_rep] __device__(auto ridx) {
if (!separator_col_view.is_valid(ridx) && !separator_rep.is_valid()) return invalid_str;
return (!col0.is_valid(ridx)) ? (col_rep.is_valid() ? col_rep.value() : invalid_str)
: col0.element<string_view>(ridx);
});

return make_strings_column(out_col_strings, invalid_str, stream, mr);
}

// Create device views from the strings columns.
auto table = table_device_view::create(strings_columns, stream);
auto d_table = *table;

// Create resulting null mask
auto valid_mask = cudf::experimental::detail::valid_if(
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
[d_table, separator_col_view, separator_rep, col_rep] __device__(size_type ridx) {
if (!separator_col_view.is_valid(ridx) && !separator_rep.is_valid()) return false;
bool all_nulls =
thrust::all_of(thrust::seq, d_table.begin(), d_table.end(), [ridx](auto const& col) {
return col.is_null(ridx);
});
return all_nulls ? col_rep.is_valid() : true;
},
stream,
mr);

auto null_count = valid_mask.second;

// Build offsets column by computing sizes of each string in the output
auto offsets_transformer = [d_table, separator_col_view, separator_rep, col_rep] __device__(
size_type ridx) {
// If the separator value for the row is null and if there aren't global separator
// replacements, this row does not have any value - null row
if (!separator_col_view.is_valid(ridx) && !separator_rep.is_valid()) return 0;

// For this row (idx), iterate over each column and add up the bytes
bool all_nulls =
thrust::all_of(thrust::seq, d_table.begin(), d_table.end(), [ridx](auto const& d_column) {
return d_column.is_null(ridx);
});
// If all column values are null and there isn't a global column replacement value, this row
// is a null row
if (all_nulls && !col_rep.is_valid()) return 0;

// There is at least one non-null column value (it can still be empty though)
auto separator_str = separator_col_view.is_valid(ridx)
? separator_col_view.element<string_view>(ridx)
: separator_rep.value();

size_type bytes = thrust::transform_reduce(
thrust::seq,
d_table.begin(),
d_table.end(),
[ridx, separator_str, col_rep] __device__(column_device_view const& d_column) {
// If column is null and there isn't a valid column replacement, this isn't used in
// final string concatenate
if (d_column.is_null(ridx) && !col_rep.is_valid()) return 0;
return separator_str.size_bytes() + (d_column.is_null(ridx)
? col_rep.size()
: d_column.element<string_view>(ridx).size_bytes());
},
0,
thrust::plus<size_type>());

// Null/empty separator and columns doesn't produce a non-empty string
if (bytes == 0) assert(separator_str.size_bytes() == 0);

// Separator goes only in between elements
return bytes - separator_str.size_bytes();
};
auto offsets_transformer_itr = thrust::make_transform_iterator(
thrust::make_counting_iterator<size_type>(0), offsets_transformer);
auto offsets_column = detail::make_offsets_child_column(
offsets_transformer_itr, offsets_transformer_itr + strings_count, mr, stream);
auto d_results_offsets = offsets_column->view().data<int32_t>();

// Create the chars column
size_type bytes = thrust::device_pointer_cast(d_results_offsets)[strings_count];
auto chars_column =
strings::detail::create_chars_child_column(strings_count, null_count, bytes, mr, stream);

// Fill the chars column
auto d_results_chars = chars_column->mutable_view().data<char>();
thrust::for_each_n(rmm::exec_policy(stream)->on(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
[d_table,
num_columns,
d_results_offsets,
d_results_chars,
separator_col_view,
separator_rep,
col_rep] __device__(size_type ridx) {
// If the separator for this row is null and if there isn't a valid separator
// to replace, do not write anything for this row
if (!separator_col_view.is_valid(ridx) && !separator_rep.is_valid()) return;

bool all_nulls = thrust::all_of(
thrust::seq, d_table.begin(), d_table.end(), [ridx](auto const& col) {
return col.is_null(ridx);
});

// If all column values are null and there isn't a valid column replacement,
// skip this row
if (all_nulls && !col_rep.is_valid()) return;

size_type offset = d_results_offsets[ridx];
char* d_buffer = d_results_chars + offset;
bool colval_written = false;

// There is at least one non-null column value (it can still be empty though)
auto separator_str = separator_col_view.is_valid(ridx)
? separator_col_view.element<string_view>(ridx)
: separator_rep.value();

// Write out each column's entry for this row
for (size_type col_idx = 0; col_idx < num_columns; ++col_idx) {
auto d_column = d_table.column(col_idx);
// If the column isn't valid and if there isn't a replacement for it, skip
// it
if (d_column.is_null(ridx) && !col_rep.is_valid()) continue;

// Separator goes only in between elements
if (colval_written)
d_buffer = detail::copy_string(d_buffer, separator_str);

string_view d_str = d_column.is_null(ridx)
? col_rep.value()
: d_column.element<string_view>(ridx);
d_buffer = detail::copy_string(d_buffer, d_str);
colval_written = true;
}
});

return make_strings_column(strings_count,
std::move(offsets_column),
std::move(chars_column),
null_count,
(null_count) ? std::move(valid_mask.first) : rmm::device_buffer{},
stream,
mr);
}

} // namespace detail

// APIs
Expand All @@ -267,5 +456,15 @@ std::unique_ptr<column> join_strings(strings_column_view const& strings,
return detail::join_strings(strings, separator, narep, mr);
}

std::unique_ptr<column> concatenate(table_view const& strings_columns,
strings_column_view const& separators,
string_scalar const& separator_narep,
string_scalar const& col_narep,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::concatenate(strings_columns, separators, separator_narep, col_narep, mr, 0);
}

} // namespace strings
} // namespace cudf
Loading

0 comments on commit 6467dbb

Please sign in to comment.