Skip to content

Commit

Permalink
Merge pull request #55 from CarloLucibello/cl/withgraph
Browse files Browse the repository at this point in the history
add graph NeuralODE example
  • Loading branch information
CarloLucibello authored Oct 16, 2021
2 parents e1114a2 + cac775d commit 63745a5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Expand Down
61 changes: 61 additions & 0 deletions examples/neural_ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Load the packages
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
using Flux: onehotbatch, onecold, throttle
using Flux.Losses: logitcrossentropy
using Statistics: mean
using MLDatasets: Cora

device = cpu # `gpu` not working yet

# LOAD DATA
data = Cora.dataset()
g = GNNGraph(data.adjacency_list) |> device
X = data.node_features |> device
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
train_ids = data.train_indices |> device
val_ids = data.val_indices |> device
test_ids = data.test_indices |> device
ytrain = y[:, train_ids]


# Model and Data Configuration
nin = size(X, 1)
nhidden = 16
nout = data.num_classes
epochs = 40

# Define the Neural GDE
diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2])

# GCNConv(nhidden => nhidden, graph=g),

node_chain = GNNChain(GCNConv(nhidden => nhidden, relu),
GCNConv(nhidden => nhidden, relu)) |> device

node = NeuralODE(WithGraph(node_chain, g),
(0.f0, 1.f0), Tsit5(), save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false) |> device

model = GNNChain(GCNConv(nin => nhidden, relu),
Dropout(0.5),
node,
diffeqarray_to_array,
Dense(nhidden, nout)) |> device

# Loss
loss(x, y) = logitcrossentropy(model(g, x), y)
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))

# Training
## Model Parameters
ps = Flux.params(model, node.p);

## Optimizer
opt = ADAM(0.01)

## Training Loop
for epoch in 1:epochs
gs = gradient(() -> loss(X, y), ps)
Flux.Optimise.update!(opt, ps, gs)
@show(accuracy(X, y))
end

0 comments on commit 63745a5

Please sign in to comment.