Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stream synchronization in MTMG graph construction #4275

Merged
merged 7 commits into from
May 23, 2024
76 changes: 38 additions & 38 deletions cpp/include/cugraph/mtmg/detail/per_device_edgelist.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,13 +62,13 @@ class per_device_edgelist_t {
/**
* @brief Construct a new per device edgelist t object
*
* @param handle MTMG resource handle - used to identify GPU resources
* @param stream_view CUDA stream view
* @param device_buffer_size Number of edges to store in each device buffer
* @param use_weight Whether or not the edgelist will have weights
* @param use_edge_id Whether or not the edgelist will have edge ids
* @param use_edge_type Whether or not the edgelist will have edge types
*/
per_device_edgelist_t(cugraph::mtmg::handle_t const& handle,
per_device_edgelist_t(rmm::cuda_stream_view stream_view,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something tedious, but I guess our convention is to pass the handle as the first input argument and stream_view as the last input argument. Should we better keep this convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I should have caught that. I will make that change.

size_t device_buffer_size,
bool use_weight,
bool use_edge_id,
Expand All @@ -89,7 +89,7 @@ class per_device_edgelist_t {
edge_type_ = std::make_optional(std::vector<rmm::device_uvector<edge_type_t>>());
}

create_new_buffers(handle);
create_new_buffers(stream_view);
}

/**
Expand All @@ -111,14 +111,14 @@ class per_device_edgelist_t {
/**
* @brief Append a list of edges to the edge list
*
* @param handle The resource handle
* @param src Source vertex id
* @param dst Destination vertex id
* @param wgt Edge weight
* @param edge_id Edge id
* @param edge_type Edge type
* @param stream_view CUDA stream view
* @param src Source vertex id
* @param dst Destination vertex id
* @param wgt Edge weight
* @param edge_id Edge id
* @param edge_type Edge type
*/
void append(handle_t const& handle,
void append(rmm::cuda_stream_view stream_view,
raft::host_span<vertex_t const> src,
raft::host_span<vertex_t const> dst,
std::optional<raft::host_span<weight_t const>> wgt,
Expand All @@ -142,13 +142,13 @@ class per_device_edgelist_t {
pos += copy_count;
current_pos_ += copy_count;

if (current_pos_ == src_.back().size()) { create_new_buffers(handle); }
if (current_pos_ == src_.back().size()) { create_new_buffers(stream_view); }
}
}

std::for_each(copy_positions.begin(),
copy_positions.end(),
[&handle,
[&stream_view,
&this_src = src_,
&src,
&this_dst = dst_,
Expand All @@ -164,47 +164,45 @@ class per_device_edgelist_t {
raft::update_device(this_src[buffer_idx].begin() + buffer_pos,
src.begin() + input_pos,
copy_count,
handle.get_stream());
stream_view);

raft::update_device(this_dst[buffer_idx].begin() + buffer_pos,
dst.begin() + input_pos,
copy_count,
handle.get_stream());
stream_view);

if (this_wgt)
raft::update_device((*this_wgt)[buffer_idx].begin() + buffer_pos,
wgt->begin() + input_pos,
copy_count,
handle.get_stream());
stream_view);

if (this_edge_id)
raft::update_device((*this_edge_id)[buffer_idx].begin() + buffer_pos,
edge_id->begin() + input_pos,
copy_count,
handle.get_stream());
stream_view);

if (this_edge_type)
raft::update_device((*this_edge_type)[buffer_idx].begin() + buffer_pos,
edge_type->begin() + input_pos,
copy_count,
handle.get_stream());
stream_view);
});

handle.sync_stream();
}

/**
* @brief Mark the edgelist as ready for reading (all writes are complete)
*
* @param handle The resource handle
* @param stream_view CUDA stream view
*/
void finalize_buffer(handle_t const& handle)
void finalize_buffer(rmm::cuda_stream_view stream_view)
{
src_.back().resize(current_pos_, handle.get_stream());
dst_.back().resize(current_pos_, handle.get_stream());
if (wgt_) wgt_->back().resize(current_pos_, handle.get_stream());
if (edge_id_) edge_id_->back().resize(current_pos_, handle.get_stream());
if (edge_type_) edge_type_->back().resize(current_pos_, handle.get_stream());
src_.back().resize(current_pos_, stream_view);
dst_.back().resize(current_pos_, stream_view);
if (wgt_) wgt_->back().resize(current_pos_, stream_view);
if (edge_id_) edge_id_->back().resize(current_pos_, stream_view);
if (edge_type_) edge_type_->back().resize(current_pos_, stream_view);
}

