Skip to content

Commit

Permalink
Merge branch 'concurrency' of github.com:afender/cugraph into concurr…
Browse files Browse the repository at this point in the history
…ency
  • Loading branch information
afender committed Mar 4, 2021
2 parents 412f377 + b8f0605 commit c47818d
Show file tree
Hide file tree
Showing 52 changed files with 327 additions and 326 deletions.
26 changes: 15 additions & 11 deletions cpp/include/compute_partition.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,27 +39,32 @@ class compute_partition_t {
using graph_view_t = graph_view_type;
using vertex_t = typename graph_view_type::vertex_type;

compute_partition_t(graph_view_t const &graph_view)
compute_partition_t(raft::handle_t const &handle, graph_view_t const &graph_view)
: vertex_partition_offsets_v_(0, handle.get_stream())
{
init<graph_view_t::is_multi_gpu>(graph_view);
init<graph_view_t::is_multi_gpu>(handle, graph_view);
}

private:
template <bool is_multi_gpu, typename std::enable_if_t<!is_multi_gpu> * = nullptr>
void init(graph_view_t const &graph_view)
void init(raft::handle_t const &handle, graph_view_t const &graph_view)
{
}

template <bool is_multi_gpu, typename std::enable_if_t<is_multi_gpu> * = nullptr>
void init(graph_view_t const &graph_view)
void init(raft::handle_t const &handle, graph_view_t const &graph_view)
{
auto partition = graph_view.get_partition();
row_size_ = partition.get_row_size();
col_size_ = partition.get_col_size();
size_ = row_size_ * col_size_;

vertex_partition_offsets_v_.resize(size_ + 1);
vertex_partition_offsets_v_ = partition.get_vertex_partition_offsets();
vertex_partition_offsets_v_.resize(size_ + 1, handle.get_stream());
auto vertex_partition_offsets = partition.get_vertex_partition_offsets();
raft::update_device(vertex_partition_offsets_v_.data(),
vertex_partition_offsets.data(),
vertex_partition_offsets.size(),
handle.get_stream());
}

public:
Expand Down Expand Up @@ -166,7 +171,7 @@ class compute_partition_t {
*/
vertex_device_view_t vertex_device_view() const
{
return vertex_device_view_t(vertex_partition_offsets_v_.data().get(), size_);
return vertex_device_view_t(vertex_partition_offsets_v_.data(), size_);
}

/**
Expand All @@ -176,12 +181,11 @@ class compute_partition_t {
*/
edge_device_view_t edge_device_view() const
{
return edge_device_view_t(
vertex_partition_offsets_v_.data().get(), row_size_, col_size_, size_);
return edge_device_view_t(vertex_partition_offsets_v_.data(), row_size_, col_size_, size_);
}

private:
rmm::device_vector<vertex_t> vertex_partition_offsets_v_{};
rmm::device_uvector<vertex_t> vertex_partition_offsets_v_;
int row_size_{1};
int col_size_{1};
int size_{1};
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/patterns/count_if_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ typename GraphViewType::edge_type count_if_e(
detail::count_if_e_for_all_block_size,
handle.get_device_properties().maxGridSize[0]);

rmm::device_vector<edge_t> block_counts(update_grid.num_blocks);
rmm::device_uvector<edge_t> block_counts(update_grid.num_blocks, handle.get_stream());

detail::for_all_major_for_all_nbr_low_degree<<<update_grid.num_blocks,
update_grid.block_size,
Expand All @@ -210,7 +210,7 @@ typename GraphViewType::edge_type count_if_e(
matrix_partition,
adj_matrix_row_value_input_first + row_value_input_offset,
adj_matrix_col_value_input_first + col_value_input_offset,
block_counts.data().get(),
block_counts.data(),
e_op);

// FIXME: we have several options to implement this. With cooperative group support
Expand Down
11 changes: 6 additions & 5 deletions cpp/include/patterns/transform_reduce_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ T transform_reduce_e(raft::handle_t const& handle,
detail::transform_reduce_e_for_all_block_size,
handle.get_device_properties().maxGridSize[0]);

rmm::device_vector<T> block_results(update_grid.num_blocks);
auto block_result_buffer =
allocate_dataframe_buffer<T>(update_grid.num_blocks, handle.get_stream());

detail::for_all_major_for_all_nbr_low_degree<<<update_grid.num_blocks,
update_grid.block_size,
Expand All @@ -215,7 +216,7 @@ T transform_reduce_e(raft::handle_t const& handle,
matrix_partition,
adj_matrix_row_value_input_first + row_value_input_offset,
adj_matrix_col_value_input_first + col_value_input_offset,
block_results.data(),
get_dataframe_buffer_begin<T>(block_result_buffer),
e_op);

// FIXME: we have several options to implement this. With cooperative group support
Expand All @@ -225,10 +226,10 @@ T transform_reduce_e(raft::handle_t const& handle,
// synchronization point in varying timings and the number of SMs is not very big)
auto partial_result =
thrust::reduce(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
block_results.begin(),
block_results.end(),
get_dataframe_buffer_begin<T>(block_result_buffer),
get_dataframe_buffer_begin<T>(block_result_buffer) + update_grid.num_blocks,
T(),
[] __device__(auto lhs, auto rhs) { return plus_edge_op_result(lhs, rhs); });
[] __device__(T lhs, T rhs) { return plus_edge_op_result(lhs, rhs); });

result = plus_edge_op_result(result, partial_result);
}
Expand Down
26 changes: 14 additions & 12 deletions cpp/include/patterns/update_frontier_v_push_if_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <partition_manager.hpp>
#include <patterns/edge_op_utils.cuh>
#include <patterns/reduce_op.cuh>
#include <utilities/dataframe_buffer.cuh>
#include <utilities/device_comm.cuh>
#include <utilities/error.hpp>
#include <utilities/host_scalar_comm.cuh>
Expand Down Expand Up @@ -157,13 +158,14 @@ size_t reduce_buffer_elements(raft::handle_t const& handle,
// FIXME: if GraphViewType::is_multi_gpu is true, this should be executed on the GPU holding the
// vertex unless reduce_op is a pure function.
rmm::device_uvector<key_t> keys(num_buffer_elements, handle.get_stream());
rmm::device_vector<payload_t> values(num_buffer_elements);
auto value_buffer =
allocate_dataframe_buffer<payload_t>(num_buffer_elements, handle.get_stream());
auto it = thrust::reduce_by_key(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
buffer_key_output_first,
buffer_key_output_first + num_buffer_elements,
buffer_payload_output_first,
keys.begin(),
values.begin(),
get_dataframe_buffer_begin<payload_t>(value_buffer),
thrust::equal_to<key_t>(),
reduce_op);
auto num_reduced_buffer_elements =
Expand All @@ -173,13 +175,9 @@ size_t reduce_buffer_elements(raft::handle_t const& handle,
keys.begin() + num_reduced_buffer_elements,
buffer_key_output_first);
thrust::copy(rmm::exec_policy(handle.get_stream())->on(handle.get_stream()),
values.begin(),
values.begin() + num_reduced_buffer_elements,
get_dataframe_buffer_begin<payload_t>(value_buffer),
get_dataframe_buffer_begin<payload_t>(value_buffer) + num_reduced_buffer_elements,
buffer_payload_output_first);
// FIXME: this is unecessary if we use a tuple of rmm::device_uvector objects for values
CUDA_TRY(
cudaStreamSynchronize(handle.get_stream())); // this is necessary as values will become
// out-of-scope once this function returns
return num_reduced_buffer_elements;
}
}
Expand Down Expand Up @@ -673,15 +671,19 @@ void update_frontier_v_push_if_out_nbr(
num_buffer_elements,
vertex_value_input_first,
vertex_value_output_first,
std::get<0>(bucket_and_bucket_size_device_ptrs).get(),
std::get<1>(bucket_and_bucket_size_device_ptrs).get(),
std::get<0>(bucket_and_bucket_size_device_ptrs),
std::get<1>(bucket_and_bucket_size_device_ptrs),
VertexFrontierType::kInvalidBucketIdx,
invalid_vertex,
v_op);

