Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeuralODE example working on cpu and gpu #67

Merged
merged 3 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions examples/neural_ode.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Load the packages
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
using Flux: onehotbatch, onecold, throttle
using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
using Statistics: mean
using MLDatasets: Cora
using CUDA
# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux

device = cpu # `gpu` not working yet
# device = cpu # `gpu` not working yet
device = CUDA.functional() ? gpu : cpu

# LOAD DATA
data = Cora.dataset()
Expand Down Expand Up @@ -39,21 +42,21 @@ node = NeuralODE(WithGraph(node_chain, g),
model = GNNChain(GCNConv(nin => nhidden, relu),
Dropout(0.5),
node,
diffeqarray_to_array,
diffeqsol_to_array,
Dense(nhidden, nout)) |> device

# Loss
loss(x, y) = logitcrossentropy(model(g, x), y)
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))

# Training
## Model Parameters
ps = Flux.params(model, node.p);
# # Training
# ## Model Parameters
ps = Flux.params(model);

## Optimizer
# ## Optimizer
opt = ADAM(0.01)

## Training Loop
# ## Training Loop
for epoch in 1:epochs
gs = gradient(() -> loss(X, y), ps)
Flux.Optimise.update!(opt, ps, gs)
Expand Down
8 changes: 5 additions & 3 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
rand_graph(n, m; bidirected=true, kws...)
rand_graph(n, m; bidirected=true, seed=-1, kws...)

Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
and `m` edges.
Expand All @@ -8,6 +8,8 @@ If `bidirected=true` the reverse edge of each edge will be present.
If `bidirected=false` instead, `m` unrelated edges are generated.
In any case, the output graph will contain no self-loops or multi-edges.

Use a `seed > 0` for reproducibility.

Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.

# Usage
Expand Down Expand Up @@ -43,10 +45,10 @@ julia> edge_index(g)

```
"""
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
if bidirected
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
end
m2 = bidirected ? m÷2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
end
4 changes: 2 additions & 2 deletions src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types
=============================================#

const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V}
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector{<:Integer}
const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
Expand Down
42 changes: 28 additions & 14 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type GNNLayer end


"""
WithGraph(model, g::GNNGraph; traingraph=false)
WithGraph(model, g::GNNGraph; traingraph=false)

A type wrapping the `model` and tying it to the graph `g`.
In the forward pass, can only take feature arrays as inputs,
Expand All @@ -38,17 +38,31 @@ x2 = rand(Float32, 2, 4)
@assert wg(g2, x2) == model(g2, x2)
```
"""
struct WithGraph{M}
model::M
g::GNNGraph
traingraph::Bool
struct WithGraph{M, G<:GNNGraph}
model::M
g::G
traingraph::Bool
end

WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)

@functor WithGraph
Flux.trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)

# Work around
# https://github.com/FluxML/Flux.jl/issues/1733
# Revisit after
# https://github.com/FluxML/Flux.jl/pull/1742
function Flux.destructure(m::WithGraph)
@assert m.traingraph == false # TODO
p, re = Flux.destructure(m.model)
function re_withgraph(x)
WithGraph(re(x), m.g, m.traingraph)
end

return p, re_withgraph
end

(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)

Expand Down Expand Up @@ -86,15 +100,15 @@ julia> m(g, x)
```
"""
struct GNNChain{T} <: GNNLayer
layers::T

GNNChain(xs...) = new{typeof(xs)}(xs)

function GNNChain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
layers::T

GNNChain(xs...) = new{typeof(xs)}(xs)
function GNNChain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
isempty(kw) && return new{Tuple{}}(())
new{typeof(values(kw))}(values(kw))
end
end

@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,
Expand Down
7 changes: 6 additions & 1 deletion test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
m2 = m ÷ 2
x = rand(3, n)
e = rand(4, m2)

g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == m
Expand All @@ -15,8 +16,12 @@
@test g.edata.e[:,1:m2] == e
@test g.edata.e[:,m2+1:end] == e
end
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)

g = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == m

g2 = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
@test edge_index(g2) == edge_index(g)
end
end
1 change: 0 additions & 1 deletion test/examples/node_classification_cora.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ function eval_loss_accuracy(X, y, ids, model, g)
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
end


# arguments for the `train` function
Base.@kwdef mutable struct Args
η = 5f-3 # learning rate
Expand Down