diff --git a/docs/src/api/utils.md b/docs/src/api/utils.md index a4633e001..a7b1ac684 100644 --- a/docs/src/api/utils.md +++ b/docs/src/api/utils.md @@ -25,6 +25,12 @@ GraphNeuralNetworks.broadcast_nodes GraphNeuralNetworks.broadcast_edges ``` +### Neighborhood operations + +```@docs +GraphNeuralNetworks.softmax_edge_neighbors +``` + ### NNlib Primitive functions implemented in NNlib.jl. diff --git a/docs/src/messagepassing.md b/docs/src/messagepassing.md index ca9b94cf9..e058575a0 100644 --- a/docs/src/messagepassing.md +++ b/docs/src/messagepassing.md @@ -14,14 +14,14 @@ A generic message passing on graph takes the form where we refer to ``\phi`` as to the message function, and to ``\gamma_x`` and ``\gamma_e`` as to the node update and edge update function respectively. The aggregation ``\square`` is over the neighborhood ``N(i)`` of node ``i``, -and it is usually set to summation ``\sum``, a max or a mean operation. +and it is usually equal either to ``\sum``, to `max` or to a `mean` operation. In GNN.jl, the function [`propagate`](@ref) takes care of materializing the node features on each edge, applying the message function, performing the aggregation, and returning ``\bar{\mathbf{m}}``. It is then left to the user to perform further node and edge updates, -manypulating arrays of size ``D_{node} \times num_nodes`` and -``D_{edge} \times num_edges``. +manypulating arrays of size ``D_{node} \times num\_nodes`` and +``D_{edge} \times num\_edges``. As part of the [`propagate`](@ref) pipeline, we have the function [`apply_edges`](@ref). It can be independently used to materialize @@ -34,9 +34,9 @@ and [`NNlib.scatter`](@ref) methods. ## Examples -### Basic use propagate and apply_edges - +### Basic use of propagate and apply_edges +TODO ### Implementing a custom Graph Convolutional Layer diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 9eb2f03e6..c05183545 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -35,6 +35,7 @@ export reduce_nodes, reduce_edges, softmax_nodes, softmax_edges, broadcast_nodes, broadcast_edges, + softmax_edge_neighbors, # msgpass apply_edges, propagate, diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 7bb386506..88ed3ddd0 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -52,24 +52,24 @@ Global soft attention layer from the [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) paper ```math -\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i) +\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i) ``` -where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref) +where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref) operation: ```math -\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}} - {\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}. +\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} + {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}. ``` # Arguments -- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to -\mathbb{R}``. It is tipically a neural network. +- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``. + It is tipically expressed by a neural network. -- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to -\mathbb{R}^{D_{out}}``. It is tipically a neural network. +- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``. + It is tipically expressed by a neural network. # Examples @@ -88,6 +88,7 @@ g = Flux.batch([GNNGraph(random_regular_graph(10, 4), u = pool(g, g.ndata.x) @assert size(u) == (chout, g.num_graphs) +``` """ struct GlobalAttentionPool{G,F} fgate::G diff --git a/src/utils.jl b/src/utils.jl index 680958da9..7f9506d27 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -190,6 +190,25 @@ function softmax_edges(g::GNNGraph, e) return num ./ den end +@doc raw""" + softmax_edge_neighbors(g, e) + +Softmax over each node's neighborhood of the edge features `e`. + +```math +\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} + {\sum_{j'\in N(i)} e^{\mathbf{e}_{j\to i}}}. +``` +""" +function softmax_edge_neighbors(g::GNNGraph, e) + @assert size(e)[end] == g.num_edges + s, t = edge_index(g) + max_ = gather(scatter(max, e, t), t) + num = exp.(e .- max_) + den = gather(scatter(+, num, t), t) + return num ./ den +end + """ broadcast_nodes(g, x) diff --git a/test/utils.jl b/test/utils.jl index 3a86b51ce..c61831e5a 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -31,7 +31,6 @@ @test r[:,1:60] ≈ softmax(getgraph(g, 1).edata.e, dims=2) end - @testset "broadcast_nodes" begin z = rand(4, g.num_graphs) r = broadcast_nodes(g, z) @@ -49,4 +48,15 @@ @test r[:,60] ≈ z[:,1] @test r[:,61] ≈ z[:,2] end + + @testset "softmax_edge_neighbors" begin + s = [1,2,3,4] + t = [5,5,6,6] + g2 = GNNGraph(s, t) + e2 = randn(Float32, 3, g2.num_edges) + z = softmax_edge_neighbors(g2, e2) + @test size(z) == size(e2) + @test z[:,1:2] ≈ NNlib.softmax(e2[:,1:2], dims=2) + @test z[:,3:4] ≈ NNlib.softmax(e2[:,3:4], dims=2) + end end