Skip to content

Commit

Permalink
Merge pull request #34 from MilkshakeForReal/datatype
Browse files Browse the repository at this point in the history
fix MPPDEConv
  • Loading branch information
YichengDWu authored Jun 28, 2022
2 parents 6e08a35 + 5eb3694 commit 8d6de4b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 26 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ NNlib = "0.8"
julia = "1.7"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

[targets]
test = ["Test","Random"]
test = ["Test", "Random", "Flux"]
22 changes: 13 additions & 9 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ end
@doc raw"""
MPPDEConv(ϕ, ψ; initialgraph = initialgraph, aggr = sum, local_features = (:u, :x))
Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/abs/2202.03376), without the temporal bulking trick.
Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/abs/2202.03376), without the temporal bulking trick.
```math
\begin{aligned}
\mathbf{m}_i&=\Box _{j\in N(i)}\,\phi (\mathbf{h}_i,\mathbf{h}_j;\mathbf{u}_i-\mathbf{u}_j;\mathbf{x}_i-\mathbf{x}_j;\theta )\\
Expand Down Expand Up @@ -371,7 +371,8 @@ Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/
# States
- `graph`: `GNNGraph` where `graph.ndata.x` represents the spatial coordinates of nodes, `graph.ndata.u` represents the initial condition, and `graph.gdata.θ` represents the graph level features of the underlying PDE. `θ` should be a vector.
- `graph`: `GNNGraph` where `graph.ndata.x` represents the spatial coordinates of nodes, `graph.ndata.u` represents the initial condition, and `graph.gdata.θ` represents the graph level features of the underlying PDE. `θ` should be a matrix
of the size `(num_feats, num_graphs)`. If `g` is a batched graph, then all graphs need to have the same structure.
# Examples
```julia
Expand Down Expand Up @@ -408,25 +409,28 @@ function MPPDEConv(ϕ::AbstractExplicitLayer, ψ::AbstractExplicitLayer;
end

function (l::MPPDEConv)(x::AbstractArray, ps, st::NamedTuple)
num_nodes = st.graph.num_nodes
num_edges = st.graph.num_edges
θ = vcat(values(st.graph.gdata)...)
g = st.graph
num_nodes = g.num_nodes
num_edges = g.num_edges
num_graphs = g.num_graphs
θ = reduce(vcat, values(st.graph.gdata), init = similar(x, 0, num_graphs))

function message(xi, xj, e)
di, dj = values(xi[l.local_features]), values(xj[l.local_features])
di, dj = reduce(vcat, values(xi[l.local_features])),
reduce(vcat, values(xj[l.local_features]))
hi, hj = xi.h, xj.h
m, st_ϕ = l.ϕ(vcat(hi, hj, (di .- dj)..., repeat(θ, 1, num_edges)), ps.ϕ, st.ϕ)
m, st_ϕ = l.ϕ(vcat(hi, hj, di .- dj,
repeat(θ, inner = (1, num_edges ÷ num_graphs))), ps.ϕ, st.ϕ)
st = merge(st, (; ϕ = st_ϕ))
return m
end

g = st.graph
s = g.ndata

xs = merge((; h = x), s)
m = propagate(message, g, l.aggr, xi = xs, xj = xs)

y, st_ψ = l.ψ(vcat(x, m, repeat(θ, 1, num_nodes)), ps.ψ, st.ψ)
y, st_ψ = l.ψ(vcat(x, m, repeat(θ, inner = (1, num_nodes ÷ num_graphs))), ps.ψ, st.ψ)
st = merge(st, (; ψ = st_ψ))

return y, st
Expand Down
72 changes: 57 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Random
using Lux
using Lux: parameterlength
using Test
using Flux: batch, unbatch

@testset "layers" begin
rng = Random.default_rng()
Expand Down Expand Up @@ -54,21 +55,62 @@ using Test
end

@testset "MPPDE" begin
gh = GNNGraph(g, ndata = (; u = rand(2, g.num_nodes), x = rand(3, g.num_nodes)),
gdata = (; θ = rand(4)))
h = randn(T, 5, g.num_nodes)
ϕ = Dense(5 + 5 + 2 + 3 + 4 => 5)
ψ = Dense(5 + 5 + 4 => 7)
l = MPPDEConv(ϕ, ψ, initialgraph = gh)

rng = Random.default_rng()
ps, st = Lux.setup(rng, l)
@test st.graph == gh

y, st = l(h, ps, st)
@test st.graph == gh

@test size(y) == (7, g.num_nodes)
@testset "With theta" begin
gh = GNNGraph(g,
ndata = (; u = rand(2, g.num_nodes),
x = rand(3, g.num_nodes)),
gdata = (; θ = rand(4)))

h = randn(T, 5, g.num_nodes)
ϕ = Dense(5 + 5 + 2 + 3 + 4 => 5)
ψ = Dense(5 + 5 + 4 => 7)
l = MPPDEConv(ϕ, ψ, initialgraph = gh)

ps, st = Lux.setup(rng, l)
@test st.graph == gh

y, st = l(h, ps, st)

@test st.graph == gh
@test size(y) == (7, g.num_nodes)
end

@testset "batched graph" begin
gh = GNNGraph(g,
ndata = (; u = rand(2, g.num_nodes),
x = rand(3, g.num_nodes)),
gdata = (; θ = rand(4)))
gh = batch([gh, copy(gh)])

h = randn(T, 5, gh.num_nodes)
ϕ = Dense(5 + 5 + 2 + 3 + 4 => 5)
ψ = Dense(5 + 5 + 4 => 7)
l = MPPDEConv(ϕ, ψ, initialgraph = gh)

ps, st = Lux.setup(rng, l)
y, st = l(h, ps, st)
@test size(y) == (7, gh.num_nodes)
end

@testset "Without theta" begin
gh = GNNGraph(g,
ndata = (; u = rand(2, g.num_nodes),
x = rand(3, g.num_nodes)))

h = randn(T, 5, gh.num_nodes)
ϕ = Dense(5 + 5 + 2 + 3 => 5)
ψ = Dense(5 + 5 => 7)
l = MPPDEConv(ϕ, ψ, initialgraph = gh)

rng = Random.default_rng()
ps, st = Lux.setup(rng, l)
@test st.graph == gh

y, st = l(h, ps, st)
@test st.graph == gh

@test size(y) == (7, gh.num_nodes)
end
end
end
end
Expand Down

0 comments on commit 8d6de4b

Please sign in to comment.