From f228781c03ec006649c0ca774821979df8bac3a2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 9 Nov 2022 15:33:47 +0100 Subject: [PATCH] equivariant gnn --- docs/src/api/messagepassing.md | 2 + src/GraphNeuralNetworks.jl | 5 +- src/layers/conv.jl | 129 +++++++++++++++++++++++++++++++++ src/msgpass.jl | 20 ++++- test/layers/conv.jl | 13 ++++ 5 files changed, 164 insertions(+), 5 deletions(-) diff --git a/docs/src/api/messagepassing.md b/docs/src/api/messagepassing.md index a3d1b9708..e7ade6d5b 100644 --- a/docs/src/api/messagepassing.md +++ b/docs/src/api/messagepassing.md @@ -25,6 +25,8 @@ propagate copy_xi copy_xj xi_dot_xj +xi_sub_xj +xj_sub_xi e_mul_xj w_mul_xj ``` diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 42c20c258..a7d5af45e 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -36,6 +36,8 @@ export copy_xj, copy_xi, xi_dot_xj, + xi_sub_xj, + xj_sub_xi, e_mul_xj, w_mul_xj, @@ -50,17 +52,18 @@ export CGConv, ChebConv, EdgeConv, + EGNNConv, GATConv, GATv2Conv, GatedGraphConv, GCNConv, GINConv, + GMMConv, GraphConv, MEGNetConv, NNConv, ResGatedGraphConv, SAGEConv, - GMMConv, SGConv, # layers/pool diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c1409c71f..570b4df8a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1083,6 +1083,7 @@ The input to the layer is a node feature array `x` of size `(num_features, num_n edge pseudo-coordinate array `e` of size `(num_features, num_edges)` The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same as the input size. + # Arguments - `in`: Number of input node features. @@ -1298,3 +1299,131 @@ function Base.show(io::IO, l::SGConv) l.k == 1 || print(io, ", ", l.k) print(io, ")") end + + + +@doc raw""" + EdgeConv((in, ein) => out, hidden_size) + EdgeConv(in => out, hidden_size=2*in) + +Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph +Neural Networks](https://arxiv.org/abs/2102.09844). + +The layer performs the following operation: + +```math +\mathbf{m}_{j\to i}=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\ +\mathbf{x}_i' = \mathbf{h}_i{x_i} + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\ +\mathbf{m}_i = C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\ +\mathbf{h}_i' = \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i) +``` +where ``h_i``, ``x_i``, ``e_{ij}`` are invariant node features, equivariance node +features, and edge features respectively. ``\phi_e``, ``\phi_h``, and +``\phi_x`` are two-layer MLPs. :math:`C` is a constant for normalization, +computed as ``1/|\mathcal{N}(i)|``. + + +# Constructor Arguments + +- `in`: Number of input features for `h`. +- `out`: Number of output features for `h`. +- `ein`: Number of input edge features. +- `hidden_size`: Hidden representation size. +- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`. + +# Forward Pass + + l(g, x, h, e=nothing) + +## Forward Pass Arguments: + +- `g` : The graph. +- `x` : Matrix of equivariant node coordinates. +- `h` : Matrix of invariant node features. +- `e` : Matrix of invariant edge features. Default `nothing`. + +Returns updated `h` and `x`. + +# Examples + +```julia +g = rand_graph(10, 10) +h = randn(Float32, 5, g.num_nodes) +x = randn(Float32, 3, g.num_nodes) +egnn = EGNNConv(5 => 6, 10) +hnew, xnew = egnn(g, h, x) +``` +""" +struct EGNNConv <: GNNLayer + ϕe::Chain + ϕx::Chain + ϕh::Chain + num_features::NamedTuple + residual::Bool +end + +@functor EGNNConv + +EGNNConv(ch::Pair{Int,Int}, hidden_size=2*ch[1]) = EGNNConv((ch[1], 0) => ch[2], hidden_size) + +#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py +function EGNNConv(ch::Pair{NTuple{2, Int}, Int}, hidden_size::Int, residual=false) + (in_size, edge_feat_size), out_size = ch + act_fn = swish + + # +1 for the radial feature: ||x_i - x_j||^2 + ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn), + Dense(hidden_size => hidden_size, act_fn)) + + ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish), + Dense(hidden_size, out_size)) + + ϕx = Chain(Dense(hidden_size, hidden_size, swish), + Dense(hidden_size, 1, bias=false)) + + num_features = (in=in_size, edge=edge_feat_size, out=out_size) + if residual + @assert in_size == out_size "Residual connection only possible if in_size == out_size" + end + return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) +end + +function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e=nothing) + if l.num_features.edge > 0 + @assert e !== nothing "Edge features must be provided." + end + @assert size(h, 1) == l.num_features.in "Input features must match layer input size." + + + @show size(x) size(h) + + function message(xi, xj, e) + if l.num_features.edge > 0 + f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) + else + f = vcat(xi.h, xj.h, e.sqnorm_xdiff) + end + + msg_h = l.ϕe(f) + msg_x = l.ϕx(msg_h) .* e.x_diff + return (; x=msg_x, h=msg_h) + end + + x_diff = apply_edges(xi_sub_xj, g, x, x) + sqnorm_xdiff = sum(x_diff .^ 2, dims=1) + x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1f-6) + + msg = apply_edges(message, g, xi=(; h), xj=(; h), e=(; e, x_diff, sqnorm_xdiff)) + h_aggr = aggregate_neighbors(g, +, msg.h) + x_aggr = aggregate_neighbors(g, mean, msg.x) + + hnew = l.ϕh(vcat(h, h_aggr)) + if l.residual + h = h .+ hnew + else + h = hnew + end + x = x .+ x_aggr + + return h, x +end diff --git a/src/msgpass.jl b/src/msgpass.jl index d6f6df6d7..61af0c15b 100644 --- a/src/msgpass.jl +++ b/src/msgpass.jl @@ -1,5 +1,6 @@ """ - propagate(f, g, aggr; xi, xj, e) -> m̄ + propagate(f, g, aggr; [xi, xj, e]) -> m̄ + propagate(f, g, aggr, xi, xj, e=nothing) Performs message passing on graph `g`. Takes care of materializing the node features on each edge, applying the message function, and returning an aggregated message ``\\bar{\\mathbf{m}}`` @@ -68,7 +69,7 @@ function propagate end propagate(l, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) = propagate(l, g, aggr, xi, xj, e) -function propagate(l, g::GNNGraph, aggr, xi, xj, e) +function propagate(l, g::GNNGraph, aggr, xi, xj, e=nothing) m = apply_edges(l, g, xi, xj, e) m̄ = aggregate_neighbors(g, aggr, m) return m̄ @@ -77,8 +78,8 @@ end ## APPLY EDGES """ - apply_edges(f, g, xi, xj, e) apply_edges(f, g; [xi, xj, e]) + apply_edges(f, g, xi, xj, e=nothing) Returns the message from node `j` to node `i` . In the message-passing scheme, the incoming messages @@ -110,7 +111,7 @@ function apply_edges end apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) = apply_edges(l, g, xi, xj, e) -function apply_edges(f, g::GNNGraph, xi, xj, e) +function apply_edges(f, g::GNNGraph, xi, xj, e=nothing) check_num_nodes(g, xi) check_num_nodes(g, xj) check_num_edges(g, e) @@ -158,6 +159,17 @@ copy_xi(xi, xj, e) = xi """ xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1) +""" + xi_sub_xj(xi, xj, e) = xi .- xj +""" +xi_sub_xj(xi, xj, e) = xi .- xj + +""" + xj_sub_xi(xi, xj, e) = xj .- xi +""" +xj_sub_xi(xi, xj, e) = xj .- xi + + """ e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj diff --git a/test/layers/conv.jl b/test/layers/conv.jl index a8706f0d7..6d8687a84 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -288,4 +288,17 @@ end end end + + @testset "EGNNConv" begin + hin = 5 + hout = 5 + hidden = 5 + l = EGNNConv(hin => hout, hidden) + g = rand_graph(10, 20, graph_type=GRAPH_T) + x = rand(T, in_channel, g.num_nodes) + h = randn(T, hin, g.num_nodes) + hnew, xnew = l(g, h, x) + @test size(hnew) == (hout, g.num_nodes) + @test size(xnew) == (in_channel, g.num_nodes) + end end