Skip to content

Commit

Permalink
Merge pull request #223 from CarloLucibello/cl/egnn
Browse files Browse the repository at this point in the history
equivariant gnn
  • Loading branch information
CarloLucibello authored Nov 9, 2022
2 parents 377fcf1 + f228781 commit 24a22aa
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 5 deletions.
2 changes: 2 additions & 0 deletions docs/src/api/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ propagate
copy_xi
copy_xj
xi_dot_xj
xi_sub_xj
xj_sub_xi
e_mul_xj
w_mul_xj
```
5 changes: 4 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ export
copy_xj,
copy_xi,
xi_dot_xj,
xi_sub_xj,
xj_sub_xi,
e_mul_xj,
w_mul_xj,

Expand All @@ -50,17 +52,18 @@ export
CGConv,
ChebConv,
EdgeConv,
EGNNConv,
GATConv,
GATv2Conv,
GatedGraphConv,
GCNConv,
GINConv,
GMMConv,
GraphConv,
MEGNetConv,
NNConv,
ResGatedGraphConv,
SAGEConv,
GMMConv,
SGConv,

# layers/pool
Expand Down
129 changes: 129 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ The input to the layer is a node feature array `x` of size `(num_features, num_n
edge pseudo-coordinate array `e` of size `(num_features, num_edges)`
The residual ``\mathbf{x}_i`` is added only if `residual=true` and the output size is the same
as the input size.
# Arguments
- `in`: Number of input node features.
Expand Down Expand Up @@ -1298,3 +1299,131 @@ function Base.show(io::IO, l::SGConv)
l.k == 1 || print(io, ", ", l.k)
print(io, ")")
end



@doc raw"""
EdgeConv((in, ein) => out, hidden_size)
EdgeConv(in => out, hidden_size=2*in)
Equivariant Graph Convolutional Layer from [E(n) Equivariant Graph
Neural Networks](https://arxiv.org/abs/2102.09844).
The layer performs the following operation:
```math
\mathbf{m}_{j\to i}=\phi_e(\mathbf{h}_i, \mathbf{h}_j, \lVert\mathbf{x}_i-\mathbf{x}_j\rVert^2, \mathbf{e}_{j\to i}),\\
\mathbf{x}_i' = \mathbf{h}_i{x_i} + C_i\sum_{j\in\mathcal{N}(i)}(\mathbf{x}_i-\mathbf{x}_j)\phi_x(\mathbf{m}_{j\to i}),\\
\mathbf{m}_i = C_i\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{j\to i},\\
\mathbf{h}_i' = \mathbf{h}_i + \phi_h(\mathbf{h}_i, \mathbf{m}_i)
```
where ``h_i``, ``x_i``, ``e_{ij}`` are invariant node features, equivariance node
features, and edge features respectively. ``\phi_e``, ``\phi_h``, and
``\phi_x`` are two-layer MLPs. :math:`C` is a constant for normalization,
computed as ``1/|\mathcal{N}(i)|``.
# Constructor Arguments
- `in`: Number of input features for `h`.
- `out`: Number of output features for `h`.
- `ein`: Number of input edge features.
- `hidden_size`: Hidden representation size.
- `residual`: If `true`, add a residual connection. Only possible if `in == out`. Default `false`.
# Forward Pass
l(g, x, h, e=nothing)
## Forward Pass Arguments:
- `g` : The graph.
- `x` : Matrix of equivariant node coordinates.
- `h` : Matrix of invariant node features.
- `e` : Matrix of invariant edge features. Default `nothing`.
Returns updated `h` and `x`.
# Examples
```julia
g = rand_graph(10, 10)
h = randn(Float32, 5, g.num_nodes)
x = randn(Float32, 3, g.num_nodes)
egnn = EGNNConv(5 => 6, 10)
hnew, xnew = egnn(g, h, x)
```
"""
struct EGNNConv <: GNNLayer
ϕe::Chain
ϕx::Chain
ϕh::Chain
num_features::NamedTuple
residual::Bool
end

@functor EGNNConv

EGNNConv(ch::Pair{Int,Int}, hidden_size=2*ch[1]) = EGNNConv((ch[1], 0) => ch[2], hidden_size)

#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
function EGNNConv(ch::Pair{NTuple{2, Int}, Int}, hidden_size::Int, residual=false)
(in_size, edge_feat_size), out_size = ch
act_fn = swish

# +1 for the radial feature: ||x_i - x_j||^2
ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
Dense(hidden_size => hidden_size, act_fn))

