diff --git a/examples/Project.toml b/examples/Project.toml index 21702b05d..fce85a059 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,5 +1,7 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" diff --git a/examples/neural_ode.jl b/examples/neural_ode.jl new file mode 100644 index 000000000..28a0a65d2 --- /dev/null +++ b/examples/neural_ode.jl @@ -0,0 +1,61 @@ +# Load the packages +using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations +using Flux: onehotbatch, onecold, throttle +using Flux.Losses: logitcrossentropy +using Statistics: mean +using MLDatasets: Cora + +device = cpu # `gpu` not working yet + +# LOAD DATA +data = Cora.dataset() +g = GNNGraph(data.adjacency_list) |> device +X = data.node_features |> device +y = onehotbatch(data.node_labels, 1:data.num_classes) |> device +train_ids = data.train_indices |> device +val_ids = data.val_indices |> device +test_ids = data.test_indices |> device +ytrain = y[:, train_ids] + + +# Model and Data Configuration +nin = size(X, 1) +nhidden = 16 +nout = data.num_classes +epochs = 40 + +# Define the Neural GDE +diffeqsol_to_array(x) = reshape(device(x), size(x)[1:2]) + +# GCNConv(nhidden => nhidden, graph=g), + +node_chain = GNNChain(GCNConv(nhidden => nhidden, relu), + GCNConv(nhidden => nhidden, relu)) |> device + +node = NeuralODE(WithGraph(node_chain, g), + (0.f0, 1.f0), Tsit5(), save_everystep = false, + reltol = 1e-3, abstol = 1e-3, save_start = false) |> device + +model = GNNChain(GCNConv(nin => nhidden, relu), + Dropout(0.5), + node, + diffeqarray_to_array, + Dense(nhidden, nout)) |> device + +# Loss +loss(x, y) = logitcrossentropy(model(g, x), y) +accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y)) + +# Training +## Model Parameters +ps = Flux.params(model, node.p); + +## Optimizer +opt = ADAM(0.01) + +## Training Loop +for epoch in 1:epochs + gs = gradient(() -> loss(X, y), ps) + Flux.Optimise.update!(opt, ps, gs) + @show(accuracy(X, y)) +end