Skip to content

Commit

Permalink
implement negative sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 1, 2021
1 parent 53f79cd commit 81fb4d1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 45 deletions.
3 changes: 2 additions & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py

using Flux
# Link prediction task
# https://arxiv.org/pdf/2102.12557.pdf

using Flux: onecold, onehotbatch
using Flux.Losses: logitbinarycrossentropy
using GraphNeuralNetworks
using GraphNeuralNetworks: ones_like, zeros_like
using MLDatasets: Cora
using MLDatasets: PubMed, Cora
using Statistics, Random, LinearAlgebra
using CUDA
using MLJBase: AreaUnderCurve
# using MLJBase: AreaUnderCurve
CUDA.allowscalar(false)

"""
Transform vector of cartesian indexes into a tuple of vectors containing integers.
"""
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)

# arguments for the `train` function
Base.@kwdef mutable struct Args
η = 1f-3 # learning rate
Expand All @@ -34,6 +31,8 @@ function (::DotPredictor)(g, x)
return vec(z)
end

using ChainRulesCore

function train(; kws...)
# args = Args(; kws...)
args = Args()
Expand All @@ -54,75 +53,67 @@ function train(; kws...)
g = GNNGraph(data.adjacency_list) |> device
X = data.node_features |> device


#### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
# Split edge set for training and testing
s, t = edge_index(g)
eids = randperm(g.num_edges)
test_size = round(Int, g.num_edges * 0.1)
train_size = g.num_edges - test_size

test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]]
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]

# Find all negative edges and split them for training and testing
adj = adjacency_matrix(g)
adj_neg = 1 .- adj - I
neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2)

neg_eids = randperm(length(neg_s))[1:g.num_edges]
test_neg_s, test_neg_t = neg_s[neg_eids[1:test_size]], neg_t[neg_eids[1:test_size]]
train_neg_s, train_neg_t = neg_s[neg_eids[test_size+1:end]], neg_t[neg_eids[test_size+1:end]]
# train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes)

train_pos_g = GNNGraph((train_pos_s, train_pos_t), num_nodes=g.num_nodes)
train_neg_g = GNNGraph((train_neg_s, train_neg_t), num_nodes=g.num_nodes)
train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]]
train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes)

test_pos_g = GNNGraph((test_pos_s, test_pos_t), num_nodes=g.num_nodes)
test_neg_g = GNNGraph((test_neg_s, test_neg_t), num_nodes=g.num_nodes)
test_neg_g = negative_sample(g, num_neg_edges=test_size)

@show train_pos_g test_pos_g train_neg_g test_neg_g

### DEFINE MODEL
### DEFINE MODEL #########
nin, nhidden = size(X,1), args.nhidden

model = GNNChain(GCNConv(nin => nhidden, relu),
GCNConv(nhidden => nhidden)) |> device
model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu),
GCNConv(nhidden => nhidden)),
train_pos_g) |> device

pred = DotPredictor()

ps = Flux.params(model)
opt = ADAM(args.η)

### LOSS FUNCTION
### LOSS FUNCTION ############

function loss(pos_g, neg_g)
h = model(train_pos_g, X)
function loss(pos_g, neg_g = nothing)
h = model(X)
if neg_g === nothing
# we sample a negative graph at each training step
neg_g = negative_sample(pos_g)
end
pos_score = pred(pos_g, h)
neg_score = pred(neg_g, h)
scores = [pos_score; neg_score]
labels = [ones_like(pos_score); zeros_like(neg_score)]
labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
return logitbinarycrossentropy(scores, labels)
end

function accuracy(pos_g, neg_g)
h = model(train_pos_g, X)
pos_score = pred(pos_g, h)
neg_score = pred(neg_g, h)
scores = [pos_score; neg_score]
labels = [ones_like(pos_score); zeros_like(neg_score)]
return logitbinarycrossentropy(scores, labels)
end
# function accuracy(pos_g, neg_g)
# h = model(train_pos_g, X)
# pos_score = pred(pos_g, h)
# neg_score = pred(neg_g, h)
# scores = [pos_score; neg_score]
# labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
# return logitbinarycrossentropy(scores, labels)
# end

### LOGGING FUNCTION
function report(epoch)
train_loss = loss(train_pos_g, train_neg_g)
train_loss = loss(train_pos_g)
test_loss = loss(test_pos_g, test_neg_g)
println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)")
end

### TRAINING
report(0)
for epoch in 1:args.epochs
gs = Flux.gradient(() -> loss(train_pos_g, train_neg_g), ps)
gs = Flux.gradient(() -> loss(train_pos_g), ps)
Flux.Optimise.update!(opt, ps, gs)
epoch % args.infotime == 0 && report(epoch)
end
Expand Down

0 comments on commit 81fb4d1

Please sign in to comment.