Skip to content

Commit

Permalink
strings::join_list_elements options for empty list inputs (#8285)
Browse files Browse the repository at this point in the history
This PR implements a new option for `strings::join_list_elements` on top of #8282. In particular, the new option is:
```
/**
 * @brief Setting for specifying what will be output from `join_list_elements` when an input list
 * is empty.
 */
enum class output_if_empty_list {
  EMPTY_STRING,  ///< Empty list will result in empty string
  NULL_ELEMENT   ///< Empty list will result in a null
};
```
This new option is necessary for implementing `concat_ws` in Spark, since the behavior of the output string is required to be different depending on the situation.

Currently blocked from merging by #8282.

Authors:
  - Nghia Truong (https://github.com/ttnghia)
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Keith Kraus (https://github.com/kkraus14)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - David Wendt (https://github.com/davidwendt)
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Ashwin Srinath (https://github.com/shwina)

URL: #8285
  • Loading branch information
ttnghia authored May 26, 2021
1 parent cd7fe6f commit 773fc7a
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 76 deletions.
48 changes: 38 additions & 10 deletions cpp/include/cudf/strings/combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ enum class separator_on_nulls {
NO ///< Do not add separators if an element is null
};

/**
* @brief Setting for specifying what will be output from `join_list_elements` when an input list
* is empty.
*/
enum class output_if_empty_list {
EMPTY_STRING, ///< Empty list will result in empty string
NULL_ELEMENT ///< Empty list will result in a null
};

/**
* @brief Concatenates all strings in the column into one new string delimited
* by an optional separator string.
Expand Down Expand Up @@ -203,8 +212,15 @@ std::unique_ptr<column> concatenate(
* column will also result in a null output row unless a valid @p separator_narep scalar is provided
* to be used in place of the null separators.
*
* If @p separate_nulls is set to `NO` and @p narep is valid then separators are not added to the
* output between null elements. Otherwise, separators are always added if @p narep is valid.
* If @p separate_nulls is set to `NO` and @p string_narep is valid then separators are not added to
* the output between null elements. Otherwise, separators are always added if @p string_narep is
* valid.
*
* If @p empty_list_policy is set to `EMPTY_STRING`, any row that is an empty list will result in
* an empty output string. Otherwise, the output will be a null.
*
* In the special case when the input list row contains all null elements, the output will be the
* same as in case of empty input list regardless of @p string_narep and @p separate_nulls values.
*
* @code{.pseudo}
* Example:
Expand Down Expand Up @@ -234,16 +250,19 @@ std::unique_ptr<column> concatenate(
* default is an invalid-scalar denoting that list rows containing null strings will result
* in null string in the corresponding output rows.
* @param separate_nulls If YES, then the separator is included for null rows if `narep` is valid.
* @param empty_list_policy if set to EMPTY_STRING, any input row that is an empty list will
* result in an empty string. Otherwise, it will result in a null.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column with concatenated results.
*/
std::unique_ptr<column> join_list_elements(
const lists_column_view& lists_strings_column,
const strings_column_view& separators,
string_scalar const& separator_narep = string_scalar("", false),
string_scalar const& string_narep = string_scalar("", false),
separator_on_nulls separate_nulls = separator_on_nulls::YES,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
string_scalar const& separator_narep = string_scalar("", false),
string_scalar const& string_narep = string_scalar("", false),
separator_on_nulls separate_nulls = separator_on_nulls::YES,
output_if_empty_list empty_list_policy = output_if_empty_list::EMPTY_STRING,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Given a lists column of strings (each row is a list of strings), concatenates the strings
Expand All @@ -259,6 +278,12 @@ std::unique_ptr<column> join_list_elements(
* If @p separate_nulls is set to `NO` and @p narep is valid then separators are not added to the
* output between null elements. Otherwise, separators are always added if @p narep is valid.
*
* If @p empty_list_policy is set to `EMPTY_STRING`, any row that is an empty list will result in
* an empty output string. Otherwise, the output will be a null.
*
* In the special case when the input list row contains all null elements, the output will be the
* same as in case of empty input list regardless of @p narep and @p separate_nulls values.
*
* @code{.pseudo}
* Example:
* s = [ ['aa', 'bb', 'cc'], null, ['', 'dd'], ['ee', null], ['ff'] ]
Expand All @@ -283,15 +308,18 @@ std::unique_ptr<column> join_list_elements(
* is an invalid-scalar denoting that list rows containing null strings will result in null
* string in the corresponding output rows.
* @param separate_nulls If YES, then the separator is included for null rows if `narep` is valid.
* @param empty_list_policy if set to EMPTY_STRING, any input row that is an empty list will result
* in an empty string. Otherwise, it will result in a null.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column with concatenated results.
*/
std::unique_ptr<column> join_list_elements(
const lists_column_view& lists_strings_column,
string_scalar const& separator = string_scalar(""),
string_scalar const& narep = string_scalar("", false),
separator_on_nulls separate_nulls = separator_on_nulls::YES,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
string_scalar const& separator = string_scalar(""),
string_scalar const& narep = string_scalar("", false),
separator_on_nulls separate_nulls = separator_on_nulls::YES,
output_if_empty_list empty_list_policy = output_if_empty_list::EMPTY_STRING,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
} // namespace strings
Expand Down
65 changes: 47 additions & 18 deletions cpp/src/strings/combine/join_list_elements.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct compute_size_and_concatenate_fn {
column_device_view const strings_dv;
string_scalar_device_view const string_narep_dv;
separator_on_nulls const separate_nulls;
output_if_empty_list const empty_list_policy;

offset_type* d_offsets{nullptr};

Expand All @@ -63,27 +64,40 @@ struct compute_size_and_concatenate_fn {
// We need to set `1` or `0` for the validities of the output strings.
int8_t* d_validities{nullptr};

__device__ void operator()(size_type const idx)
__device__ bool output_is_null(size_type const idx,
size_type const start_idx,
size_type const end_idx) const noexcept
{
if (func.is_null_list(lists_dv, idx)) { return true; }
return empty_list_policy == output_if_empty_list::NULL_ELEMENT && start_idx == end_idx;
}

__device__ void operator()(size_type const idx) const noexcept
{
// If this is the second pass, and the row `idx` is known to be a null string
if (d_chars and not d_validities[idx]) { return; }
if (d_chars && !d_validities[idx]) { return; }

// Indices of the strings within the list row
auto const start_idx = list_offsets[idx];
auto const end_idx = list_offsets[idx + 1];

if (not d_chars and func.is_null_list(lists_dv, idx)) {
if (!d_chars && output_is_null(idx, start_idx, end_idx)) {
d_offsets[idx] = 0;
d_validities[idx] = false;
return;
}

auto const separator = func.separator(idx);
auto size_bytes = size_type{0};
char* output_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr;
bool write_separator = false;
auto const separator = func.separator(idx);
auto size_bytes = size_type{0};
char* output_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr;
bool has_valid_element = false;
bool write_separator = false;

for (size_type str_idx = list_offsets[idx], idx_end = list_offsets[idx + 1]; str_idx < idx_end;
++str_idx) {
for (size_type str_idx = start_idx; str_idx < end_idx; ++str_idx) {
bool null_element = strings_dv.is_null(str_idx);
has_valid_element = has_valid_element || !null_element;

if (not d_chars and (null_element and not string_narep_dv.is_valid())) {
if (!d_chars && (null_element && !string_narep_dv.is_valid())) {
d_offsets[idx] = 0;
d_validities[idx] = false;
return; // early termination: the entire list of strings will result in a null string
Expand All @@ -104,9 +118,12 @@ struct compute_size_and_concatenate_fn {
write_separator || (separate_nulls == separator_on_nulls::YES) || !null_element;
}

if (not d_chars) {
d_offsets[idx] = size_bytes;
d_validities[idx] = true;
// If there are all null elements, the output should be the same as having an empty list input:
// a null or an empty string
if (!d_chars) {
d_offsets[idx] = has_valid_element ? size_bytes : 0;
d_validities[idx] =
has_valid_element || empty_list_policy == output_if_empty_list::EMPTY_STRING;
}
}
};
Expand Down Expand Up @@ -134,6 +151,7 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
string_scalar const& separator,
string_scalar const& narep,
separator_on_nulls separate_nulls,
output_if_empty_list empty_list_policy,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand Down Expand Up @@ -161,7 +179,8 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
lists_strings_column.offsets_begin(),
*strings_dv_ptr,
string_narep_dv,
separate_nulls};
separate_nulls,
empty_list_policy};
auto [offsets_column, chars_column, null_mask, null_count] =
make_strings_children_with_null_mask(comp_fn, num_rows, num_rows, stream, mr);

Expand All @@ -187,7 +206,7 @@ struct column_separators_fn {
__device__ bool is_null_list(column_device_view const& lists_dv, size_type const idx) const
noexcept
{
return lists_dv.is_null(idx) or (separators_dv.is_null(idx) and not sep_narep_dv.is_valid());
return lists_dv.is_null(idx) || (separators_dv.is_null(idx) && !sep_narep_dv.is_valid());
}

__device__ string_view separator(size_type const idx) const noexcept
Expand All @@ -204,6 +223,7 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
string_scalar const& separator_narep,
string_scalar const& string_narep,
separator_on_nulls separate_nulls,
output_if_empty_list empty_list_policy,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand Down Expand Up @@ -233,7 +253,8 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
lists_strings_column.offsets_begin(),
*strings_dv_ptr,
string_narep_dv,
separate_nulls};
separate_nulls,
empty_list_policy};
auto [offsets_column, chars_column, null_mask, null_count] =
make_strings_children_with_null_mask(comp_fn, num_rows, num_rows, stream, mr);

Expand All @@ -252,18 +273,25 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
string_scalar const& separator,
string_scalar const& narep,
separator_on_nulls separate_nulls,
output_if_empty_list empty_list_policy,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::join_list_elements(
lists_strings_column, separator, narep, separate_nulls, rmm::cuda_stream_default, mr);
return detail::join_list_elements(lists_strings_column,
separator,
narep,
separate_nulls,
empty_list_policy,
rmm::cuda_stream_default,
mr);
}

std::unique_ptr<column> join_list_elements(lists_column_view const& lists_strings_column,
strings_column_view const& separators,
string_scalar const& separator_narep,
string_scalar const& string_narep,
separator_on_nulls separate_nulls,
output_if_empty_list empty_list_policy,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
Expand All @@ -272,6 +300,7 @@ std::unique_ptr<column> join_list_elements(lists_column_view const& lists_string
separator_narep,
string_narep,
separate_nulls,
empty_list_policy,
rmm::cuda_stream_default,
mr);
}
Expand Down
99 changes: 79 additions & 20 deletions cpp/tests/strings/combine/join_list_elements_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>

#include <thrust/iterator/transform_iterator.h>
#include <cudf_test/iterator_utilities.hpp>

struct StringsListsConcatenateTest : public cudf::test::BaseFixture {
};
Expand All @@ -35,14 +34,13 @@ using INT_LISTS = cudf::test::lists_column_wrapper<int32_t>;

constexpr bool print_all{false};

auto null_at(cudf::size_type idx)
{
return cudf::detail::make_counting_transform_iterator(0, [idx](auto i) { return i != idx; });
}
auto all_nulls() { return cudf::test::iterator_all_nulls(); }

auto all_nulls()
auto null_at(cudf::size_type idx) { return cudf::test::iterator_with_null_at(idx); }

auto null_at(std::vector<cudf::size_type> const& indices)
{
return cudf::detail::make_counting_transform_iterator(0, [](auto) { return false; });
return cudf::test::iterator_with_null_at(cudf::host_span<cudf::size_type const>{indices});
}

auto nulls_from_nullptr(std::vector<const char*> const& strs)
Expand Down Expand Up @@ -99,14 +97,81 @@ TEST_F(StringsListsConcatenateTest, ZeroSizeStringsInput)
auto const string_lists =
STR_LISTS{STR_LISTS{""}, STR_LISTS{"", "", ""}, STR_LISTS{"", ""}, STR_LISTS{}}.release();
auto const string_lv = cudf::lists_column_view(string_lists->view());
auto const expected = STR_COL{"", "", "", ""};

auto results = cudf::strings::join_list_elements(string_lv);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
// Empty list results in empty string
{
auto const expected = STR_COL{"", "", "", ""};

auto const separators = STR_COL{"", "", "", ""}.release();
results = cudf::strings::join_list_elements(string_lv, separators->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
auto results = cudf::strings::join_list_elements(string_lv);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);

auto const separators = STR_COL{"", "", "", ""}.release();
results = cudf::strings::join_list_elements(string_lv, separators->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
}

// Empty list results in null
{
auto const expected = STR_COL{{"", "", "", "" /*NULL*/}, null_at(3)};
auto results =
cudf::strings::join_list_elements(string_lv,
cudf::string_scalar(""),
cudf::string_scalar(""),
cudf::strings::separator_on_nulls::NO,
cudf::strings::output_if_empty_list::NULL_ELEMENT);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);

auto const separators = STR_COL{"", "", "", ""}.release();
results = cudf::strings::join_list_elements(string_lv,
separators->view(),
cudf::string_scalar(""),
cudf::string_scalar(""),
cudf::strings::separator_on_nulls::NO,
cudf::strings::output_if_empty_list::NULL_ELEMENT);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
}
}

TEST_F(StringsListsConcatenateTest, ColumnHasEmptyListAndNullListInput)
{
auto const string_lists =
STR_LISTS{{STR_LISTS{"abc", "def", ""}, STR_LISTS{} /*NULL*/, STR_LISTS{}, STR_LISTS{"gh"}},
null_at(1)}
.release();
auto const string_lv = cudf::lists_column_view(string_lists->view());

// Empty list results in empty string
{
auto const expected = STR_COL{{"abc-def-", "" /*NULL*/, "", "gh"}, null_at(1)};

auto results = cudf::strings::join_list_elements(string_lv, cudf::string_scalar("-"));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);

auto const separators = STR_COL{"-", "", "", ""}.release();
results = cudf::strings::join_list_elements(string_lv, separators->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
}

// Empty list results in null
{
auto const expected = STR_COL{{"abc-def-", "" /*NULL*/, "" /*NULL*/, "gh"}, null_at({1, 2})};
auto results =
cudf::strings::join_list_elements(string_lv,
cudf::string_scalar("-"),
cudf::string_scalar(""),
cudf::strings::separator_on_nulls::NO,
cudf::strings::output_if_empty_list::NULL_ELEMENT);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);

auto const separators = STR_COL{"-", "", "", ""}.release();
results = cudf::strings::join_list_elements(string_lv,
separators->view(),
cudf::string_scalar(""),
cudf::string_scalar(""),
cudf::strings::separator_on_nulls::NO,
cudf::strings::output_if_empty_list::NULL_ELEMENT);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
}
}

TEST_F(StringsListsConcatenateTest, AllNullsStringsInput)
Expand All @@ -127,12 +192,6 @@ TEST_F(StringsListsConcatenateTest, AllNullsStringsInput)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected, print_all);
}

auto null_at(std::initializer_list<cudf::size_type> indices)
{
return cudf::detail::make_counting_transform_iterator(
0, [indices](auto i) { return std::find(indices.begin(), indices.end(), i) == indices.end(); });
}

TEST_F(StringsListsConcatenateTest, ScalarSeparator)
{
auto const string_lists = STR_LISTS{{STR_LISTS{{"a", "bb" /*NULL*/, "ccc"}, null_at(1)},
Expand Down
Loading

0 comments on commit 773fc7a

Please sign in to comment.