Skip to content

Commit

Permalink
rand edge split
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 6, 2021
1 parent 446aba5 commit a3cd35b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
6 changes: 2 additions & 4 deletions examples/link_prediction_pubmed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ using Flux
using Flux: onecold, onehotbatch
using Flux.Losses: logitbinarycrossentropy
using GraphNeuralNetworks
using MLDatasets: PubMed, Cora
using MLDatasets: PubMed
using Statistics, Random, LinearAlgebra
using CUDA
# using MLJBase: AreaUnderCurve
CUDA.allowscalar(false)

# arguments for the `train` function
Expand Down Expand Up @@ -47,8 +46,7 @@ function train(; kws...)
end

### LOAD DATA
data = Cora.dataset()
# data = PubMed.dataset()
data = PubMed.dataset()
g = GNNGraph(data.adjacency_list)

# Print some info
Expand Down
2 changes: 1 addition & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseArrays
using Functors: @functor
using CUDA
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops
import Flux
using Flux: batch
import NNlib
Expand Down
37 changes: 20 additions & 17 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,30 +386,33 @@ and `g2`. Both will have the same number of nodes as `g`.
while `g2` wil contain the rest.
If `bidirected = true` makes sure that an edge and its reverse go into the same split.
This option is supported only for bidirected graphs with no self-loops
and multi-edges.
Useful for train/test splits in link prediction tasks.
`rand_edge_split` is tipically used to create train/test splits in link prediction tasks.
"""
function rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g))
s, t = edge_index(g)
idx, idmax = edge_encoding(s, t, g.num_nodes, directed=!bidirected)
uidx = union(idx) # So that multi-edges (and reverse edges in the bidir case) go in the same split
nu = length(uidx)
eids = randperm(nu)
size1 = round(Int, nu * frac)
ne = bidirected ? g.num_edges ÷ 2 : g.num_edges
eids = randperm(ne)
size1 = round(Int, ne * frac)

s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)

s, t = edge_index(g)
eids = randperm(g.num_edges)
size1 = round(Int, g.num_edges * frac)

s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
if !bidirected
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
else
@assert is_bidirected(g)
@assert !has_self_loops(g)
@assert !has_multi_edges(g)
mask = s .< t
s, t = s[mask], t[mask]
s1, t1 = s[eids[1:size1]], t[eids[1:size1]]
s1, t1 = [s1; t1], [t1; s1]
s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
s2, t2 = [s2; t2], [t2; s2]
end
g1 = GNNGraph(s1, t1, num_nodes=g.num_nodes)

s2, t2 = s[eids[size1+1:end]], t[eids[size1+1:end]]
g2 = GNNGraph(s2, t2, num_nodes=g.num_nodes)

return g1, g2
end

Expand Down
31 changes: 31 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,35 @@
@test intersect(g, gneg).num_edges == 0
end
end

@testset "rand_edge_split" begin
if GRAPH_T == :coo
n, m = 100,300

g = rand_graph(n, m, bidirected=true, graph_type=GRAPH_T)
# check bidirected=is_bidirected(g) default
g1, g2 = rand_edge_split(g, 0.9)
@test is_bidirected(g1)
@test is_bidirected(g2)
@test intersect(g1, g2).num_edges == 0
@test g1.num_edges + g2.num_edges == g.num_edges
@test g2.num_edges < 50

g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
# check bidirected=is_bidirected(g) default
g1, g2 = rand_edge_split(g, 0.9)
@test !is_bidirected(g1)
@test !is_bidirected(g2)
@test intersect(g1, g2).num_edges == 0
@test g1.num_edges + g2.num_edges == g.num_edges
@test g2.num_edges < 50

g1, g2 = rand_edge_split(g, 0.9, bidirected=false)
@test !is_bidirected(g1)
@test !is_bidirected(g2)
@test intersect(g1, g2).num_edges == 0
@test g1.num_edges + g2.num_edges == g.num_edges
@test g2.num_edges < 50
end
end
end

0 comments on commit a3cd35b

Please sign in to comment.