auto bucket_sizes_device_ptr = std::get<1>(bucket_and_bucket_size_device_ptrs);
thrust::host_vector<size_t> bucket_sizes(
bucket_sizes_device_ptr, bucket_sizes_device_ptr + VertexFrontierType::kNumBuckets);
std::vector<size_t> bucket_sizes(VertexFrontierType::kNumBuckets);
raft::update_host(bucket_sizes.data(),
bucket_sizes_device_ptr,
VertexFrontierType::kNumBuckets,
handle.get_stream());
CUDA_TRY(cudaStreamSynchronize(handle.get_stream()));
for (size_t i = 0; i < VertexFrontierType::kNumBuckets; ++i) {
vertex_frontier.get_bucket(i).set_size(bucket_sizes[i]);
}
Expand Down
51 changes: 34 additions & 17 deletions cpp/include/patterns/vertex_frontier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,17 @@ template <typename vertex_t, bool is_multi_gpu = false>
class Bucket {
public:
Bucket(raft::handle_t const& handle, size_t capacity)
: handle_ptr_(&handle), elements_(capacity, invalid_vertex_id<vertex_t>::value)
: handle_ptr_(&handle), elements_(capacity, handle.get_stream())
{
thrust::fill(rmm::exec_policy(handle_ptr_->get_stream())->on(handle_ptr_->get_stream()),
elements_.begin(),
elements_.end(),
invalid_vertex_id<vertex_t>::value);
}

void insert(vertex_t v)
{
elements_[size_] = v;
raft::update_device(elements_.data() + size_, &v, 1, handle_ptr_->get_stream());
++size_;
}

Expand All @@ -177,9 +181,9 @@ class Bucket {

size_t capacity() const { return elements_.size(); }

auto const data() const { return elements_.data().get(); }
auto const data() const { return elements_.data(); }

auto data() { return elements_.data().get(); }
auto data() { return elements_.data(); }

auto const begin() const { return elements_.begin(); }

Expand All @@ -191,7 +195,7 @@ class Bucket {

private:
raft::handle_t const* handle_ptr_{nullptr};
rmm::device_vector<vertex_t> elements_{};
rmm::device_uvector<vertex_t> elements_;
size_t size_{0};
};

Expand All @@ -206,13 +210,21 @@ class VertexFrontier {

VertexFrontier(raft::handle_t const& handle, std::vector<size_t> bucket_capacities)
: handle_ptr_(&handle),
tmp_bucket_ptrs_(num_buckets, nullptr),
tmp_bucket_sizes_(num_buckets, 0),
tmp_bucket_ptrs_(num_buckets, handle.get_stream()),
tmp_bucket_sizes_(num_buckets, handle.get_stream()),
buffer_ptrs_(kReduceInputTupleSize + 1 /* to store destination column number */, nullptr),
buffer_idx_(0, handle_ptr_->get_stream())
{
CUGRAPH_EXPECTS(bucket_capacities.size() == num_buckets,
"invalid input argument bucket_capacities (size mismatch)");
thrust::fill(rmm::exec_policy(handle_ptr_->get_stream())->on(handle_ptr_->get_stream()),
tmp_bucket_ptrs_.begin(),
tmp_bucket_ptrs_.end(),
static_cast<vertex_t*>(nullptr));
thrust::fill(rmm::exec_policy(handle_ptr_->get_stream())->on(handle_ptr_->get_stream()),
tmp_bucket_sizes_.begin(),
tmp_bucket_sizes_.end(),
size_t{0});
for (size_t i = 0; i < num_buckets; ++i) {
buckets_.emplace_back(handle, bucket_capacities[i]);
}
Expand Down Expand Up @@ -251,8 +263,8 @@ class VertexFrontier {
0,
handle_ptr_->get_stream()>>>(this_bucket.begin(),
this_bucket.end(),
std::get<0>(bucket_and_bucket_size_device_ptrs).get(),
std::get<1>(bucket_and_bucket_size_device_ptrs).get(),
std::get<0>(bucket_and_bucket_size_device_ptrs),
std::get<1>(bucket_and_bucket_size_device_ptrs),
bucket_idx,
kInvalidBucketIdx,
invalid_vertex,
Expand All @@ -269,8 +281,10 @@ class VertexFrontier {
[] __device__(auto value) { return value == invalid_vertex; });
auto bucket_sizes_device_ptr = std::get<1>(bucket_and_bucket_size_device_ptrs);
thrust::host_vector<size_t> bucket_sizes(bucket_sizes_device_ptr,
bucket_sizes_device_ptr + kNumBuckets);
std::vector<size_t> bucket_sizes(kNumBuckets);
raft::update_host(
bucket_sizes.data(), bucket_sizes_device_ptr, kNumBuckets, handle_ptr_->get_stream());
CUDA_TRY(cudaStreamSynchronize(handle_ptr_->get_stream()));
for (size_t i = 0; i < kNumBuckets; ++i) {
if (i != bucket_idx) { get_bucket(i).set_size(bucket_sizes[i]); }
}
Expand All @@ -283,14 +297,17 @@ class VertexFrontier {
auto get_bucket_and_bucket_size_device_pointers()
{
thrust::host_vector<vertex_t*> tmp_ptrs(buckets_.size(), nullptr);
thrust::host_vector<size_t> tmp_sizes(buckets_.size(), 0);
std::vector<vertex_t*> tmp_ptrs(buckets_.size(), nullptr);
std::vector<size_t> tmp_sizes(buckets_.size(), 0);
for (size_t i = 0; i < buckets_.size(); ++i) {
tmp_ptrs[i] = get_bucket(i).data();
tmp_sizes[i] = get_bucket(i).size();
}
tmp_bucket_ptrs_ = tmp_ptrs;
tmp_bucket_sizes_ = tmp_sizes;
raft::update_device(
tmp_bucket_ptrs_.data(), tmp_ptrs.data(), tmp_ptrs.size(), handle_ptr_->get_stream());
raft::update_device(
tmp_bucket_sizes_.data(), tmp_sizes.data(), tmp_sizes.size(), handle_ptr_->get_stream());
CUDA_TRY(cudaStreamSynchronize(handle_ptr_->get_stream()));
return std::make_tuple(tmp_bucket_ptrs_.data(), tmp_bucket_sizes_.data());
}
Expand Down Expand Up @@ -345,8 +362,8 @@ class VertexFrontier {
raft::handle_t const* handle_ptr_{nullptr};
std::vector<Bucket<vertex_t, is_multi_gpu>> buckets_{};
rmm::device_vector<vertex_t*> tmp_bucket_ptrs_{};
rmm::device_vector<size_t> tmp_bucket_sizes_{};
rmm::device_uvector<vertex_t*> tmp_bucket_ptrs_;
rmm::device_uvector<size_t> tmp_bucket_sizes_;
std::array<size_t, kReduceInputTupleSize> tuple_element_sizes_ =
compute_thrust_tuple_element_sizes<ReduceInputTupleType>()();
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/experimental/louvain.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class Louvain {
handle_(handle),
dendrogram_(std::make_unique<Dendrogram<vertex_t>>()),
current_graph_view_(graph_view),
compute_partition_(graph_view),
compute_partition_(handle, graph_view),
local_num_vertices_(graph_view.get_number_of_local_vertices()),
local_num_rows_(graph_view.get_number_of_local_adj_matrix_partition_rows()),
local_num_cols_(graph_view.get_number_of_local_adj_matrix_partition_cols()),
Expand Down
8 changes: 1 addition & 7 deletions python/cugraph/centrality/betweenness_centrality_wrapper.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -17,18 +17,12 @@
# cython: language_level = 3

from cugraph.centrality.betweenness_centrality cimport betweenness_centrality as c_betweenness_centrality
from cugraph.centrality.betweenness_centrality cimport handle_t
from cugraph.structure.graph import DiGraph
from cugraph.structure.graph_primtypes cimport *
from libc.stdint cimport uintptr_t
from libcpp cimport bool
import cudf
import numpy as np
import numpy.ctypeslib as ctypeslib

import dask_cudf
import dask_cuda

import cugraph.comms.comms as Comms
from cugraph.dask.common.mg_utils import get_client
import dask.distributed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -24,8 +24,6 @@ from libc.stdint cimport uintptr_t
from libcpp cimport bool
import cudf
import numpy as np
import numpy.ctypeslib as ctypeslib

from cugraph.dask.common.mg_utils import get_client
import cugraph.comms.comms as Comms
import dask.distributed
Expand Down
4 changes: 2 additions & 2 deletions python/cugraph/centrality/katz_centrality.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -16,7 +16,7 @@
# cython: embedsignature = True
# cython: language_level = 3

from cugraph.structure.graph_primtypes cimport *
from cugraph.structure.graph_utilities cimport *
from libcpp cimport bool

cdef extern from "utilities/cython.hpp" namespace "cugraph::cython":
Expand Down
7 changes: 2 additions & 5 deletions python/cugraph/centrality/katz_centrality_wrapper.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -17,13 +17,10 @@
# cython: language_level = 3

from cugraph.centrality.katz_centrality cimport call_katz_centrality
from cugraph.structure.graph_primtypes cimport *
from cugraph.structure.graph_utilities cimport *
from cugraph.structure import graph_primtypes_wrapper
from libcpp cimport bool
from libc.stdint cimport uintptr_t

import cudf
import rmm
import numpy as np


Expand Down
Loading

0 comments on commit c47818d

Please sign in to comment.