Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add graph NeuralODE example #55

Merged
merged 3 commits into from
Oct 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Expand Down
61 changes: 61 additions & 0 deletions examples/neural_ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Load the packages
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
using Flux: onehotbatch, onecold, throttle
using Flux.Losses: logitcrossentropy
using Statistics: mean
using MLDatasets: Cora

device = cpu # `gpu` not working yet

# LOAD DATA
data = Cora.dataset()
g = GNNGraph(data.adjacency_list) |> device
X = data.node_features |> device
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
train_ids = data.train_indices |> device
val_ids = data.val_indices |> device
test_ids = data.test_indices |> device
ytrain = y[:, train_ids]


# Model and Data Configuration
nin = size(X, 1)
nhidden = 16
nout = data.num_classes
epochs = 40

# Define the Neural GDE
diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2])

# GCNConv(nhidden => nhidden, graph=g),

node_chain = GNNChain(GCNConv(nhidden => nhidden, relu),
GCNConv(nhidden => nhidden, relu)) |> device

node = NeuralODE(WithGraph(node_chain, g),
(0.f0, 1.f0), Tsit5(), save_everystep = false,
reltol = 1e-3, abstol = 1e-3, save_start = false) |> device

model = GNNChain(GCNConv(nin => nhidden, relu),
Dropout(0.5),
node,
diffeqarray_to_array,
Dense(nhidden, nout)) |> device

# Loss
loss(x, y) = logitcrossentropy(model(g, x), y)
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))

# Training
## Model Parameters
ps = Flux.params(model, node.p);

## Optimizer
opt = ADAM(0.01)

## Training Loop
for epoch in 1:epochs
gs = gradient(() -> loss(X, y), ps)
Flux.Optimise.update!(opt, ps, gs)
@show(accuracy(X, y))
end