Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MG neighborhood sampling and add SG implementation #2285

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
98fd1b2
move the current implementation of mg neighborhood sampling to proto
May 18, 2022
063e443
remove experimental prefix
May 18, 2022
e6ed994
refactor mg neighborhood sampling bindings
May 18, 2022
4581645
add and test mechanism for creating graph with edge index as weight
ChuckHastings May 19, 2022
57680f6
Merge mechanism for creating graph with edge index as weight
May 19, 2022
16cea30
rename create*_with_ids to create*_with_edge_ids
ChuckHastings May 19, 2022
2fd99b5
rename create*_with_ids to create*_with_edge_ids from Chuck
May 19, 2022
3f90963
update python bindings to create graph with edge index as weight
May 20, 2022
d460193
fix bug in MG case... cugraph_ops function doesn't handle an empty re…
ChuckHastings May 20, 2022
c7d0a11
merge bug fix in MG case by Chuck
May 20, 2022
e6210a2
Merge remote-tracking branch 'upstream/branch-22.06' into branch-22.0…
May 23, 2022
3b76a49
add bindings for SG uniform_neighbor_sample
May 24, 2022
e3b5fe4
Merge remote-tracking branch 'upstream/branch-22.06' into branch-22.0…
May 24, 2022
9933120
remove debug print
May 24, 2022
604ab0f
update pylibcugraph uniform_neighbor_sample tests because of the API …
May 24, 2022
c579261
drop the directory proto
May 24, 2022
626a833
enable support for weigths
May 25, 2022
6d91c39
remove debug prints, address PR comments
May 26, 2022
f038688
move uniform_neighbor_sample to stable API, convert edge_ids to weigh…
May 27, 2022
d97ce67
update uniform neighborhood sampling tests
May 30, 2022
c9482c2
merge latest change and update branch
jnke2016 May 30, 2022
720b05d
remove uniform neighbor sample older mechanism
May 30, 2022
78f7dd6
add end of line
May 30, 2022
ce97653
resolve merge conflict
jnke2016 Jun 1, 2022
7fdc09d
remove merge labels
jnke2016 Jun 1, 2022
b629ca9
remove outdated fixme
jnke2016 Jun 1, 2022
8a8f063
remove unused import
jnke2016 Jun 1, 2022
1af8e7b
add end of line
Jun 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions cpp/include/cugraph_c/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,30 @@ void* cugraph_type_erased_device_array_release(cugraph_type_erased_device_array_
cugraph_type_erased_device_array_view_t* cugraph_type_erased_device_array_view(
cugraph_type_erased_device_array_t* array);

/**
* @brief Create a type erased device array view with a different type
*
* Create a type erased device array view from
* a type erased device array treating the underlying
* pointer as a different type.
*
* Note: This is only viable when the underlying types are the same size. That
* is, you can switch between INT32 and FLOAT32, or between INT64 and FLOAT64.
* But if the types are different sizes this will be an error.
*
* @param [in] array Pointer to the type erased device array
* @param [in] dtype The type to cast the pointer to
* @param [out] result_view Address where to put the allocated device view
* @param [out] error Pointer to an error object storing details of any error. Will
* be populated if error code is not CUGRAPH_SUCCESS
* @return error code
*/
cugraph_error_code_t cugraph_type_erased_device_array_view_as_type(
cugraph_type_erased_device_array_t* array,
data_type_id_t dtype,
cugraph_type_erased_device_array_view_t** result_view,
cugraph_error_t** error);

/**
* @brief Create a type erased device array view from
* a raw device pointer.
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/c_api/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,28 @@ extern "C" cugraph_error_code_t cugraph_type_erased_device_array_view_copy(
return CUGRAPH_UNKNOWN_ERROR;
}
}

extern "C" cugraph_error_code_t cugraph_type_erased_device_array_view_as_type(
cugraph_type_erased_device_array_t* array,
data_type_id_t dtype,
cugraph_type_erased_device_array_view_t** result_view,
cugraph_error_t** error)
{
auto internal_pointer =
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_t*>(array);

if (data_type_sz[dtype] == data_type_sz[internal_pointer->type_]) {
*result_view = reinterpret_cast<cugraph_type_erased_device_array_view_t*>(
new cugraph::c_api::cugraph_type_erased_device_array_view_t{internal_pointer->data_.data(),
internal_pointer->size_,
internal_pointer->data_.size(),
dtype});
return CUGRAPH_SUCCESS;
} else {
std::stringstream ss;
ss << "Could not treat type " << internal_pointer->type_ << " as type " << dtype;
auto tmp_error = new cugraph::c_api::cugraph_error_t{ss.str().c_str()};
*error = reinterpret_cast<cugraph_error_t*>(tmp_error);
return CUGRAPH_INVALID_INPUT;
}
}
4 changes: 2 additions & 2 deletions cpp/src/sampling/detail/graph_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ rmm::device_uvector<typename GraphViewType::edge_type> get_active_major_global_d
template <typename GraphViewType>
std::tuple<rmm::device_uvector<typename GraphViewType::vertex_type>,
rmm::device_uvector<typename GraphViewType::vertex_type>,
thrust::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
std::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
gather_local_edges(
raft::handle_t const& handle,
GraphViewType const& graph_view,
Expand All @@ -169,7 +169,7 @@ gather_local_edges(
template <typename GraphViewType>
std::tuple<rmm::device_uvector<typename GraphViewType::vertex_type>,
rmm::device_uvector<typename GraphViewType::vertex_type>,
thrust::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
std::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
gather_one_hop_edgelist(
raft::handle_t const& handle,
GraphViewType const& graph_view,
Expand Down
24 changes: 17 additions & 7 deletions cpp/src/sampling/detail/sampling_utils_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ partition_information(raft::handle_t const& handle, GraphViewType const& graph_v
template <typename GraphViewType>
std::tuple<rmm::device_uvector<typename GraphViewType::vertex_type>,
rmm::device_uvector<typename GraphViewType::vertex_type>,
thrust::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
std::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
gather_local_edges(
raft::handle_t const& handle,
GraphViewType const& graph_view,
Expand All @@ -451,10 +451,10 @@ gather_local_edges(

rmm::device_uvector<vertex_t> majors(edge_count, handle.get_stream());
rmm::device_uvector<vertex_t> minors(edge_count, handle.get_stream());
thrust::optional<rmm::device_uvector<weight_t>> weights =
std::optional<rmm::device_uvector<weight_t>> weights =
graph_view.is_weighted()
? thrust::make_optional(rmm::device_uvector<weight_t>(edge_count, handle.get_stream()))
: thrust::nullopt;
? std::make_optional(rmm::device_uvector<weight_t>(edge_count, handle.get_stream()))
: std::nullopt;

// FIXME: This should be the global constant
vertex_t invalid_vertex_id = graph_view.number_of_vertices();
Expand All @@ -477,6 +477,7 @@ gather_local_edges(
glbl_adj_list_offsets = global_adjacency_list_offsets.data(),
majors = majors.data(),
minors = minors.data(),
weights = weights ? weights->data() : nullptr,
partitions = partitions.data(),
hypersparse_begin = hypersparse_begin.data(),
invalid_vertex_id,
Expand Down Expand Up @@ -524,6 +525,10 @@ gather_local_edges(
(g_dst_index < g_degree_offset + local_out_degree)) {
minors[index] = adjacency_list[g_dst_index - g_degree_offset];
edge_index_first[index] = g_dst_index - g_degree_offset + glbl_adj_list_offsets[location];
if (weights != nullptr) {
weight_t const* edge_weights = *(partitions[partition_id].weights()) + sparse_offset;
weights[index] = edge_weights[g_dst_index];
}
} else {
minors[index] = invalid_vertex_id;
}
Expand All @@ -542,6 +547,7 @@ gather_local_edges(
glbl_degree_offsets = global_degree_offsets.data(),
majors = majors.data(),
minors = minors.data(),
weights = weights ? weights->data() : nullptr,
partitions = partitions.data(),
hypersparse_begin = hypersparse_begin.data(),
invalid_vertex_id,
Expand Down Expand Up @@ -585,7 +591,11 @@ gather_local_edges(
auto location = location_in_segment + vertex_count_offsets[partition_id];
auto g_dst_index = edge_index_first[index];
if (g_dst_index >= 0) {
minors[index] = adjacency_list[g_dst_index];
minors[index] = adjacency_list[g_dst_index];
if (weights != nullptr) {
weight_t const* edge_weights = *(partitions[partition_id].weights()) + sparse_offset;
weights[index] = edge_weights[g_dst_index];
}
edge_index_first[index] = g_dst_index;
} else {
minors[index] = invalid_vertex_id;
Expand Down Expand Up @@ -758,7 +768,7 @@ void local_major_degree(
template <typename GraphViewType>
std::tuple<rmm::device_uvector<typename GraphViewType::vertex_type>,
rmm::device_uvector<typename GraphViewType::vertex_type>,
thrust::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
std::optional<rmm::device_uvector<typename GraphViewType::weight_type>>>
gather_one_hop_edgelist(
raft::handle_t const& handle,
GraphViewType const& graph_view,
Expand All @@ -771,7 +781,7 @@ gather_one_hop_edgelist(
rmm::device_uvector<vertex_t> majors(0, handle.get_stream());
rmm::device_uvector<vertex_t> minors(0, handle.get_stream());

auto weights = thrust::make_optional<rmm::device_uvector<weight_t>>(0, handle.get_stream());
auto weights = std::make_optional<rmm::device_uvector<weight_t>>(0, handle.get_stream());

if constexpr (GraphViewType::is_multi_gpu == true) {
std::vector<std::vector<vertex_t>> active_majors_segments;
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/sampling/detail/sampling_utils_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ partition_information(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, float, false, true> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -191,7 +191,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, float, false, true> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -202,7 +202,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, float, false, true> const& graph_view,
const rmm::device_uvector<int64_t>& active_majors,
Expand All @@ -213,7 +213,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, double, false, true> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -224,7 +224,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, double, false, true> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -235,7 +235,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, double, false, true> const& graph_view,
const rmm::device_uvector<int64_t>& active_majors,
Expand All @@ -246,42 +246,42 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, float, false, true> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, float, false, true> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, float, false, true> const& graph_view,
rmm::device_uvector<int64_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, double, false, true> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, double, false, true> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, double, false, true> const& graph_view,
rmm::device_uvector<int64_t> const& active_majors);
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/sampling/detail/sampling_utils_sg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ template rmm::device_uvector<int64_t> get_active_major_global_degrees(

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, float, false, false> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -134,7 +134,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, float, false, false> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -145,7 +145,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, float, false, false> const& graph_view,
const rmm::device_uvector<int64_t>& active_majors,
Expand All @@ -156,7 +156,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, double, false, false> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -167,7 +167,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, double, false, false> const& graph_view,
const rmm::device_uvector<int32_t>& active_majors,
Expand All @@ -178,7 +178,7 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_local_edges(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, double, false, false> const& graph_view,
const rmm::device_uvector<int64_t>& active_majors,
Expand All @@ -189,42 +189,42 @@ gather_local_edges(raft::handle_t const& handle,

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, float, false, false> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, float, false, false> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<float>>>
std::optional<rmm::device_uvector<float>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, float, false, false> const& graph_view,
rmm::device_uvector<int64_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int32_t, double, false, false> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int32_t, int64_t, double, false, false> const& graph_view,
rmm::device_uvector<int32_t> const& active_majors);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
thrust::optional<rmm::device_uvector<double>>>
std::optional<rmm::device_uvector<double>>>
gather_one_hop_edgelist(raft::handle_t const& handle,
graph_view_t<int64_t, int64_t, double, false, false> const& graph_view,
rmm::device_uvector<int64_t> const& active_majors);
Expand Down
Loading