Skip to content

Commit

Permalink
add softmax_edge_neighbors (#59)
Browse files Browse the repository at this point in the history
* add softmax_edge_neighbors
  • Loading branch information
CarloLucibello authored Oct 24, 2021
1 parent 5ad064f commit 24512ee
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 14 deletions.
6 changes: 6 additions & 0 deletions docs/src/api/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ GraphNeuralNetworks.broadcast_nodes
GraphNeuralNetworks.broadcast_edges
```

### Neighborhood operations

```@docs
GraphNeuralNetworks.softmax_edge_neighbors
```

### NNlib

Primitive functions implemented in NNlib.jl.
Expand Down
10 changes: 5 additions & 5 deletions docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ A generic message passing on graph takes the form
where we refer to ``\phi`` as to the message function,
and to ``\gamma_x`` and ``\gamma_e`` as to the node update and edge update function
respectively. The aggregation ``\square`` is over the neighborhood ``N(i)`` of node ``i``,
and it is usually set to summation ``\sum``, a max or a mean operation.
and it is usually equal either to ``\sum``, to `max` or to a `mean` operation.

In GNN.jl, the function [`propagate`](@ref) takes care of materializing the
node features on each edge, applying the message function, performing the
aggregation, and returning ``\bar{\mathbf{m}}``.
It is then left to the user to perform further node and edge updates,
manypulating arrays of size ``D_{node} \times num_nodes`` and
``D_{edge} \times num_edges``.
manypulating arrays of size ``D_{node} \times num\_nodes`` and
``D_{edge} \times num\_edges``.

As part of the [`propagate`](@ref) pipeline, we have the function
[`apply_edges`](@ref). It can be independently used to materialize
Expand All @@ -34,9 +34,9 @@ and [`NNlib.scatter`](@ref) methods.

## Examples

### Basic use propagate and apply_edges

### Basic use of propagate and apply_edges

TODO

### Implementing a custom Graph Convolutional Layer

Expand Down
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export
reduce_nodes, reduce_edges,
softmax_nodes, softmax_edges,
broadcast_nodes, broadcast_edges,
softmax_edge_neighbors,

# msgpass
apply_edges, propagate,
Expand Down
17 changes: 9 additions & 8 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,24 @@ Global soft attention layer from the [Gated Graph Sequence Neural
Networks](https://arxiv.org/abs/1511.05493) paper
```math
\mathbf{u}_V} = \sum_{i\in V} \alpha_i\, f_{\mathrm{feat}}(\mathbf{x}_i)
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
```
where the coefficients ``alpha_i`` are given by a [`softmax_nodes`](@ref)
where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref)
operation:
```math
\alpha_i = \frac{e^{f_{\mathrm{feat}}(\mathbf{x}_i)}}
{\sum_{i'\in V} e^{f_{\mathrm{feat}}(\mathbf{x}_{i'})}}.
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
```
# Arguments
- `fgate`: The function ``f_{\mathrm{gate}} \colon \mathbb{R}^{D_{in}} \to
\mathbb{R}``. It is tipically a neural network.
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is tipically expressed by a neural network.
- `ffeat`: The function ``f_{\mathrm{feat}} \colon \mathbb{R}^{D_{in}} \to
\mathbb{R}^{D_{out}}``. It is tipically a neural network.
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is tipically expressed by a neural network.
# Examples
Expand All @@ -88,6 +88,7 @@ g = Flux.batch([GNNGraph(random_regular_graph(10, 4),
u = pool(g, g.ndata.x)
@assert size(u) == (chout, g.num_graphs)
```
"""
struct GlobalAttentionPool{G,F}
fgate::G
Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ function softmax_edges(g::GNNGraph, e)
return num ./ den
end

@doc raw"""
softmax_edge_neighbors(g, e)
Softmax over each node's neighborhood of the edge features `e`.
```math
\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}}
{\sum_{j'\in N(i)} e^{\mathbf{e}_{j\to i}}}.
```
"""
function softmax_edge_neighbors(g::GNNGraph, e)
@assert size(e)[end] == g.num_edges
s, t = edge_index(g)
max_ = gather(scatter(max, e, t), t)
num = exp.(e .- max_)
den = gather(scatter(+, num, t), t)
return num ./ den
end

"""
broadcast_nodes(g, x)
Expand Down
12 changes: 11 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
end


@testset "broadcast_nodes" begin
z = rand(4, g.num_graphs)
r = broadcast_nodes(g, z)
Expand All @@ -49,4 +48,15 @@
@test r[:,60] z[:,1]
@test r[:,61] z[:,2]
end

@testset "softmax_edge_neighbors" begin
s = [1,2,3,4]
t = [5,5,6,6]
g2 = GNNGraph(s, t)
e2 = randn(Float32, 3, g2.num_edges)
z = softmax_edge_neighbors(g2, e2)
@test size(z) == size(e2)
@test z[:,1:2] NNlib.softmax(e2[:,1:2], dims=2)
@test z[:,3:4] NNlib.softmax(e2[:,3:4], dims=2)
end
end

0 comments on commit 24512ee

Please sign in to comment.