-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Uniform Neighborhood Sampling #2258
Refactor Uniform Neighborhood Sampling #2258
Conversation
…inal, add new functions with cleaned up API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Part 1.
cpp/include/cugraph/algorithms.hpp
Outdated
* @param with_replacement boolean flag specifying if random sampling is done with replacement | ||
* (true); or, without replacement (false); default = true; | ||
* @return tuple of tuple of device vectors and counts: | ||
* ((vertex_t source_vertex, vertex_t destination_vertex, int rank, edge_t index), rx_counts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this comment is out-dated copy-and-paste from the previous implementation. I assume we are returning a tuple of edge source, edge destination, and edge weight vectors (the last might be actually edge ID right at this moment?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in next push
cpp/include/cugraph/algorithms.hpp
Outdated
* @return tuple of tuple of device vectors and counts: | ||
* ((vertex_t source_vertex, vertex_t destination_vertex, int rank, edge_t index), rx_counts) | ||
*/ | ||
template <typename graph_view_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... and we are sort of mixing
template <typename graph_view_t>
and using typename graph_view_t::vertex_type, ...
and
template <typename vertext_t, typename edge_t, typename weight_t, bool store_transpoed, bool multi_gpu>
and using graph_view_t<vertex_t, edge_t, weight_t, store_transposed, multi_gpu>
.
I think we'd better be consistent and any preference in one over the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong preference for me.
There is, I think, and advantage to the template <typename graph_view_t>
approach in that if we change the implementation of graph_view (adding or removing a template parameter), as long as typename graph_view_t::vertex_type
is still defined the API works without modification. I believe Andrei copied this from my Louvain definition which uses this approach. I implemented Louvain this way so that I could support both the Legacy graph and the graph_t with the same API.
But the syntax is a bit cleaner with your original approach. I don't think it's likely that we will frequently change the template signature of the API, and we will eventually get rid of the legacy graph class.
I'd be happy to change this back to your original approach, or if we like the template <typename graph_view_t>
approach better I can add that to the list of things to gradually update in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... I don't have strong preference either but I have strong preference for consistency.
I am also using for primitives but wondering I should better use graph_view_t<vertex_t, edge_t, weight_t, store_transpoed, multi_gpu>
instead.
I am getting more inclined to the graph_view_t<vertex_t, edge_t, weight_t, store_transpoed, multi_gpu>
approach as this code does not work for a general graph view type but works only with our graph_view_t (e.g. the implementation depends on multiple member functions only exist in graph_view_t).
And hopefully we can eliminate the legacy code sooner than later; at that point, I slightly prefer graph_view_t<vertex_t, edge_t, weight_t, store_transpoed, multi_gpu>
even though this will have pretty much very minimal impact on end-user experiences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I will make those changes in the next push. I will leave Louvain as it is now. I plan to create a PR to add Louvain to the C API, I will refactor the Louvain API in that PR.
cpp/include/cugraph/algorithms.hpp
Outdated
uniform_nbr_sample(raft::handle_t const& handle, | ||
graph_view_t const& graph_view, | ||
raft::device_span<typename graph_view_t::vertex_type> d_starting_vertices, | ||
raft::host_span<const int> h_fan_out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess d_
and h_
here are a bit redundant (especially with device_span
and host_span
). Or we should use this naming convention in all the functions in the public API. My current practice is to use d_
and h_
only when we have both host and device vectors with the same name, but open to discussions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah.... and this API is way more intuitive than the previous one!!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love how the span variants clean up the API. I'll drop the extra prefixes in the next push
cpp/include/cugraph/algorithms.hpp
Outdated
@@ -1536,6 +1537,32 @@ uniform_nbr_sample(raft::handle_t const& handle, | |||
std::vector<int> const& h_fan_out, | |||
bool with_replacement = true); | |||
|
|||
/** | |||
* @brief Multi-GPU Uniform Neighborhood Sampling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really a Multi-GPU only thing or for both SG & MG
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both. Updated the comment.
cpp/src/detail/shuffle_wrappers.cu
Outdated
handle.get_stream()); | ||
|
||
return d_rx_vertices; | ||
} | ||
|
||
template <typename vertex_t> | ||
rmm::device_uvector<vertex_t> shuffle_vertices_by_gpu_id(raft::handle_t const& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we better rename this to shuffle_ext_vertices_by_gpu_id
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in next push
@@ -47,6 +48,22 @@ struct compute_gpu_id_from_vertex_t { | |||
} | |||
}; | |||
|
|||
template <typename vertex_t> | |||
struct compute_gpu_id_from_int_vertex_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we better rename other functors working on external vertex IDs to ext_vertex_t
and ext_edge_t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done for vertex in the next push.
Do we ever try and use these functors on an int_edge_t
? I'm inclined not to add the ext
to the name unless we need to distinguish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, agreed.
template <typename vertex_t> | ||
struct compute_gpu_id_from_int_vertex_t { | ||
vertex_t const* vertex_partition_range_lasts; | ||
size_t num_vertex_partitions; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah... maybe just a FIXME statement, but we should eventually replace this (pointer, size) pairs to raft::device_span.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to span in the next push.
cpp/src/detail/utility_wrappers.cu
Outdated
zip_iter, | ||
zip_iter + d_vertices.size(), | ||
zip_iter, | ||
[] __device__(auto pair) { return thrust::get<1>(pair) > 0; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: https://github.com/NVIDIA/thrust/issues/1302
Maybe do copy_if in chunks or add check for d_vertices.size() and throw an exception if d_vertices.size() overflows 32 bit integer (if you expect this will unlikely to happen and we'd better wait for thrust folks to fix this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I guess thrust::remove_if
is more intuitive than copy_if
here (unless you willing to copy in chunks). You may look for https://github.com/rapidsai/cugraph/pull/2253/files#diff-ce8c8b8ffdc670a97313ca4ce20de7bf8a18daa81f5a1fde50f3b162bf75b75bR1238
#if 1 // FIXME: work-around for the 32 bit integer overflow issue in thrust::remove,
// thrust::remove_if, and thrust::copy_if (https://github.com/NVIDIA/thrust/issues/1302)
rmm::device_uvector<vertex_t> tmp_indices(
thrust::count_if(handle.get_thrust_policy(),
nbr_intersection_indices.begin(),
nbr_intersection_indices.end(),
detail::not_equal_t<vertex_t>{invalid_vertex_id<vertex_t>::value}),
handle.get_stream());
size_t num_copied{0};
size_t num_scanned{0};
while (num_scanned < nbr_intersection_indices.size()) {
size_t this_scan_size = std::min(
size_t{1} << 30,
static_cast<size_t>(thrust::distance(nbr_intersection_indices.begin() + num_scanned,
nbr_intersection_indices.end())));
num_copied += static_cast<size_t>(thrust::distance(
tmp_indices.begin() + num_copied,
thrust::copy_if(handle.get_thrust_policy(),
nbr_intersection_indices.begin() + num_scanned,
nbr_intersection_indices.begin() + num_scanned + this_scan_size,
tmp_indices.begin() + num_copied,
detail::not_equal_t<vertex_t>{invalid_vertex_id<vertex_t>::value})));
num_scanned += this_scan_size;
}
nbr_intersection_indices = std::move(tmp_indices);
#else
nbr_intersection_indices.resize(
thrust::distance(nbr_intersection_indices.begin(),
thrust::remove(handle.get_thrust_policy(),
nbr_intersection_indices.begin(),
nbr_intersection_indices.end(),
invalid_vertex_id<vertex_t>::value)),
handle.get_stream());
#endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to remove_if.
Seems unlikely to have an overflow issue, at least with current memory sizes, as the number of elements in a vertex array on each partition is likely to be < 2^31-1. But I added a FIXME so we can remember.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add CUGRAPH_EXPECTS(d_vertices.size() < std::numeric_limit<int32_t>::max())
as well. I agree that this is unlikely to happen, but if this happens in user side or large scale testing, it is very difficult for us to figure out this is actually due to the overflow. With the check, it will be way easier to figure out what went awry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the CUGRAPH_EXPECTS
here and the other two places where I call remove_if
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Part 2
namespace detail { | ||
|
||
/** | ||
* @brief Compute local out degrees of the majors belonging to the adjacency matrices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to double check but I guess this computes out-degrees if major == source and in-degrees if major == destination.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that's correct. The sampling code forces store_transposed=false
, so this function assumes that.
I'm not sure that's a good long-term assumption (feels like sampling on incoming vertices would be a reasonable thing to do). But at the moment this is sufficient.
Perhaps a FIXME to address this later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a FIXME near the beginning of these function definitions to reflect that we should revisit this.
* @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and | ||
* handles to various CUDA libraries) to run graph algorithms. | ||
* @param graph_view Non-owning graph object. | ||
* @return A single vector containing the local out degrees of the majors belong to the adjacency |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out degrees
may not be accurate here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same observation as above, store_tranposed=false
for the sampling algorithms.
const rmm::device_uvector<typename GraphViewType::edge_type>& global_out_degrees); | ||
|
||
/** | ||
* @brief Gather active majors across gpus in a column communicator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this gather or allgather (the results will be stored only in root or every process in the column communicator?)? If allgather, better rename to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in next push
rmm::device_uvector<vertex_t>&& d_in); | ||
|
||
/** | ||
* @brief Return global out degrees of active majors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to double check "out" degrees here is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling forces store_transposed=false
|
||
template <typename vertex_t> | ||
rmm::device_uvector<vertex_t> gather_active_majors(raft::handle_t const& handle, | ||
rmm::device_uvector<vertex_t>&& d_in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, this is using allgatherv, so this function should better be renamed to "allgather_active_majors".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in next push
template <typename GraphViewType> | ||
rmm::device_uvector<typename GraphViewType::edge_type> compute_local_major_degrees( | ||
raft::handle_t const& handle, GraphViewType const& graph_view) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this code is pretty much https://github.com/rapidsai/cugraph/blob/branch-22.06/cpp/src/structure/graph_view_impl.cuh#L88 less col_comm.reduce(...)
(https://github.com/rapidsai/cugraph/blob/branch-22.06/cpp/src/structure/graph_view_impl.cuh#L155).
Better to re-factor (or at least FIXME) to avoid code-duplication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a FIXME.
I actually think much of this logic should be moved into the graph_view, we assume too much regarding implementation in these functions.
|
||
auto compacted_length = thrust::distance( | ||
input_iter, | ||
thrust::remove_if( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: The current version of thrust::remove_if does not work properly if minors.size() overflows 32bit integer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added FIXME to both of these remove_if calls in this file (both branches of the if)
thrust::make_optional(rmm::device_uvector<weight_t>(0, handle.get_stream())); | ||
|
||
size_t level{0}; | ||
size_t num_rows{1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better re-name this to row_comm_size
(this is more of a consistency thing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me except for few minor complaints about documentation.
cpp/include/cugraph/algorithms.hpp
Outdated
* @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and | ||
* handles to various CUDA libraries) to run graph algorithms. | ||
* @param graph_view Graph View object to generate NBR Sampling on. | ||
* @param d_starting_vertices Device span of starting vertex IDs for the NBR Sampling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d_starting_vertices
=>starting_vertices
as we renamed the input parameters.
* @param graph_view Graph View object to generate NBR Sampling on. | ||
* @param d_starting_vertices Device span of starting vertex IDs for the NBR Sampling. | ||
* @param h_fan_out Host span defining branching out (fan-out) degree per source vertex for each | ||
* level |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
h_fan_out
to fan_out
.
@@ -350,6 +365,10 @@ void partially_decompress_edge_partition_to_fill_edgelist( | |||
thrust::fill( | |||
thrust::seq, majors + major_offset, majors + major_offset + local_degree, major); | |||
thrust::copy(thrust::seq, indices, indices + local_degree, minors + major_offset); | |||
if (weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can lead to thread-divergence if local_degree
values vary significantly within the threads in a single Warp. May add a FIXME statement. I have the same issue in Triangle Counting implementation (https://github.com/rapidsai/cugraph/pull/2253/files#diff-ce8c8b8ffdc670a97313ca4ce20de7bf8a18daa81f5a1fde50f3b162bf75b75bR434).
You may add a similar FIXME. Later, we may address this together by adding something like (delayed) segmented_copy(or fill)
.
@gpucibot merge |
This PR will refactor the Uniform Neighborhood Sampling implementation to meet the new C API.
Major elements: