Skip to content

Commit

Permalink
Make Leiden runnable (still doesn't work correctly)
Browse files Browse the repository at this point in the history
  • Loading branch information
gitbuda committed Sep 15, 2024
1 parent bda68b0 commit 5dae101
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
4 changes: 1 addition & 3 deletions cpp/cugraph_module/algorithms/betweenness_centrality.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
#include "mg_cugraph_utility.hpp"

namespace {
// TODO: Check Betweenness instances. Update in new cuGraph API.
using vertex_t = int64_t;
using edge_t = int64_t;
using weight_t = double;
using result_t = double;

constexpr char const *kProcedureBetweenness = "get";

Expand Down Expand Up @@ -62,7 +60,7 @@ void BetweennessProc(mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_m
auto mg_graph =
mg_utility::GetWeightedGraphView(graph, result, memory, mg_graph_type, weight_property, kDefaultWeight);
if (mg_graph->Empty()) return;
auto cu_graph = mg_cugraph::CreateCugraphFromMemgraph<vertex_t, edge_t, result_t, false, false>(
auto cu_graph = mg_cugraph::CreateCugraphFromMemgraph<vertex_t, edge_t, weight_t, false, false>(
*mg_graph.get(), mg_graph_type, handle);
auto cu_graph_view = cu_graph.view();

Expand Down
17 changes: 7 additions & 10 deletions cpp/cugraph_module/algorithms/leiden.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
#include "mg_cugraph_utility.hpp"

namespace {
// TODO: Check Leiden instances. Update in new cuGraph API.
using vertex_t = int32_t;
using edge_t = int32_t;
using vertex_t = int64_t;
using edge_t = int64_t;
using weight_t = double;

constexpr char const *kProcedureLeiden = "get";
Expand Down Expand Up @@ -50,25 +49,23 @@ void LeidenProc(mgp_list *args, mgp_graph *graph, mgp_result *result, mgp_memory
auto max_iterations = mgp::value_get_int(mgp::list_at(args, 0));
auto resulution = mgp::value_get_double(mgp::list_at(args, 1));

auto mg_graph = mg_utility::GetGraphView(graph, result, memory, mg_graph::GraphType::kUndirectedGraph);
if (mg_graph->Empty()) return;

// Define handle and operation stream
raft::handle_t handle{};
auto stream = handle.get_stream();
auto mg_graph_type = mg_graph::GraphType::kUndirectedGraph;
auto mg_graph = mg_utility::GetGraphView(graph, result, memory, mg_graph_type);
if (mg_graph->Empty()) return;
// TODO(gitbuda): Inject the valid seed.
raft::random::RngState rng_state(0);

auto cu_graph = mg_cugraph::CreateCugraphFromMemgraph<vertex_t, edge_t, weight_t, false, false>(
*mg_graph.get(), mg_graph::GraphType::kUndirectedGraph, handle);
*mg_graph.get(), mg_graph_type, handle);
auto cu_graph_view = cu_graph.view();
auto n_vertices = cu_graph_view.number_of_vertices();

rmm::device_uvector<vertex_t> clustering_result(n_vertices, stream);
// TODO(gitbuda): Leiden weights and other arguments. Add theta argument.
cugraph::leiden<vertex_t, edge_t, weight_t, false>(handle, rng_state, cu_graph_view, std::nullopt,
clustering_result.data(), (size_t)max_iterations, resulution,
1.0);

for (vertex_t node_id = 0; node_id < clustering_result.size(); ++node_id) {
auto partition = clustering_result.element(node_id, stream);
InsertLeidenRecord(graph, result, memory, mg_graph->GetMemgraphNodeId(node_id), partition);
Expand Down

0 comments on commit 5dae101

Please sign in to comment.