Skip to content

Commit

Permalink
implement negative sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 1, 2021
1 parent 81fb4d1 commit 13468e3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
graph_indicator

include("transform.jl")
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph,
negative_sample

include("generate.jl")
export rand_graph
Expand Down
21 changes: 21 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,5 +324,26 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
end
end


"""
negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
Return a graph containing random negative edges (i.e. non-edges) from graph `g`.
"""
function negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
adj = adjacency_matrix(g)
adj_neg = 1 .- adj - I
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)
neg_eids = randperm(length(neg_s))[1:num_neg_edges]
neg_s, neg_t = neg_s[neg_eids], neg_t[neg_eids]
return GNNGraph(neg_s, neg_t, num_nodes=g.num_nodes)
end

# """
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
# """
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)

@non_differentiable negative_sample(x...)
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule

0 comments on commit 13468e3

Please sign in to comment.