Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stream parameter to Set Operations (Public List APIs) #14305

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/include/cudf/lists/set_operations.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,13 +53,15 @@ namespace cudf::lists {
* to be `UNEQUAL` which means only non-null elements are checked for overlapping
* @param nans_equal Flag to specify whether floating-point NaNs should be considered as equal
* @param mr Device memory resource used to allocate the returned object
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A column of type BOOL containing the check results
*/
std::unique_ptr<column> have_overlap(
lists_column_view const& lhs,
lists_column_view const& rhs,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -87,6 +89,7 @@ std::unique_ptr<column> have_overlap(
* @param rhs The input lists column for the other side
* @param nulls_equal Flag to specify whether null elements should be considered as equal
* @param nans_equal Flag to specify whether floating-point NaNs should be considered as equal
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned object
* @return A lists column containing the intersection results
*/
Expand All @@ -95,6 +98,7 @@ std::unique_ptr<column> intersect_distinct(
lists_column_view const& rhs,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -122,6 +126,7 @@ std::unique_ptr<column> intersect_distinct(
* @param rhs The input lists column for the other side
* @param nulls_equal Flag to specify whether null elements should be considered as equal
* @param nans_equal Flag to specify whether floating-point NaNs should be considered as equal
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned object
* @return A lists column containing the union results
*/
Expand All @@ -130,6 +135,7 @@ std::unique_ptr<column> union_distinct(
lists_column_view const& rhs,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -157,6 +163,7 @@ std::unique_ptr<column> union_distinct(
* @param rhs The input lists column of elements to exclude
* @param nulls_equal Flag to specify whether null elements should be considered as equal
* @param nans_equal Flag to specify whether floating-point NaNs should be considered as equal
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned object
* @return A lists column containing the difference results
*/
Expand All @@ -165,6 +172,7 @@ std::unique_ptr<column> difference_distinct(
lists_column_view const& rhs,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/lists/set_operations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,42 +278,44 @@ std::unique_ptr<column> have_overlap(lists_column_view const& lhs,
lists_column_view const& rhs,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::have_overlap(lhs, rhs, nulls_equal, nans_equal, cudf::get_default_stream(), mr);
return detail::have_overlap(lhs, rhs, nulls_equal, nans_equal, stream, mr);
}

std::unique_ptr<column> intersect_distinct(lists_column_view const& lhs,
lists_column_view const& rhs,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::intersect_distinct(
lhs, rhs, nulls_equal, nans_equal, cudf::get_default_stream(), mr);
return detail::intersect_distinct(lhs, rhs, nulls_equal, nans_equal, stream, mr);
}

std::unique_ptr<column> union_distinct(lists_column_view const& lhs,
lists_column_view const& rhs,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::union_distinct(lhs, rhs, nulls_equal, nans_equal, cudf::get_default_stream(), mr);
return detail::union_distinct(lhs, rhs, nulls_equal, nans_equal, stream, mr);
}

std::unique_ptr<column> difference_distinct(lists_column_view const& lhs,
lists_column_view const& rhs,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::difference_distinct(
lhs, rhs, nulls_equal, nans_equal, cudf::get_default_stream(), mr);
return detail::difference_distinct(lhs, rhs, nulls_equal, nans_equal, stream, mr);
}

} // namespace cudf::lists
45 changes: 45 additions & 0 deletions cpp/tests/streams/lists_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cudf/lists/filling.hpp>
#include <cudf/lists/gather.hpp>
#include <cudf/lists/reverse.hpp>
#include <cudf/lists/set_operations.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/lists/stream_compaction.hpp>

Expand Down Expand Up @@ -166,3 +167,47 @@ TEST_F(ListTest, Distinct)
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}

TEST_F(ListTest, DifferenceDistinct)
{
cudf::test::lists_column_wrapper<int> list_col_a{{0, 1}, {2, 3, 7, 8}, {4, 5}};
cudf::test::lists_column_wrapper<int> list_col_b{{0, 1}, {1, 3, 6, 8}, {5}};
cudf::lists::difference_distinct(list_col_a,
list_col_b,
cudf::null_equality::EQUAL,
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}

TEST_F(ListTest, IntersectDistinct)
{
cudf::test::lists_column_wrapper<int> list_col_a{{0, 1}, {2, 3, 7, 8}, {4, 5}};
cudf::test::lists_column_wrapper<int> list_col_b{{0, 1}, {1, 3, 6, 8}, {5}};
cudf::lists::intersect_distinct(list_col_a,
list_col_b,
cudf::null_equality::EQUAL,
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}

TEST_F(ListTest, UnionDistinct)
{
cudf::test::lists_column_wrapper<int> list_col_a{{0, 1}, {2, 3, 7, 8}, {4, 5}};
cudf::test::lists_column_wrapper<int> list_col_b{{0, 1}, {1, 3, 6, 8}, {5}};
cudf::lists::union_distinct(list_col_a,
list_col_b,
cudf::null_equality::EQUAL,
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}

TEST_F(ListTest, HaveOverlap)
{
cudf::test::lists_column_wrapper<int> list_col_a{{0, 1}, {2, 3, 7, 8}, {4, 5}};
cudf::test::lists_column_wrapper<int> list_col_b{{0, 1}, {1, 3, 6, 8}, {5}};
cudf::lists::have_overlap(list_col_a,
list_col_b,
cudf::null_equality::EQUAL,
cudf::nan_equality::ALL_EQUAL,
cudf::test::get_default_stream());
}
Loading