ϕh = Chain(Dense(in_size + hidden_size, hidden_size, swish),
Dense(hidden_size, out_size))

ϕx = Chain(Dense(hidden_size, hidden_size, swish),
Dense(hidden_size, 1, bias=false))

num_features = (in=in_size, edge=edge_feat_size, out=out_size)
if residual
@assert in_size == out_size "Residual connection only possible if in_size == out_size"
end
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
end

function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e=nothing)
if l.num_features.edge > 0
@assert e !== nothing "Edge features must be provided."
end
@assert size(h, 1) == l.num_features.in "Input features must match layer input size."


@show size(x) size(h)

function message(xi, xj, e)
if l.num_features.edge > 0
f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e)
else
f = vcat(xi.h, xj.h, e.sqnorm_xdiff)
end

msg_h = l.ϕe(f)
msg_x = l.ϕx(msg_h) .* e.x_diff
return (; x=msg_x, h=msg_h)
end

x_diff = apply_edges(xi_sub_xj, g, x, x)
sqnorm_xdiff = sum(x_diff .^ 2, dims=1)
x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1f-6)

msg = apply_edges(message, g, xi=(; h), xj=(; h), e=(; e, x_diff, sqnorm_xdiff))
h_aggr = aggregate_neighbors(g, +, msg.h)
x_aggr = aggregate_neighbors(g, mean, msg.x)

hnew = l.ϕh(vcat(h, h_aggr))
if l.residual
h = h .+ hnew
else
h = hnew
end
x = x .+ x_aggr

return h, x
end
20 changes: 16 additions & 4 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
propagate(f, g, aggr; xi, xj, e) -> m̄
propagate(f, g, aggr; [xi, xj, e]) -> m̄
propagate(f, g, aggr, xi, xj, e=nothing)
Performs message passing on graph `g`. Takes care of materializing the node features on each edge,
applying the message function, and returning an aggregated message ``\\bar{\\mathbf{m}}``
Expand Down Expand Up @@ -68,7 +69,7 @@ function propagate end
propagate(l, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) =
propagate(l, g, aggr, xi, xj, e)

function propagate(l, g::GNNGraph, aggr, xi, xj, e)
function propagate(l, g::GNNGraph, aggr, xi, xj, e=nothing)
m = apply_edges(l, g, xi, xj, e)
= aggregate_neighbors(g, aggr, m)
return
Expand All @@ -77,8 +78,8 @@ end
## APPLY EDGES

"""
apply_edges(f, g, xi, xj, e)
apply_edges(f, g; [xi, xj, e])
apply_edges(f, g, xi, xj, e=nothing)
Returns the message from node `j` to node `i` .
In the message-passing scheme, the incoming messages
Expand Down Expand Up @@ -110,7 +111,7 @@ function apply_edges end
apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) =
apply_edges(l, g, xi, xj, e)

function apply_edges(f, g::GNNGraph, xi, xj, e)
function apply_edges(f, g::GNNGraph, xi, xj, e=nothing)
check_num_nodes(g, xi)
check_num_nodes(g, xj)
check_num_edges(g, e)
Expand Down Expand Up @@ -158,6 +159,17 @@ copy_xi(xi, xj, e) = xi
"""
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)

"""
xi_sub_xj(xi, xj, e) = xi .- xj
"""
xi_sub_xj(xi, xj, e) = xi .- xj

"""
xj_sub_xi(xi, xj, e) = xj .- xi
"""
xj_sub_xi(xi, xj, e) = xj .- xi


"""
e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj
Expand Down
13 changes: 13 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,17 @@
end
end
end

@testset "EGNNConv" begin
hin = 5
hout = 5
hidden = 5
l = EGNNConv(hin => hout, hidden)
g = rand_graph(10, 20, graph_type=GRAPH_T)
x = rand(T, in_channel, g.num_nodes)
h = randn(T, hin, g.num_nodes)
hnew, xnew = l(g, h, x)
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_channel, g.num_nodes)
end
end

0 comments on commit 24a22aa

Please sign in to comment.