Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gno conv #37

Merged
merged 3 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NeuralGraphPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include("utils.jl")
include("layers.jl")

export AbstractGNNLayer, AbstractGNNContainerLayer
export ExplicitEdgeConv, ExplicitGCNConv, VMHConv, MPPDEConv
export ExplicitEdgeConv, ExplicitGCNConv, VMHConv, MPPDEConv, GNOConv
export updategraph

end
158 changes: 142 additions & 16 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -339,57 +339,41 @@ end

@doc raw"""
MPPDEConv(ϕ, ψ; initialgraph = initialgraph, aggr = sum, local_features = (:u, :x))

Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/abs/2202.03376), without the temporal bulking trick.
```math
\begin{aligned}
\mathbf{m}_i&=\Box _{j\in N(i)}\,\phi (\mathbf{h}_i,\mathbf{h}_j;\mathbf{u}_i-\mathbf{u}_j;\mathbf{x}_i-\mathbf{x}_j;\theta )\\
\mathbf{h}_i'&=\psi (\mathbf{h}_i,\mathbf{m}_i,\theta )\\
\end{aligned}
```

# Arguments

- `ϕ`: The neural network for the message function.
- `ψ`: The neural network for the update function.
- `initialgraph`: `GNNGraph` or a function that returns a `GNNGraph`
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `local_features`: The features that will be differentiated in the message function.

# Inputs

- `h`: Trainable node embeddings, `Array`.

# Returns

- `NamedTuple` or `Array` that has the same struct with `x` with different a size of channels.

# Parameters

- Parameters of `ϕ`.
- Parameters of `ψ`.

# States

- `graph`: `GNNGraph` where `graph.ndata.x` represents the spatial coordinates of nodes, `graph.ndata.u` represents the initial condition, and `graph.gdata.θ` represents the graph level features of the underlying PDE. `θ` should be a matrix
of the size `(num_feats, num_graphs)`. If `g` is a batched graph, then all graphs need to have the same structure.

# Examples
```julia
g = rand_graph(10, 6)

g = GNNGraph(g, ndata = (; u = rand(2, 10), x = rand(3, 10)), gdata = (; θ = rand(4)))
h = randn(5, 10)
ϕ = Dense(5 + 5 + 2 + 3 + 4 => 5)
ψ = Dense(5 + 5 + 4 => 7)
l = MPPDEConv(ϕ, ψ, initialgraph = g)

rng = Random.default_rng()
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
```

"""
struct MPPDEConv{F, L, M1, M2, A} <: AbstractGNNContainerLayer{(:ϕ, :ψ)}
initialgraph::F
Expand Down Expand Up @@ -435,3 +419,145 @@ function (l::MPPDEConv)(x::AbstractArray, ps, st::NamedTuple)

return y, st
end

@doc raw"""
GNOConv(in_chs => out_chs, ϕ; initialgraph = initialgraph, aggr = mean)

Convolutional layer from [Neural Operator: Graph Kernel Network for Partial Differential Equations](https://openreview.net/forum?id=5fbUEUTZEn7).
```math
\begin{aligned}
\mathbf{m}_i&=\Box _{j\in N(i)}\,\phi (\mathbf{a}_i,\mathbf{a}_j,\mathbf{x}_i,\mathbf{x}_j)\mathbf{h}_j\\
\mathbf{h}_i'&=\,\,\sigma \left( \mathbf{Wh}_i+\mathbf{m}_i+\mathbf{b} \right)\\
\end{aligned}
```

# Arguments

- `in_chs`: The number of input channels.
- `out_chs`: The number of output channels.
- `ϕ`: The neural network for the message function. The output size of `ϕ` should be `in_chs * out_chs`.
- `initialgraph`: `GNNGraph` or a function that returns a `GNNGraph`
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).

# Inputs

- `h`: `Array` of the size `(in_chs, num_nodes)`.

# Returns

- `Array` of the size `(out_chs, num_nodes)`.

# Parameters

- Parameters of `ϕ`.
- `W`.
- `b`.

# States

- `graph`: `GNNGraph`. All features in `graph.ndata` will be concatenated and then fed into `ϕ`.

# Examples
```julia
g = rand_graph(10, 6)

g = GNNGraph(g, ndata = (; a = rand(2, 10), x = rand(3, 10)))
in_chs, out_chs = 5, 7
h = randn(in_chs, 10)
ϕ = Dense(2 + 2 + 3 + 3 => in_chs * out_chs)
l = GNOConv(5 => 7, ϕ, initialgraph = g)

rng = Random.default_rng()
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
```

"""
struct GNOConv{bias, A} <: AbstractGNNContainerLayer{(:linear, :ϕ)}
in_chs::Int
out_chs::Int
initialgraph::Function
aggr::A
linear::Dense
ϕ::AbstractExplicitLayer
end

function GNOConv(in_chs::Int, out_chs::Int, ϕ::AbstractExplicitLayer, activation = identity;
initialgraph = initialgraph,
init_weight = glorot_uniform,
init_bias = zeros32,
aggr = mean,
bias::Bool = true)
GNOConv(in_chs => out_chs, ϕ, activation,
initialgraph = initialgraph, init_weight = init_weight, init_bias = init_bias,
aggr = aggr, bias = bias)
end

function GNOConv(ch::Pair{Int, Int}, ϕ::AbstractExplicitLayer, activation = identity;
initialgraph = initialgraph,
init_weight = glorot_uniform,
init_bias = zeros32,
aggr = mean,
bias::Bool = true)
initialgraph = wrapgraph(initialgraph)
linear = Dense(ch, activation,
init_weight = init_weight,
init_bias = init_bias,
bias = bias)
GNOConv{bias, typeof(aggr)}(first(ch), last(ch), initialgraph, aggr, linear, ϕ)
end

function (l::GNOConv{true})(x::AbstractArray, ps, st::NamedTuple)
g = st.graph
s = g.ndata
edge_features = keys(s)

function message(xi, xj, e)
si, sj = xi[edge_features], xj[edge_features]
si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj))

W, st_ϕ = l.ϕ(vcat(si, sj), ps.ϕ, st.ϕ)
st = merge(st, (; ϕ = st_ϕ))

hj = xj.h
nin, nedges = size(hj)
W = reshape(W, :, nin, nedges)
hj = reshape(hj, (nin, 1, nedges))
m = NNlib.batched_mul(W, hj)
return reshape(m, :, nedges)
end

xs = merge((; h = x), s)
m = propagate(message, g, l.aggr, xi = xs, xj = xs)

y = l.linear.activation(ps.linear.weight * x .+ m .+ ps.linear.bias)
return y, st
end

function (l::GNOConv{false})(x::AbstractArray, ps, st::NamedTuple)
g = st.graph
s = g.ndata
edge_features = keys(s)

function message(xi, xj, e)
si, sj = xi[edge_features], xj[edge_features]
si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj))

W, st_ϕ = l.ϕ(vcat(si, sj), ps.ϕ, st.ϕ)
st = merge(st, (; ϕ = st_ϕ))

hj = xj.h
nin, nedges = size(hj)
W = reshape(W, :, nin, nedges)
hj = reshape(hj, (nin, 1, nedges))
m = NNlib.batched_mul(W, hj)
return reshape(m, :, nedges)
end

xs = merge((; h = x), s)
m = propagate(message, g, l.aggr, xi = xs, xj = xs)

y = l.linear.activation(ps.linear.weight * x .+ m)
return y, st
end
25 changes: 24 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Random
using Lux
using Lux: parameterlength
using Test
using Flux: batch, unbatch
import Flux: batch, unbatch

@testset "layers" begin
rng = Random.default_rng()
Expand Down Expand Up @@ -112,6 +112,29 @@ using Flux: batch, unbatch
@test size(y) == (7, gh.num_nodes)
end
end

@testset "GNOConv" begin
g = rand_graph(10, 6)

g = GNNGraph(g, ndata = (; a = rand(2, g.num_nodes), x = rand(3, g.num_nodes)))
in_chs, out_chs = 5, 7
h = randn(in_chs, g.num_nodes)
ϕ = Dense(2 + 2 + 3 + 3 => in_chs * out_chs)
l = GNOConv(5 => 7, ϕ, initialgraph = g)

rng = Random.default_rng()
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
@test size(y) == (7, g.num_nodes)

l = GNOConv(5 => 7, ϕ, initialgraph = g, bias = false)
rng = Random.default_rng()
ps, st = Lux.setup(rng, l)

y, st = l(h, ps, st)
@test size(y) == (7, g.num_nodes)
end
end
end

Expand Down