Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement ResGatedGraphConv and support Parallel in GNNChain #46

Merged
merged 6 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
version = "0.1.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -20,6 +21,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Adapt = "3"
CUDA = "3.3"
ChainRulesCore = "1"
DataStructures = "0.18"
Expand Down
19 changes: 17 additions & 2 deletions docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,26 @@ X = randn(Float32, din, 10)
model = GNNChain(GCNConv(din => d),
BatchNorm(d),
x -> relu.(x),
GraphConv(d => d, relu),
GCNConv(d => d, relu),
Dropout(0.5),
Dense(d, dout))

y = model(g, X)
y = model(g, X) # output size: (dout, g.num_nodes)
```

The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.

A `GNNChain` oppurtunely propagates the graph into the branches created by the `Flux.Parallel` layer:

```julia
AddResidual(l) = Parallel(+, identity, l) # implementing a skip/residual connection

model = GNNChain( ResGatedGraphConv(din => d, relu),
AddResidual(ResGatedGraphConv(d => d, relu)),
AddResidual(ResGatedGraphConv(d => d, relu)),
AddResidual(ResGatedGraphConv(d => d, relu)),
GlobalPooling(mean),
Dense(d, dout))

y = model(g, X) # output size: (dout, g.num_graphs)
```
2 changes: 2 additions & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export
GINConv,
GraphConv,
NNConv,
ResGatedGraphConv,
SAGEConv,

# layers/pool
Expand All @@ -62,5 +63,6 @@ include("msgpass.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/pool.jl")
include("deprecations.jl")

end
3 changes: 3 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Deprecated in v0.1

@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)
6 changes: 6 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
applylayer(l, g::GNNGraph, x) = l(x)
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)

# Handle Flux.Parallel
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs)
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)


applychain(::Tuple{}, g::GNNGraph, x) = x
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))

Expand Down
102 changes: 92 additions & 10 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
- `bias`: Learn the additive bias if true.
- `heads`: Number attention heads.
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
- `negative_slope`: The parameter of LeakyReLU.
Expand Down Expand Up @@ -407,7 +407,7 @@ end


@doc raw"""
GINConv(f; eps = 0f0)
GINConv(f, ϵ; aggr=+)

Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)

Expand All @@ -420,30 +420,38 @@ where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer o
# Arguments

- `f`: A (possibly learnable) function acting on node features.
- `eps`: Weighting factor.
- `ϵ`: Weighting factor.
"""
struct GINConv{R<:Real} <: GNNLayer
nn
eps::R
ϵ::R
aggr
end

@functor GINConv
Flux.trainable(l::GINConv) = (nn=l.nn,)
Flux.trainable(l::GINConv) = (l.nn,)

GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)

function GINConv(nn; eps=0f0)
GINConv(nn, eps)
end

compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
update_node(l::GINConv, m, x) = l.nn((1 + ofeltype(x, l.ϵ)) * x + m)

function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
check_num_nodes(g, X)
X, _ = propagate(l, g, +, X)
X, _ = propagate(l, g, l.aggr, X)
X
end


function Base.show(io::IO, l::GINConv)
print(io, "GINConv($(l.nn)")
print(io, ", $(l.ϵ)")
print(io, ")")
end



@doc raw"""
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)

Expand Down Expand Up @@ -572,3 +580,77 @@ function Base.show(io::IO, l::SAGEConv)
print(io, ", aggr=", l.aggr)
print(io, ")")
end


@doc raw"""
ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true)

The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets]((https://arxiv.org/abs/1711.07553)) paper.

The layer's forward pass is given by

```math
\mathbf{x}_i' = act\big(U\mathbf{xhttps://github.com/ArtLabBocconi/deepJuliaNN}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big),
```
where the edge gates ``\eta_{ij}`` are given by

```math
\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j).
```

# Arguments

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `act`: Activation function.
- `init`: Weight matrices' initializing function.
- `bias`: Learn an additive bias if true.
"""
struct ResGatedGraphConv <: GNNLayer
A
B
U
V
bias
σ
end

@functor ResGatedGraphConv

function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity;
init=glorot_uniform, bias::Bool=true)
in, out = ch
A = init(out, in)
B = init(out, in)
U = init(out, in)
V = init(out, in)
b = bias ? Flux.create_bias(A, true, out) : false
return ResGatedGraphConv(A, B, U, V, b, σ)
end

function compute_message(l::ResGatedGraphConv, di, dj)
η = sigmoid.(di.Ax .+ dj.Bx)
return η .* dj.Vx
end

update_node(l::ResGatedGraphConv, m, x) = m

function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)

Ax = l.A * x
Bx = l.B * x
Vx = l.V * x

m, _ = propagate(l, g, +, (; Ax, Bx, Vx))

return l.σ.(l.U*x .+ m .+ l.bias)
end


function Base.show(io::IO, l::ResGatedGraphConv)
out_channel, in_channel = size(l.A)
print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
end
return data
end


ofeltype(x, y) = convert(float(eltype(x)), y)
21 changes: 11 additions & 10 deletions test/examples/node_classification_cora.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ end

# arguments for the `train` function
Base.@kwdef mutable struct Args
η = 1f-3 # learning rate
epochs = 20 # number of epochs
η = 5f-3 # learning rate
epochs = 10 # number of epochs
seed = 17 # set seed > 0 for reproducibility
usecuda = false # if true use cuda (if available)
nhidden = 128 # dimension of hidden features
nhidden = 64 # dimension of hidden features
end

function train(Layer; verbose=false, kws...)
Expand Down Expand Up @@ -49,7 +49,7 @@ function train(Layer; verbose=false, kws...)

## DEFINE MODEL
model = GNNChain(Layer(nin, nhidden),
Dropout(0.5),
# Dropout(0.5),
Layer(nhidden, nhidden),
Dense(nhidden, nout)) |> device

Expand All @@ -70,8 +70,8 @@ function train(Layer; verbose=false, kws...)
ŷ = model(g, X)
logitcrossentropy(ŷ[:,train_ids], ytrain)
end
verbose && report(epoch)
Flux.Optimise.update!(opt, ps, gs)
verbose && report(epoch)
end

train_res = eval_loss_accuracy(X, y, train_ids, model, g)
Expand All @@ -84,15 +84,16 @@ for Layer in [
(nin, nout) -> GraphConv(nin => nout, relu, aggr=mean),
(nin, nout) -> SAGEConv(nin => nout, relu),
(nin, nout) -> GATConv(nin => nout, relu),
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
(nin, nout) -> ChebConv(nin => nout, 3),
(nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr=mean),
(nin, nout) -> ChebConv(nin => nout, 2),
(nin, nout) -> ResGatedGraphConv(nin => nout, relu),
# (nin, nout) -> NNConv(nin => nout), # needs edge features
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
]
train_res, test_res = train(Layer, verbose=true)
# @show Layer(2,2) train_res, test_res

# @show Layer(2,2)
train_res, test_res = train(Layer, verbose=false)
@test train_res.acc > 95
@test test_res.acc > 70
end
36 changes: 19 additions & 17 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,34 @@
@testset "GNNChain" begin
n, din, d, dout = 10, 3, 4, 2

g = GNNGraph(random_regular_graph(n, 4), graph_type=GRAPH_T)
g = GNNGraph(random_regular_graph(n, 4),
graph_type=GRAPH_T,
ndata= randn(Float32, din, n))

gnn = GNNChain(GCNConv(din => d),
BatchNorm(d),
x -> relu.(x),
GraphConv(d => d, relu),
x -> tanh.(x),
GraphConv(d => d, tanh),
Dropout(0.5),
Dense(d, dout))

testmode!(gnn)

X = randn(Float32, din, n)
test_layer(gnn, g, rtol=1e-5)

y = gnn(g, X)

@test y isa Matrix{Float32}
@test size(y) == (dout, n)

@test length(params(gnn)) == 9

gs = gradient(x -> sum(gnn(g, x)), X)[1]
@test gs isa Matrix{Float32}
@test size(gs) == size(X)
@testset "Parallel" begin
AddResidual(l) = Parallel(+, identity, l)

gnn = GNNChain(ResGatedGraphConv(din => d, tanh),
BatchNorm(d),
AddResidual(ResGatedGraphConv(d => d, tanh)),
BatchNorm(d),
Dense(d, dout))

gs = gradient(() -> sum(gnn(g, X)), Flux.params(gnn))
for p in Flux.params(gnn)
@test eltype(gs[p]) == Float32
@test size(gs[p]) == size(p)
testmode!(gnn)

test_layer(gnn, g, rtol=1e-5)
end
end
end
Expand Down
19 changes: 16 additions & 3 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@

@testset "GINConv" begin
nn = Dense(in_channel, out_channel)
eps = 0.001f0
l = GINConv(nn, eps=eps)

l = GINConv(nn, 0.01f0, aggr=mean)
for g in test_graphs
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps])
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
end

@test !in(:eps, Flux.trainable(l))
Expand Down Expand Up @@ -149,4 +149,17 @@
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
end
end


@testset "ResGatedGraphConv" begin
l = ResGatedGraphConv(in_channel => out_channel)
for g in test_graphs
test_layer(l, g, rtol=1e-5,)
end

l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=false)
for g in test_graphs
test_layer(l, g, rtol=1e-5,)
end
end
end
5 changes: 4 additions & 1 deletion test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,

# TEST LAYER GRADIENT - l(g, x)
l̄ = gradient(l -> loss(l, g, x), l)[1]
l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)

Expand All @@ -104,6 +105,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,

# TEST LAYER GRADIENT - l(g)
l̄ = gradient(l -> loss(l, g), l)[1]
l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1]
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)

Expand Down Expand Up @@ -140,7 +142,8 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5,
end
else
verbose && println("C")
test_approx_structs(x, f̄, f̄2; broken_grad_fields)
f̄ = f̄ isa Base.RefValue ? f̄[] : f̄ # Zygote wraps gradient of mutables in RefValue
test_approx_structs(x, f̄, f̄2; exclude_grad_fields, broken_grad_fields, verbose)
end
end
return true
Expand Down