From 13468e38e7da313f226eb5eb4e602502a7cff37e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 1 Nov 2021 20:05:45 +0100 Subject: [PATCH] implement negative sampling --- src/GNNGraphs/GNNGraphs.jl | 3 ++- src/GNNGraphs/transform.jl | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 51e8891c6..d30af6c18 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -23,7 +23,8 @@ export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian, graph_indicator include("transform.jl") -export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph +export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph, + negative_sample include("generate.jl") export rand_graph diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 825b22eed..ff86a0720 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -324,5 +324,26 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false) end end + +""" + negative_sample(g::GNNGraph; num_neg_edges=g.num_edges) + +Return a graph containing random negative edges (i.e. non-edges) from graph `g`. +""" +function negative_sample(g::GNNGraph; num_neg_edges=g.num_edges) + adj = adjacency_matrix(g) + adj_neg = 1 .- adj - I + neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2) + neg_eids = randperm(length(neg_s))[1:num_neg_edges] + neg_s, neg_t = neg_s[neg_eids], neg_t[neg_eids] + return GNNGraph(neg_s, neg_t, num_nodes=g.num_nodes) +end + +# """ +# Transform vector of cartesian indexes into a tuple of vectors containing integers. +# """ +ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) + +@non_differentiable negative_sample(x...) @non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule @non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule