From 81fb4d16a29e097027c23ca14638b957ecfefd0b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 1 Nov 2021 20:05:22 +0100 Subject: [PATCH] implement negative sampling --- examples/Project.toml | 3 +- ...tion_cora.jl => link_prediction_pubmed.jl} | 79 ++++++++----------- 2 files changed, 37 insertions(+), 45 deletions(-) rename examples/{link_prediction_cora.jl => link_prediction_pubmed.jl} (54%) diff --git a/examples/Project.toml b/examples/Project.toml index b4a89fa64..3d950f665 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,9 +1,10 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" diff --git a/examples/link_prediction_cora.jl b/examples/link_prediction_pubmed.jl similarity index 54% rename from examples/link_prediction_cora.jl rename to examples/link_prediction_pubmed.jl index d6477fae0..f5b043f52 100644 --- a/examples/link_prediction_cora.jl +++ b/examples/link_prediction_pubmed.jl @@ -2,21 +2,18 @@ # Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py using Flux +# Link prediction task +# https://arxiv.org/pdf/2102.12557.pdf + using Flux: onecold, onehotbatch using Flux.Losses: logitbinarycrossentropy using GraphNeuralNetworks -using GraphNeuralNetworks: ones_like, zeros_like -using MLDatasets: Cora +using MLDatasets: PubMed, Cora using Statistics, Random, LinearAlgebra using CUDA -using MLJBase: AreaUnderCurve +# using MLJBase: AreaUnderCurve CUDA.allowscalar(false) -""" -Transform vector of cartesian indexes into a tuple of vectors containing integers. -""" -ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) - # arguments for the `train` function Base.@kwdef mutable struct Args η = 1f-3 # learning rate @@ -34,6 +31,8 @@ function (::DotPredictor)(g, x) return vec(z) end +using ChainRulesCore + function train(; kws...) # args = Args(; kws...) args = Args() @@ -54,67 +53,59 @@ function train(; kws...) g = GNNGraph(data.adjacency_list) |> device X = data.node_features |> device + #### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES - # Split edge set for training and testing s, t = edge_index(g) eids = randperm(g.num_edges) test_size = round(Int, g.num_edges * 0.1) - train_size = g.num_edges - test_size + test_pos_s, test_pos_t = s[eids[1:test_size]], t[eids[1:test_size]] - train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]] - - # Find all negative edges and split them for training and testing - adj = adjacency_matrix(g) - adj_neg = 1 .- adj - I - neg_s, neg_t = ci2t(findall(adj_neg .> 0), 2) - - neg_eids = randperm(length(neg_s))[1:g.num_edges] - test_neg_s, test_neg_t = neg_s[neg_eids[1:test_size]], neg_t[neg_eids[1:test_size]] - train_neg_s, train_neg_t = neg_s[neg_eids[test_size+1:end]], neg_t[neg_eids[test_size+1:end]] - # train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]] + test_pos_g = GNNGraph(test_pos_s, test_pos_t, num_nodes=g.num_nodes) - train_pos_g = GNNGraph((train_pos_s, train_pos_t), num_nodes=g.num_nodes) - train_neg_g = GNNGraph((train_neg_s, train_neg_t), num_nodes=g.num_nodes) + train_pos_s, train_pos_t = s[eids[test_size+1:end]], t[eids[test_size+1:end]] + train_pos_g = GNNGraph(train_pos_s, train_pos_t, num_nodes=g.num_nodes) - test_pos_g = GNNGraph((test_pos_s, test_pos_t), num_nodes=g.num_nodes) - test_neg_g = GNNGraph((test_neg_s, test_neg_t), num_nodes=g.num_nodes) + test_neg_g = negative_sample(g, num_neg_edges=test_size) - @show train_pos_g test_pos_g train_neg_g test_neg_g - - ### DEFINE MODEL + ### DEFINE MODEL ######### nin, nhidden = size(X,1), args.nhidden - model = GNNChain(GCNConv(nin => nhidden, relu), - GCNConv(nhidden => nhidden)) |> device + model = WithGraph(GNNChain(GCNConv(nin => nhidden, relu), + GCNConv(nhidden => nhidden)), + train_pos_g) |> device pred = DotPredictor() ps = Flux.params(model) opt = ADAM(args.η) - ### LOSS FUNCTION + ### LOSS FUNCTION ############ - function loss(pos_g, neg_g) - h = model(train_pos_g, X) + function loss(pos_g, neg_g = nothing) + h = model(X) + if neg_g === nothing + # we sample a negative graph at each training step + neg_g = negative_sample(pos_g) + end pos_score = pred(pos_g, h) neg_score = pred(neg_g, h) scores = [pos_score; neg_score] - labels = [ones_like(pos_score); zeros_like(neg_score)] + labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)] return logitbinarycrossentropy(scores, labels) end - function accuracy(pos_g, neg_g) - h = model(train_pos_g, X) - pos_score = pred(pos_g, h) - neg_score = pred(neg_g, h) - scores = [pos_score; neg_score] - labels = [ones_like(pos_score); zeros_like(neg_score)] - return logitbinarycrossentropy(scores, labels) - end + # function accuracy(pos_g, neg_g) + # h = model(train_pos_g, X) + # pos_score = pred(pos_g, h) + # neg_score = pred(neg_g, h) + # scores = [pos_score; neg_score] + # labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)] + # return logitbinarycrossentropy(scores, labels) + # end ### LOGGING FUNCTION function report(epoch) - train_loss = loss(train_pos_g, train_neg_g) + train_loss = loss(train_pos_g) test_loss = loss(test_pos_g, test_neg_g) println("Epoch: $epoch Train: $(train_loss) Test: $(test_loss)") end @@ -122,7 +113,7 @@ function train(; kws...) ### TRAINING report(0) for epoch in 1:args.epochs - gs = Flux.gradient(() -> loss(train_pos_g, train_neg_g), ps) + gs = Flux.gradient(() -> loss(train_pos_g), ps) Flux.Optimise.update!(opt, ps, gs) epoch % args.infotime == 0 && report(epoch) end