From 0273b252c3e32ec98dbeb5451e56a18c46e09abc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 22 Sep 2021 19:22:05 +0200 Subject: [PATCH 1/6] add Parallel support in GNNChain --- Project.toml | 1 + docs/src/models.md | 19 +- src/GraphNeuralNetworks.jl | 1 + src/layers/basic.jl | 9 + src/layers/conv.jl | 76 ++++++- test/examples/node_classification_cora.jl | 5 +- test/layers/basic.jl | 37 ++-- test/layers/conv.jl | 239 ++++++++++++---------- test/runtests.jl | 8 +- test/test_utils.jl | 5 +- 10 files changed, 260 insertions(+), 140 deletions(-) diff --git a/Project.toml b/Project.toml index fa74c4b38..c6953f8b1 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"] version = "0.1.1" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/docs/src/models.md b/docs/src/models.md index 2f25372bf..de377c0ca 100644 --- a/docs/src/models.md +++ b/docs/src/models.md @@ -78,11 +78,26 @@ X = randn(Float32, din, 10) model = GNNChain(GCNConv(din => d), BatchNorm(d), x -> relu.(x), - GraphConv(d => d, relu), + GCNConv(d => d, relu), Dropout(0.5), Dense(d, dout)) -y = model(g, X) +y = model(g, X) # output size: (dout, g.num_nodes) ``` The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass. + +A `GNNChain` oppurtunely propagates the graph into the branches created by the `Flux.Parallel` layer: + +```julia +AddResidual(l) = Parallel(+, identity, l) + +model = GNNChain( AddResidual(ResGatedGraphConv(din => d, relu)), + BatchNorm(d), + AddResidual(ResGatedGraphConv(d => d, relu)), + BatchNorm(d), + GlobalPooling(mean), + Dense(d, dout)) + +y = model(g, X) # output size: (dout, g.num_graphs) +``` diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 1480b32a0..d3f70e439 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -46,6 +46,7 @@ export GINConv, GraphConv, NNConv, + ResGatedGraphConv, SAGEConv, # layers/pool diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 07f659a32..df8efbed0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -63,6 +63,15 @@ Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...) applylayer(l, g::GNNGraph, x) = l(x) applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) +# Handle Flux.Parallel +applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers) +applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs) +applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...) +applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers) +applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs) + + + applychain(::Tuple{}, g::GNNGraph, x) = x applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 3c8e532cb..44b35ee21 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -224,7 +224,7 @@ with ``z_i`` a normalization factor. - `in`: The dimension of input features. - `out`: The dimension of output features. -- `bias::Bool`: Keyword argument, whether to learn the additive bias. +- `bias`: Learn the additive bias if true. - `heads`: Number attention heads. - `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. - `negative_slope`: The parameter of LeakyReLU. @@ -572,3 +572,77 @@ function Base.show(io::IO, l::SAGEConv) print(io, ", aggr=", l.aggr) print(io, ")") end + + +@doc raw""" + ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true) + +The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets]((https://arxiv.org/abs/1711.07553)) paper. + +The layer's forward pass is given by + +```math +\mathbf{x}_i' = act\big(U\mathbf{xhttps://github.com/ArtLabBocconi/deepJuliaNN}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big), +``` +where the edge gates ``\eta_{ij}`` are given by + +```math +\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j). +``` + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `act`: Activation function. +- `init`: Weight matrices' initializing function. +- `bias`: Learn an additive bias if true. +""" +struct ResGatedGraphConv <: GNNLayer + A + B + U + V + bias + σ +end + +@functor ResGatedGraphConv + +function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity; + init=glorot_uniform, bias::Bool=true) + in, out = ch + A = init(out, in) + B = init(out, in) + U = init(out, in) + V = init(out, in) + b = bias ? Flux.create_bias(A, true, out) : false + return ResGatedGraphConv(A, B, U, V, b, σ) +end + +function compute_message(l::ResGatedGraphConv, di, dj) + η = sigmoid.(di.Ax .+ dj.Bx) + return η .* dj.Vx +end + +update_node(l::ResGatedGraphConv, m, x) = m + +function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + + Ax = l.A * x + Bx = l.B * x + Vx = l.V * x + + m, _ = propagate(l, g, +, (; Ax, Bx, Vx)) + + return l.σ.(l.U*x .+ m .+ l.bias) +end + + +function Base.show(io::IO, l::ResGatedGraphConv) + out_channel, in_channel = size(l.weight) + print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end diff --git a/test/examples/node_classification_cora.jl b/test/examples/node_classification_cora.jl index caf775796..49b583d7d 100644 --- a/test/examples/node_classification_cora.jl +++ b/test/examples/node_classification_cora.jl @@ -70,8 +70,8 @@ function train(Layer; verbose=false, kws...) ŷ = model(g, X) logitcrossentropy(ŷ[:,train_ids], ytrain) end - verbose && report(epoch) Flux.Optimise.update!(opt, ps, gs) + verbose && report(epoch) end train_res = eval_loss_accuracy(X, y, train_ids, model, g) @@ -87,11 +87,12 @@ for Layer in [ (nin, nout) -> GATConv(nin => nout÷2, relu, heads=2), (nin, nout) -> GINConv(Dense(nin, nout, relu)), (nin, nout) -> ChebConv(nin => nout, 3), + (nin, nout) -> ResGatedGraphConv(nin => nout, relu), # (nin, nout) -> NNConv(nin => nout), # needs edge features # (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout # (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well ] - train_res, test_res = train(Layer, verbose=true) + train_res, test_res = train(Layer, verbose=false) # @show Layer(2,2) train_res, test_res @test train_res.acc > 95 @test test_res.acc > 70 diff --git a/test/layers/basic.jl b/test/layers/basic.jl index a26c1ffae..02a52e44b 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -2,32 +2,35 @@ @testset "GNNChain" begin n, din, d, dout = 10, 3, 4, 2 - g = GNNGraph(random_regular_graph(n, 4), graph_type=GRAPH_T) + g = GNNGraph(random_regular_graph(n, 4), + graph_type=GRAPH_T, + ndata= randn(Float32, din, n)) gnn = GNNChain(GCNConv(din => d), BatchNorm(d), - x -> relu.(x), - GraphConv(d => d, relu), + x -> tanh.(x), + GraphConv(d => d, tanh), Dropout(0.5), Dense(d, dout)) + + testmode!(gnn) - X = randn(Float32, din, n) + test_layer(gnn, g, rtol=1e-5) # exclude BN buffers - y = gnn(g, X) - - @test y isa Matrix{Float32} - @test size(y) == (dout, n) - @test length(params(gnn)) == 9 - - gs = gradient(x -> sum(gnn(g, x)), X)[1] - @test gs isa Matrix{Float32} - @test size(gs) == size(X) + @testset "Parallel" begin + AddResidual(l) = Parallel(+, identity, l) + + gnn = GNNChain(AddResidual(ResGatedGraphConv(din => d, tanh)), + BatchNorm(d), + AddResidual(ResGatedGraphConv(d => d, tanh)), + BatchNorm(d), + Dense(d, dout)) - gs = gradient(() -> sum(gnn(g, X)), Flux.params(gnn)) - for p in Flux.params(gnn) - @test eltype(gs[p]) == Float32 - @test size(gs[p]) == size(p) + testmode!(gnn) + + test_layer(gnn, g, rtol=1e-5, verbose=true, + exclude_grad_fields=[:μ, :σ², :ϵ]) # exclude BN buffers end end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 4599f0cdd..81d35c0bf 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -24,129 +24,142 @@ test_graphs = [g1, g_single_vertex] - @testset "GCNConv" begin - l = GCNConv(in_channel => out_channel) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - - l = GCNConv(in_channel => out_channel, tanh, bias=false) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - - l = GCNConv(in_channel => out_channel, add_self_loops=false) - test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes)) - end - - @testset "ChebConv" begin - k = 3 - l = ChebConv(in_channel => out_channel, k) - @test size(l.weight) == (out_channel, in_channel, k) - @test size(l.bias) == (out_channel,) - @test l.k == k - for g in test_graphs - g = add_self_loops(g) - test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes)) - if TEST_GPU - @test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes)) - end - end + # @testset "GCNConv" begin + # l = GCNConv(in_channel => out_channel) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + + # l = GCNConv(in_channel => out_channel, tanh, bias=false) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + + # l = GCNConv(in_channel => out_channel, add_self_loops=false) + # test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes)) + # end + + # @testset "ChebConv" begin + # k = 3 + # l = ChebConv(in_channel => out_channel, k) + # @test size(l.weight) == (out_channel, in_channel, k) + # @test size(l.bias) == (out_channel,) + # @test l.k == k + # for g in test_graphs + # g = add_self_loops(g) + # test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes)) + # if TEST_GPU + # @test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes)) + # end + # end - @testset "bias=false" begin - @test length(Flux.params(ChebConv(2=>3, 3))) == 2 - @test length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 - end - end - - @testset "GraphConv" begin - l = GraphConv(in_channel => out_channel) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - - l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end + # @testset "bias=false" begin + # @test length(Flux.params(ChebConv(2=>3, 3))) == 2 + # @test length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 + # end + # end + + # @testset "GraphConv" begin + # l = GraphConv(in_channel => out_channel) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + + # l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end - @testset "bias=false" begin - @test length(Flux.params(GraphConv(2=>3))) == 3 - @test length(Flux.params(GraphConv(2=>3, bias=false))) == 2 - end - end - - @testset "GATConv" begin - - for heads in (1, 2), concat in (true, false) - l = GATConv(in_channel => out_channel; heads, concat) - for g in test_graphs - test_layer(l, g, rtol=1e-4, - outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) - end - end - - @testset "bias=false" begin - @test length(Flux.params(GATConv(2=>3))) == 3 - @test length(Flux.params(GATConv(2=>3, bias=false))) == 2 - end - end - - @testset "GatedGraphConv" begin - num_layers = 3 - l = GatedGraphConv(out_channel, num_layers) - @test size(l.weight) == (out_channel, out_channel, num_layers) - - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - end - - @testset "EdgeConv" begin - l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - end - - @testset "GINConv" begin - nn = Dense(in_channel, out_channel) - eps = 0.001f0 - l = GINConv(nn, eps=eps) - for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps]) - end + # @testset "bias=false" begin + # @test length(Flux.params(GraphConv(2=>3))) == 3 + # @test length(Flux.params(GraphConv(2=>3, bias=false))) == 2 + # end + # end + + # @testset "GATConv" begin + + # for heads in (1, 2), concat in (true, false) + # l = GATConv(in_channel => out_channel; heads, concat) + # for g in test_graphs + # test_layer(l, g, rtol=1e-4, + # outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) + # end + # end + + # @testset "bias=false" begin + # @test length(Flux.params(GATConv(2=>3))) == 3 + # @test length(Flux.params(GATConv(2=>3, bias=false))) == 2 + # end + # end + + # @testset "GatedGraphConv" begin + # num_layers = 3 + # l = GatedGraphConv(out_channel, num_layers) + # @test size(l.weight) == (out_channel, out_channel, num_layers) + + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + # end + + # @testset "EdgeConv" begin + # l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + # end + + # @testset "GINConv" begin + # nn = Dense(in_channel, out_channel) + # eps = 0.001f0 + # l = GINConv(nn, eps=eps) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps]) + # end - @test !in(:eps, Flux.trainable(l)) - end + # @test !in(:eps, Flux.trainable(l)) + # end - @testset "NNConv" begin - edim = 10 - nn = Dense(edim, out_channel * in_channel) + # @testset "NNConv" begin + # edim = 10 + # nn = Dense(edim, out_channel * in_channel) - l = NNConv(in_channel => out_channel, nn) - for g in test_graphs - g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end + # l = NNConv(in_channel => out_channel, nn) + # for g in test_graphs + # g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end - l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean) - for g in test_graphs - g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - end - end + # l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean) + # for g in test_graphs + # g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + # end + + # @testset "SAGEConv" begin + # l = SAGEConv(in_channel => out_channel) + # @test l.aggr == mean + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + + # l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+) + # for g in test_graphs + # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + # end + # end - @testset "SAGEConv" begin - l = SAGEConv(in_channel => out_channel) - @test l.aggr == mean + + @testset "ResGatedGraphConv" begin + l = ResGatedGraphConv(in_channel => out_channel) for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=1e-5,) end - - l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+) + + l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=false) for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=1e-5,) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 79aef8031..8d2145e5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,12 +16,12 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - "gnngraph", - "msgpass", + # "gnngraph", + # "msgpass", "layers/basic", "layers/conv", - "layers/pool", - "examples/node_classification_cora", + # "layers/pool", + # "examples/node_classification_cora", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") diff --git a/test/test_utils.jl b/test/test_utils.jl index 1716d49d9..7c3566154 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -94,6 +94,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5, # TEST LAYER GRADIENT - l(g, x) l̄ = gradient(l -> loss(l, g, x), l)[1] + l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1] test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose) @@ -104,6 +105,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5, # TEST LAYER GRADIENT - l(g) l̄ = gradient(l -> loss(l, g), l)[1] + l̄ = l̄ isa Base.RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1] test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose) @@ -140,7 +142,8 @@ function test_approx_structs(l, l̄, l̄2; atol=1e-5, rtol=1e-5, end else verbose && println("C") - test_approx_structs(x, f̄, f̄2; broken_grad_fields) + f̄ = f̄ isa Base.RefValue ? f̄[] : f̄ # Zygote wraps gradient of mutables in RefValue + test_approx_structs(x, f̄, f̄2; exclude_grad_fields, broken_grad_fields, verbose) end end return true From e457ae55d41516313cf3ab8edcbc20a1f175ea06 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 22 Sep 2021 19:40:41 +0200 Subject: [PATCH 2/6] fix tests --- src/layers/basic.jl | 7 ++----- src/layers/conv.jl | 2 +- test/layers/basic.jl | 7 +++---- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index df8efbed0..7b0779a7d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -64,12 +64,9 @@ applylayer(l, g::GNNGraph, x) = l(x) applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) # Handle Flux.Parallel -applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers) -applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs) +applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers) +applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs) applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...) -applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers) -applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs) - applychain(::Tuple{}, g::GNNGraph, x) = x diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 44b35ee21..462b239bf 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -641,7 +641,7 @@ end function Base.show(io::IO, l::ResGatedGraphConv) - out_channel, in_channel = size(l.weight) + out_channel, in_channel = size(l.A) print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel) l.σ == identity || print(io, ", ", l.σ) print(io, ")") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 02a52e44b..3ba63e51e 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -15,13 +15,13 @@ testmode!(gnn) - test_layer(gnn, g, rtol=1e-5) # exclude BN buffers + test_layer(gnn, g, rtol=1e-5) @testset "Parallel" begin AddResidual(l) = Parallel(+, identity, l) - gnn = GNNChain(AddResidual(ResGatedGraphConv(din => d, tanh)), + gnn = GNNChain(ResGatedGraphConv(din => d, tanh), BatchNorm(d), AddResidual(ResGatedGraphConv(d => d, tanh)), BatchNorm(d), @@ -29,8 +29,7 @@ testmode!(gnn) - test_layer(gnn, g, rtol=1e-5, verbose=true, - exclude_grad_fields=[:μ, :σ², :ϵ]) # exclude BN buffers + test_layer(gnn, g, rtol=1e-5) end end end From c15b750db47d5353809fc481df2a05b0475dfec0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 22 Sep 2021 23:36:54 +0200 Subject: [PATCH 3/6] cleanup --- docs/src/models.md | 10 +- test/layers/conv.jl | 236 ++++++++++++++++++++++---------------------- test/runtests.jl | 8 +- 3 files changed, 127 insertions(+), 127 deletions(-) diff --git a/docs/src/models.md b/docs/src/models.md index de377c0ca..5f4a98e96 100644 --- a/docs/src/models.md +++ b/docs/src/models.md @@ -90,14 +90,14 @@ The `GNNChain` only propagates the graph and the node features. More complex sce A `GNNChain` oppurtunely propagates the graph into the branches created by the `Flux.Parallel` layer: ```julia -AddResidual(l) = Parallel(+, identity, l) +AddResidual(l) = Parallel(+, identity, l) # implementing a skip/residual connection -model = GNNChain( AddResidual(ResGatedGraphConv(din => d, relu)), - BatchNorm(d), +model = GNNChain( ResGatedGraphConv(din => d, relu), + AddResidual(ResGatedGraphConv(d => d, relu)), + AddResidual(ResGatedGraphConv(d => d, relu)), AddResidual(ResGatedGraphConv(d => d, relu)), - BatchNorm(d), GlobalPooling(mean), Dense(d, dout)) y = model(g, X) # output size: (dout, g.num_graphs) -``` +``` diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 81d35c0bf..e5e9e8d7a 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -24,131 +24,131 @@ test_graphs = [g1, g_single_vertex] - # @testset "GCNConv" begin - # l = GCNConv(in_channel => out_channel) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - - # l = GCNConv(in_channel => out_channel, tanh, bias=false) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - - # l = GCNConv(in_channel => out_channel, add_self_loops=false) - # test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes)) - # end - - # @testset "ChebConv" begin - # k = 3 - # l = ChebConv(in_channel => out_channel, k) - # @test size(l.weight) == (out_channel, in_channel, k) - # @test size(l.bias) == (out_channel,) - # @test l.k == k - # for g in test_graphs - # g = add_self_loops(g) - # test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes)) - # if TEST_GPU - # @test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes)) - # end - # end + @testset "GCNConv" begin + l = GCNConv(in_channel => out_channel) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + + l = GCNConv(in_channel => out_channel, tanh, bias=false) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + + l = GCNConv(in_channel => out_channel, add_self_loops=false) + test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes)) + end + + @testset "ChebConv" begin + k = 3 + l = ChebConv(in_channel => out_channel, k) + @test size(l.weight) == (out_channel, in_channel, k) + @test size(l.bias) == (out_channel,) + @test l.k == k + for g in test_graphs + g = add_self_loops(g) + test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes)) + if TEST_GPU + @test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes)) + end + end - # @testset "bias=false" begin - # @test length(Flux.params(ChebConv(2=>3, 3))) == 2 - # @test length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 - # end - # end - - # @testset "GraphConv" begin - # l = GraphConv(in_channel => out_channel) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - - # l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end + @testset "bias=false" begin + @test length(Flux.params(ChebConv(2=>3, 3))) == 2 + @test length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 + end + end + + @testset "GraphConv" begin + l = GraphConv(in_channel => out_channel) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + + l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end - # @testset "bias=false" begin - # @test length(Flux.params(GraphConv(2=>3))) == 3 - # @test length(Flux.params(GraphConv(2=>3, bias=false))) == 2 - # end - # end - - # @testset "GATConv" begin - - # for heads in (1, 2), concat in (true, false) - # l = GATConv(in_channel => out_channel; heads, concat) - # for g in test_graphs - # test_layer(l, g, rtol=1e-4, - # outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) - # end - # end - - # @testset "bias=false" begin - # @test length(Flux.params(GATConv(2=>3))) == 3 - # @test length(Flux.params(GATConv(2=>3, bias=false))) == 2 - # end - # end - - # @testset "GatedGraphConv" begin - # num_layers = 3 - # l = GatedGraphConv(out_channel, num_layers) - # @test size(l.weight) == (out_channel, out_channel, num_layers) - - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - # end - - # @testset "EdgeConv" begin - # l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - # end - - # @testset "GINConv" begin - # nn = Dense(in_channel, out_channel) - # eps = 0.001f0 - # l = GINConv(nn, eps=eps) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps]) - # end + @testset "bias=false" begin + @test length(Flux.params(GraphConv(2=>3))) == 3 + @test length(Flux.params(GraphConv(2=>3, bias=false))) == 2 + end + end + + @testset "GATConv" begin + + for heads in (1, 2), concat in (true, false) + l = GATConv(in_channel => out_channel; heads, concat) + for g in test_graphs + test_layer(l, g, rtol=1e-4, + outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) + end + end + + @testset "bias=false" begin + @test length(Flux.params(GATConv(2=>3))) == 3 + @test length(Flux.params(GATConv(2=>3, bias=false))) == 2 + end + end + + @testset "GatedGraphConv" begin + num_layers = 3 + l = GatedGraphConv(out_channel, num_layers) + @test size(l.weight) == (out_channel, out_channel, num_layers) + + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + end + + @testset "EdgeConv" begin + l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + end + + @testset "GINConv" begin + nn = Dense(in_channel, out_channel) + eps = 0.001f0 + l = GINConv(nn, eps=eps) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps]) + end - # @test !in(:eps, Flux.trainable(l)) - # end + @test !in(:eps, Flux.trainable(l)) + end - # @testset "NNConv" begin - # edim = 10 - # nn = Dense(edim, out_channel * in_channel) + @testset "NNConv" begin + edim = 10 + nn = Dense(edim, out_channel * in_channel) - # l = NNConv(in_channel => out_channel, nn) - # for g in test_graphs - # g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end + l = NNConv(in_channel => out_channel, nn) + for g in test_graphs + g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end - # l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean) - # for g in test_graphs - # g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - # end - - # @testset "SAGEConv" begin - # l = SAGEConv(in_channel => out_channel) - # @test l.aggr == mean - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end + l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean) + for g in test_graphs + g = GNNGraph(g, edata=rand(T, edim, g.num_edges)) + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + end + + @testset "SAGEConv" begin + l = SAGEConv(in_channel => out_channel) + @test l.aggr == mean + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end - # l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+) - # for g in test_graphs - # test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) - # end - # end + l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+) + for g in test_graphs + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) + end + end @testset "ResGatedGraphConv" begin diff --git a/test/runtests.jl b/test/runtests.jl index 8d2145e5b..79aef8031 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,12 +16,12 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - # "gnngraph", - # "msgpass", + "gnngraph", + "msgpass", "layers/basic", "layers/conv", - # "layers/pool", - # "examples/node_classification_cora", + "layers/pool", + "examples/node_classification_cora", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") From d0d39d70bef4817429c2cde3421ff29f14f1d2d6 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 23 Sep 2021 06:56:04 +0200 Subject: [PATCH 4/6] adjustment to GINConv --- Project.toml | 2 ++ src/GraphNeuralNetworks.jl | 1 + src/deprecations.jl | 3 +++ src/layers/conv.jl | 26 +++++++++++++++-------- src/utils.jl | 3 +++ test/examples/node_classification_cora.jl | 20 ++++++++--------- test/layers/conv.jl | 6 +++--- 7 files changed, 39 insertions(+), 22 deletions(-) create mode 100644 src/deprecations.jl diff --git a/Project.toml b/Project.toml index c6953f8b1..8c9f19ff8 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" @@ -21,6 +22,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +Adapt = "3" CUDA = "3.3" ChainRulesCore = "1" DataStructures = "0.18" diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index d3f70e439..07c7f115e 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -63,5 +63,6 @@ include("msgpass.jl") include("layers/basic.jl") include("layers/conv.jl") include("layers/pool.jl") +include("deprecations.jl") end diff --git a/src/deprecations.jl b/src/deprecations.jl new file mode 100644 index 000000000..967543085 --- /dev/null +++ b/src/deprecations.jl @@ -0,0 +1,3 @@ +# Deprecated in v0.1 + +@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 462b239bf..5dc97321f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -407,7 +407,7 @@ end @doc raw""" - GINConv(f; eps = 0f0) + GINConv(f, ϵ; aggr=+) Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf) @@ -420,30 +420,38 @@ where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer o # Arguments - `f`: A (possibly learnable) function acting on node features. -- `eps`: Weighting factor. +- `ϵ`: Weighting factor. """ struct GINConv{R<:Real} <: GNNLayer nn - eps::R + ϵ::R + aggr end @functor GINConv -Flux.trainable(l::GINConv) = (nn=l.nn,) +Flux.trainable(l::GINConv) = (l.nn,) + +GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr) -function GINConv(nn; eps=0f0) - GINConv(nn, eps) -end compute_message(l::GINConv, x_i, x_j, e_ij) = x_j -update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m) +update_node(l::GINConv, m, x) = l.nn((1 + ofeltype(x, l.ϵ)) * x + m) function (l::GINConv)(g::GNNGraph, X::AbstractMatrix) check_num_nodes(g, X) - X, _ = propagate(l, g, +, X) + X, _ = propagate(l, g, l.aggr, X) X end +function Base.show(io::IO, l::GINConv) + print(io, "GINConv($(l.nn)") + print(io, ", $(l.ϵ)") + print(io, ")") +end + + + @doc raw""" NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform) diff --git a/src/utils.jl b/src/utils.jl index 6b8b6b66e..d47a5913c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,3 +67,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee end return data end + + +ofeltype(x, y) = convert(float(eltype(x)), y) \ No newline at end of file diff --git a/test/examples/node_classification_cora.jl b/test/examples/node_classification_cora.jl index 49b583d7d..21fc9f131 100644 --- a/test/examples/node_classification_cora.jl +++ b/test/examples/node_classification_cora.jl @@ -17,11 +17,11 @@ end # arguments for the `train` function Base.@kwdef mutable struct Args - η = 1f-3 # learning rate - epochs = 20 # number of epochs + η = 5f-3 # learning rate + epochs = 10 # number of epochs seed = 17 # set seed > 0 for reproducibility usecuda = false # if true use cuda (if available) - nhidden = 128 # dimension of hidden features + nhidden = 64 # dimension of hidden features end function train(Layer; verbose=false, kws...) @@ -49,7 +49,7 @@ function train(Layer; verbose=false, kws...) ## DEFINE MODEL model = GNNChain(Layer(nin, nhidden), - Dropout(0.5), + # Dropout(0.5), Layer(nhidden, nhidden), Dense(nhidden, nout)) |> device @@ -84,16 +84,16 @@ for Layer in [ (nin, nout) -> GraphConv(nin => nout, relu, aggr=mean), (nin, nout) -> SAGEConv(nin => nout, relu), (nin, nout) -> GATConv(nin => nout, relu), - (nin, nout) -> GATConv(nin => nout÷2, relu, heads=2), - (nin, nout) -> GINConv(Dense(nin, nout, relu)), - (nin, nout) -> ChebConv(nin => nout, 3), - (nin, nout) -> ResGatedGraphConv(nin => nout, relu), + (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr=mean), + (nin, nout) -> ChebConv(nin => nout, 2), + (nin, nout) -> ResGatedGraphConv(nin => nout, relu), # (nin, nout) -> NNConv(nin => nout), # needs edge features # (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout # (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well ] - train_res, test_res = train(Layer, verbose=false) - # @show Layer(2,2) train_res, test_res + + @show Layer(2,2) + train_res, test_res = train(Layer, verbose=true) @test train_res.acc > 95 @test test_res.acc > 70 end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index e5e9e8d7a..d133cdd6c 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -111,10 +111,10 @@ @testset "GINConv" begin nn = Dense(in_channel, out_channel) - eps = 0.001f0 - l = GINConv(nn, eps=eps) + + l = GINConv(nn, 0.01f0, aggr=mean) for g in test_graphs - test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps]) + test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes)) end @test !in(:eps, Flux.trainable(l)) From 4296694c8ae5411f47ae3632b9b2ad4f5bf5158c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 23 Sep 2021 06:56:30 +0200 Subject: [PATCH 5/6] cleanup --- test/examples/node_classification_cora.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/examples/node_classification_cora.jl b/test/examples/node_classification_cora.jl index 21fc9f131..57fec1dd3 100644 --- a/test/examples/node_classification_cora.jl +++ b/test/examples/node_classification_cora.jl @@ -92,8 +92,8 @@ for Layer in [ # (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well ] - @show Layer(2,2) - train_res, test_res = train(Layer, verbose=true) + # @show Layer(2,2) + train_res, test_res = train(Layer, verbose=false) @test train_res.acc > 95 @test test_res.acc > 70 end From 1bfb633e2e6997f919240ebd24bcfad7e70ffd07 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 23 Sep 2021 07:09:26 +0200 Subject: [PATCH 6/6] cleanup --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8c9f19ff8..6126fed31 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"