From c473c9e8cbfa0d7d28e0f9e9e32421b4861fdbf8 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 24 Oct 2021 12:12:47 +0200 Subject: [PATCH 1/2] add GlobalAttentionPool --- src/GraphNeuralNetworks.jl | 1 + src/layers/pool.jl | 44 ++++++++++++++++++++++++++++++++++++++ test/layers/pool.jl | 31 +++++++++++++++++++++++---- 3 files changed, 72 insertions(+), 4 deletions(-) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index b76ec5c00..9eb2f03e6 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -60,6 +60,7 @@ export # layers/pool GlobalPool, + GlobalAttentionPool, TopKPool, topk_index diff --git a/src/layers/pool.jl b/src/layers/pool.jl index f1cb6ad70..0580b350f 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -42,6 +42,47 @@ end (l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) + +@doc raw""" + GlobalAttentionPool(fgate, ffeat=identity) + +Global soft attention layer from the [Gated Graph Sequence Neural +Networks](https://arxiv.org/abs/1511.05493) paper + +```math +\mathbf{u}_V} = \sum_{i\in V} \mathrm{softmax} \left( + f_{\mathrm{gate}} ( \mathbf{x}_i ) \right) \odot + f_{\mathrm{feat}} ( \mathbf{x}_i ), +``` + +where ``f_{\mathrm{gate}} \colon \mathbb{R}^F \to +\mathbb{R}`` and ``f_{\mathbf{feat}}` denote neural networks. + +# Arguments + +fgate: +ffeat: +""" +struct GlobalAttentionPool{G,F} + fgate::G + ffeat::F +end + +@functor GlobalAttentionPool + +GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) + + +function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray) + weights = softmax_nodes(g, l.fgate(x)) + feats = l.ffeat(x) + u = reduce_nodes(+, g, weights .* feats) + return u +end + +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g))) + + """ TopKPool(adj, k, in_channel) @@ -60,6 +101,9 @@ struct TopKPool{T,S} Ã::AbstractMatrix{T} end + + + function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init=glorot_uniform) TopKPool(adj, k, init(in_channel), similar(adj, k, k)) end diff --git a/test/layers/pool.jl b/test/layers/pool.jl index e86ce530b..8129814e7 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -1,14 +1,37 @@ @testset "pool" begin @testset "GlobalPool" begin + p = GlobalPool(+) n = 10 - X = rand(Float32, 16, n) + chin = 6 + X = rand(Float32, 6, n) g = GNNGraph(random_regular_graph(n, 4), ndata=X) - p = GlobalPool(+) - y = p(g, X) - @test y ≈ NNlib.scatter(+, X, ones(Int, n)) + u = p(g, X) + @test u ≈ sum(X, dims=2) + + ng = 3 + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + ndata=rand(Float32, chin, n)) + for i=1:ng]) + u = p(g, g.ndata.x) + @test size(u) == (chin, ng) + @test u[:,[1]] ≈ sum(g.ndata.x[:,1:n], dims=2) + @test p(g).gdata.u == u + test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph) end + @testset "GlobalAttentionPool" begin + n = 10 + chin = 16 + X = rand(Float32, chin, n) + g = GNNGraph(random_regular_graph(n, 4), ndata=X) + fgate = Dense(chin, 1, sigmoid) + p = GlobalAttentionPool(fgate) + y = p(g, X) + test_layer(p, g, rtol=1e-5, outtype=:graph) + end + + @testset "TopKPool" begin N = 10 k, in_channel = 4, 7 From bfa25ce4bab3b89b558c4b79481e35374370a046 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 24 Oct 2021 16:29:44 +0200 Subject: [PATCH 2/2] fixes --- Project.toml | 2 ++ src/layers/pool.jl | 50 ++++++++++++++++++++++++++++++++------------- test/layers/pool.jl | 27 ++++++++++++++++-------- test/test_utils.jl | 22 ++++++++++++-------- 4 files changed, 70 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 15158bff7..daacedab6 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b" [compat] Adapt = "3" @@ -33,6 +34,7 @@ MacroTools = "0.5" NNlib = "0.7" NNlibCUDA = "0.1" julia = "1.6" +TestEnv = "1" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 0580b350f..7bb386506 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -10,6 +10,7 @@ and performs the operation ```math \mathbf{u}_V = \square_{i \in V} \mathbf{x}_i ``` + where ``V`` is the set of nodes of the input graph and the type of aggregation represented by ``\square`` is selected by the `aggr` argument. Commonly used aggregations are `mean`, `max`, and `+`. @@ -17,6 +18,7 @@ Commonly used aggregations are `mean`, `max`, and `+`. See also [`reduce_nodes`](@ref). # Examples + ```julia using Flux, GraphNeuralNetworks, Graphs @@ -50,18 +52,42 @@ Global soft attention layer from the [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) paper ```math -\mathbf{u}_V} = \sum_{i\in V} \mathrm{softmax} \left( - f_{\mathrm{gate}} ( \mathbf{x}_i ) \right) \odot - f_{\mathrm{feat}} ( \mathbf{x}_i ), +\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i) ``` -where ``f_{\mathrm{gate}} \colon \mathbb{R}^F \to -\mathbb{R}`` and ``f_{\mathbf{feat}}` denote neural networks. +where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref) +operation: + +```math +\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}} + {\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}. +``` # Arguments -fgate: -ffeat: +- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to +\mathbb{R}``. It is tipically a neural network. + +- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to +\mathbb{R}^{D_{out}}``. It is tipically a neural network. + +# Examples + +```julia +chin = 6 +chout = 5 + +fgate = Dense(chin, 1) +ffeat = Dense(chin, chout) +pool = GlobalAttentionPool(fgate, ffeat) + +g = Flux.batch([GNNGraph(random_regular_graph(10, 4), + ndata=rand(Float32, chin, 10)) + for i=1:3]) + +u = pool(g, g.ndata.x) + +@assert size(u) == (chout, g.num_graphs) """ struct GlobalAttentionPool{G,F} fgate::G @@ -72,11 +98,10 @@ end GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) - function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray) - weights = softmax_nodes(g, l.fgate(x)) - feats = l.ffeat(x) - u = reduce_nodes(+, g, weights .* feats) + α = softmax_nodes(g, l.fgate(x)) + feats = α .* l.ffeat(x) + u = reduce_nodes(+, g, feats) return u end @@ -101,9 +126,6 @@ struct TopKPool{T,S} Ã::AbstractMatrix{T} end - - - function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init=glorot_uniform) TopKPool(adj, k, init(in_channel), similar(adj, k, k)) end diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 8129814e7..f7bb74a83 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -4,13 +4,14 @@ n = 10 chin = 6 X = rand(Float32, 6, n) - g = GNNGraph(random_regular_graph(n, 4), ndata=X) + g = GNNGraph(random_regular_graph(n, 4), ndata=X, graph_type=GRAPH_T) u = p(g, X) @test u ≈ sum(X, dims=2) ng = 3 g = Flux.batch([GNNGraph(random_regular_graph(n, 4), - ndata=rand(Float32, chin, n)) + ndata=rand(Float32, chin, n), + graph_type=GRAPH_T) for i=1:ng]) u = p(g, g.ndata.x) @test size(u) == (chin, ng) @@ -22,13 +23,21 @@ @testset "GlobalAttentionPool" begin n = 10 - chin = 16 - X = rand(Float32, chin, n) - g = GNNGraph(random_regular_graph(n, 4), ndata=X) - fgate = Dense(chin, 1, sigmoid) - p = GlobalAttentionPool(fgate) - y = p(g, X) - test_layer(p, g, rtol=1e-5, outtype=:graph) + chin = 6 + chout = 5 + ng = 3 + + fgate = Dense(chin, 1) + ffeat = Dense(chin, chout) + p = GlobalAttentionPool(fgate, ffeat) + @test length(Flux.params(p)) == 4 + + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), + ndata=rand(Float32, chin, n), + graph_type=GRAPH_T) + for i=1:ng]) + + test_layer(p, g, rtol=1e-5, outtype=:graph, outsize=(chout, ng)) end diff --git a/test/test_utils.jl b/test/test_utils.jl index 74c58469e..509301176 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -34,14 +34,21 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5, x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g]) - f(l, g) = l(g) - f(l, g, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e) - f(l, g, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64) - f(l, g, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu) + f(l, g::GNNGraph) = l(g) + f(l, g::GNNGraph, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e) + f(l, g::GNNGraph, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64) + f(l, g::GNNGraph, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu) - loss(l, g) = sum(node_features(f(l, g))) - loss(l, g, x) = sum(f(l, g, x)) - loss(l, g, x, e) = sum(l(g, x, e)) + loss(l, g::GNNGraph) = if outtype == :node + sum(node_features(f(l, g))) + elseif outtype == :edge + sum(edge_features(f(l, g))) + elseif outtype == :graph + sum(graph_features(f(l, g))) + end + + loss(l, g::GNNGraph, x) = sum(f(l, g, x)) + loss(l, g::GNNGraph, x, e) = sum(l(g, x, e)) # TEST OUTPUT @@ -117,7 +124,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5, # TEST LAYER GRADIENT - l(g) l̄ = gradient(l -> loss(l, g), l)[1] - l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64), l64)[1] test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose) return true