bool use_weight() const { return wgt_.has_value(); }
Expand All @@ -230,16 +228,18 @@ class per_device_edgelist_t {
void consolidate_and_shuffle(cugraph::mtmg::handle_t const& handle, bool store_transposed)
{
if (src_.size() > 1) {
auto stream = handle.raft_handle().get_stream();

size_t total_size = std::transform_reduce(
src_.begin(), src_.end(), size_t{0}, std::plus<size_t>(), [](auto& d_vector) {
return d_vector.size();
});

resize_and_copy_buffers(handle.get_stream(), src_, total_size);
resize_and_copy_buffers(handle.get_stream(), dst_, total_size);
if (wgt_) resize_and_copy_buffers(handle.get_stream(), *wgt_, total_size);
if (edge_id_) resize_and_copy_buffers(handle.get_stream(), *edge_id_, total_size);
if (edge_type_) resize_and_copy_buffers(handle.get_stream(), *edge_type_, total_size);
resize_and_copy_buffers(stream, src_, total_size);
resize_and_copy_buffers(stream, dst_, total_size);
if (wgt_) resize_and_copy_buffers(stream, *wgt_, total_size);
if (edge_id_) resize_and_copy_buffers(stream, *edge_id_, total_size);
if (edge_type_) resize_and_copy_buffers(stream, *edge_type_, total_size);
}

auto tmp_wgt = wgt_ ? std::make_optional(std::move((*wgt_)[0])) : std::nullopt;
Expand Down Expand Up @@ -286,16 +286,16 @@ class per_device_edgelist_t {
buffer = std::move(new_buffer);
}

void create_new_buffers(cugraph::mtmg::handle_t const& handle)
void create_new_buffers(rmm::cuda_stream_view stream_view)
{
src_.emplace_back(device_buffer_size_, handle.get_stream());
dst_.emplace_back(device_buffer_size_, handle.get_stream());
src_.emplace_back(device_buffer_size_, stream_view);
dst_.emplace_back(device_buffer_size_, stream_view);

if (wgt_) { wgt_->emplace_back(device_buffer_size_, handle.get_stream()); }
if (wgt_) { wgt_->emplace_back(device_buffer_size_, stream_view); }

if (edge_id_) { edge_id_->emplace_back(device_buffer_size_, handle.get_stream()); }
if (edge_id_) { edge_id_->emplace_back(device_buffer_size_, stream_view); }

if (edge_type_) { edge_type_->emplace_back(device_buffer_size_, handle.get_stream()); }
if (edge_type_) { edge_type_->emplace_back(device_buffer_size_, stream_view); }

current_pos_ = 0;
}
Expand Down
3 changes: 1 addition & 2 deletions cpp/include/cugraph/mtmg/edge_property.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@

#include <cugraph/mtmg/detail/device_shared_wrapper.hpp>
#include <cugraph/mtmg/edge_property_view.hpp>
#include <cugraph/mtmg/handle.hpp>

namespace cugraph {
namespace mtmg {
Expand Down
3 changes: 1 addition & 2 deletions cpp/include/cugraph/mtmg/edge_property_view.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,6 @@
#pragma once

#include <cugraph/mtmg/detail/device_shared_wrapper.hpp>
#include <cugraph/mtmg/handle.hpp>

namespace cugraph {
namespace mtmg {
Expand Down
10 changes: 7 additions & 3 deletions cpp/include/cugraph/mtmg/edgelist.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,7 @@ class edgelist_t : public detail::device_shared_wrapper_t<
bool use_edge_type)
{
detail::per_device_edgelist_t<vertex_t, weight_t, edge_t, edge_type_t> tmp(
handle, device_buffer_size, use_weight, use_edge_id, use_edge_type);
handle.get_stream(), device_buffer_size, use_weight, use_edge_id, use_edge_type);

detail::device_shared_wrapper_t<
detail::per_device_edgelist_t<vertex_t, weight_t, edge_t, edge_type_t>>::set(handle,
Expand All @@ -49,7 +49,11 @@ class edgelist_t : public detail::device_shared_wrapper_t<
/**
* @brief Stop inserting edges into this edgelist so we can use the edges
*/
void finalize_buffer(handle_t const& handle) { this->get(handle).finalize_buffer(handle); }
void finalize_buffer(handle_t const& handle)
{
handle.sync_stream_pool();
this->get(handle).finalize_buffer(handle.get_stream());
}

/**
* @brief Consolidate for the edgelist edges into a single edgelist and then
Expand Down
7 changes: 6 additions & 1 deletion cpp/include/cugraph/mtmg/handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -79,6 +79,11 @@ class handle_t {
*/
void sync_stream() const { sync_stream(get_stream()); }

/**
* @brief Sync all streams in the stream pool
*/
void sync_stream_pool() const { raft::resource::sync_stream_pool(raft_handle_); }

/**
* @brief get thrust policy for the stream
*
Expand Down
35 changes: 19 additions & 16 deletions cpp/include/cugraph/mtmg/per_thread_edgelist.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@

#include <cugraph/mtmg/detail/device_shared_wrapper.hpp>
#include <cugraph/mtmg/detail/per_device_edgelist.hpp>
#include <cugraph/mtmg/handle.hpp>

namespace cugraph {
namespace mtmg {
Expand Down Expand Up @@ -70,21 +69,21 @@ class per_thread_edgelist_t {
/**
* @brief Append an edge to the edge list
*
* @param handle The resource handle
* @param src Source vertex id
* @param dst Destination vertex id
* @param wgt Edge weight
* @param edge_id Edge id
* @param edge_type Edge type
* @param stream_view The cuda stream
* @param src Source vertex id
* @param dst Destination vertex id
* @param wgt Edge weight
* @param edge_id Edge id
* @param edge_type Edge type
*/
void append(handle_t const& handle,
void append(rmm::cuda_stream_view stream_view,
vertex_t src,
vertex_t dst,
std::optional<weight_t> wgt,
std::optional<edge_t> edge_id,
std::optional<edge_type_t> edge_type)
{
if (current_pos_ == src_.size()) { flush(handle); }
if (current_pos_ == src_.size()) { flush(stream_view); }

src_[current_pos_] = src;
dst_[current_pos_] = dst;
Expand All @@ -98,14 +97,14 @@ class per_thread_edgelist_t {
/**
* @brief Append a list of edges to the edge list
*
* @param handle The resource handle
* @param stream_view The cuda stream
* @param src Source vertex id
* @param dst Destination vertex id
* @param wgt Edge weight
* @param edge_id Edge id
* @param edge_type Edge type
*/
void append(handle_t const& handle,
void append(rmm::cuda_stream_view stream_view,
raft::host_span<vertex_t const> src,
raft::host_span<vertex_t const> dst,
std::optional<raft::host_span<weight_t const>> wgt,
Expand All @@ -131,7 +130,7 @@ class per_thread_edgelist_t {
edge_type.begin() + pos + copy_count,
edge_type_->begin() + current_pos_);

if (current_pos_ == src_.size()) { flush(handle); }
if (current_pos_ == src_.size()) { flush(stream_view); }

count -= copy_count;
pos += copy_count;
Expand All @@ -141,12 +140,14 @@ class per_thread_edgelist_t {
/**
* @brief Flush thread data from host to GPU memory
*
* @param handle The resource handle
* @param stream_view The cuda stream
* @param sync If true, synchronize the asynchronous copy of data;
* defaults to false.
*/
void flush(handle_t const& handle)
void flush(rmm::cuda_stream_view stream_view, bool sync = false)
{
edgelist_.append(
handle,
stream_view,
raft::host_span<vertex_t const>{src_.data(), current_pos_},
raft::host_span<vertex_t const>{dst_.data(), current_pos_},
wgt_ ? std::make_optional(raft::host_span<weight_t const>{wgt_->data(), current_pos_})
Expand All @@ -158,6 +159,8 @@ class per_thread_edgelist_t {
: std::nullopt);

current_pos_ = 0;

if (sync) stream_view.synchronize();
}

private:
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/mtmg/multi_node_threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ class Tests_Multithreaded

for (size_t j = starting_edge_offset; j < h_src_v.size(); j += stride) {
per_thread_edgelist.append(
thread_handle,
thread_handle.get_stream(),
h_src_v[j],
h_dst_v[j],
h_weights_v ? std::make_optional((*h_weights_v)[j]) : std::nullopt,
std::nullopt,
std::nullopt);
}

per_thread_edgelist.flush(thread_handle);
per_thread_edgelist.flush(thread_handle.get_stream());
});
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/mtmg/threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ class Tests_Multithreaded

for (size_t j = i; j < h_src_v.size(); j += num_threads) {
per_thread_edgelist.append(
thread_handle,
thread_handle.get_stream(),
h_src_v[j],
h_dst_v[j],
h_weights_v ? std::make_optional((*h_weights_v)[j]) : std::nullopt,
std::nullopt,
std::nullopt);
}

per_thread_edgelist.flush(thread_handle);
per_thread_edgelist.flush(thread_handle.get_stream());
});
}

Expand Down
Loading
Loading