From 00dcf54f4a3c183a312bbea623717e22cdb4b1cb Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Tue, 28 Jun 2022 01:01:25 -0600 Subject: [PATCH 1/3] GNOConv --- src/NeuralGraphPDE.jl | 2 +- src/layers.jl | 158 +++++++++++++++++++++++++++++++++++++----- test/runtests.jl | 16 +++++ 3 files changed, 159 insertions(+), 17 deletions(-) diff --git a/src/NeuralGraphPDE.jl b/src/NeuralGraphPDE.jl index bfc017e..bd2ce59 100644 --- a/src/NeuralGraphPDE.jl +++ b/src/NeuralGraphPDE.jl @@ -14,7 +14,7 @@ include("utils.jl") include("layers.jl") export AbstractGNNLayer, AbstractGNNContainerLayer -export ExplicitEdgeConv, ExplicitGCNConv, VMHConv, MPPDEConv +export ExplicitEdgeConv, ExplicitGCNConv, VMHConv, MPPDEConv, GNOConv export updategraph end diff --git a/src/layers.jl b/src/layers.jl index fe79c8a..c7d33d1 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -339,7 +339,6 @@ end @doc raw""" MPPDEConv(ϕ, ψ; initialgraph = initialgraph, aggr = sum, local_features = (:u, :x)) - Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/abs/2202.03376), without the temporal bulking trick. ```math \begin{aligned} @@ -347,49 +346,34 @@ Convolutional layer from [Message Passing Neural PDE Solvers](https://arxiv.org/ \mathbf{h}_i'&=\psi (\mathbf{h}_i,\mathbf{m}_i,\theta )\\ \end{aligned} ``` - # Arguments - - `ϕ`: The neural network for the message function. - `ψ`: The neural network for the update function. - `initialgraph`: `GNNGraph` or a function that returns a `GNNGraph` - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). - `local_features`: The features that will be differentiated in the message function. - # Inputs - - `h`: Trainable node embeddings, `Array`. - # Returns - - `NamedTuple` or `Array` that has the same struct with `x` with different a size of channels. - # Parameters - - Parameters of `ϕ`. - Parameters of `ψ`. - # States - - `graph`: `GNNGraph` where `graph.ndata.x` represents the spatial coordinates of nodes, `graph.ndata.u` represents the initial condition, and `graph.gdata.θ` represents the graph level features of the underlying PDE. `θ` should be a matrix of the size `(num_feats, num_graphs)`. If `g` is a batched graph, then all graphs need to have the same structure. - # Examples ```julia g = rand_graph(10, 6) - g = GNNGraph(g, ndata = (; u = rand(2, 10), x = rand(3, 10)), gdata = (; θ = rand(4))) h = randn(5, 10) ϕ = Dense(5 + 5 + 2 + 3 + 4 => 5) ψ = Dense(5 + 5 + 4 => 7) l = MPPDEConv(ϕ, ψ, initialgraph = g) - rng = Random.default_rng() ps, st = Lux.setup(rng, l) - y, st = l(h, ps, st) ``` - """ struct MPPDEConv{F, L, M1, M2, A} <: AbstractGNNContainerLayer{(:ϕ, :ψ)} initialgraph::F @@ -435,3 +419,145 @@ function (l::MPPDEConv)(x::AbstractArray, ps, st::NamedTuple) return y, st end + +@doc raw""" + GNOConv(in_chs => out_chs, ϕ; initialgraph = initialgraph, aggr = mean) + +Convolutional layer from [Neural Operator: Graph Kernel Network for Partial Differential Equations](https://openreview.net/forum?id=5fbUEUTZEn7). +```math +\begin{aligned} + \mathbf{m}_i&=\Box _{j\in N(i)}\,\phi (\mathbf{a}_i,\mathbf{a}_j,\mathbf{x}_i,\mathbf{x}_j)\mathbf{h}_j\\ + \mathbf{h}_i'&=\,\,\sigma \left( \mathbf{Wh}_i+\mathbf{m}_i+\mathbf{b} \right)\\ +\end{aligned} +``` + +# Arguments + +- `in_chs`: The number of input channels. +- `out_chs`: The number of output channels. +- `ϕ`: The neural network for the message function. The output size of `ϕ` should be `in_chs * out_chs`. +- `initialgraph`: `GNNGraph` or a function that returns a `GNNGraph` +- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). + +# Inputs + +- `h`: `Array` of the size `(in_chs, num_nodes)`. + +# Returns + +- `Array` of the size `(out_chs, num_nodes)`. + +# Parameters + +- Parameters of `ϕ`. +- `W`. +- `b`. + +# States + +- `graph`: `GNNGraph`. All features in `graph.ndata` will be concatenated and then fed into `ϕ`. + +# Examples +```julia +g = rand_graph(10, 6) + +g = GNNGraph(g, ndata = (; a = rand(2, 10), x = rand(3, 10))) +in_chs, out_chs = 5, 7 +h = randn(in_chs, 10) +ϕ = Dense(2 + 2 + 3 + 3 => in_chs * out_chs) +l = GNOConv(5 => 7, ϕ, initialgraph = g) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, l) + +y, st = l(h, ps, st) +``` + +""" +struct GNOConv{bias, A} <: AbstractGNNContainerLayer{(:linear, :ϕ)} + in_chs::Int + out_chs::Int + initialgraph::Function + aggr::A + linear::Dense + ϕ::AbstractExplicitLayer +end + +function GNOConv(in_chs::Int, out_chs::Int, ϕ::AbstractExplicitLayer, activation = identity; + initialgraph = initialgraph, + init_weight = glorot_uniform, + init_bias = zeros32, + aggr = mean, + bias::Bool = true) + GNOConv(in_chs => out_chs, ϕ, activation, + initialgraph = initialgraph, init_weight = init_weight, init_bias = init_bias, + aggr = aggr, bias = bias) +end + +function GNOConv(ch::Pair{Int, Int}, ϕ::AbstractExplicitLayer, activation = identity; + initialgraph = initialgraph, + init_weight = glorot_uniform, + init_bias = zeros32, + aggr = mean, + bias::Bool = true) + initialgraph = wrapgraph(initialgraph) + linear = Dense(ch, activation, + init_weight = init_weight, + init_bias = init_bias, + bias = bias) + GNOConv{bias, typeof(aggr)}(first(ch), last(ch), initialgraph, aggr, linear, ϕ) +end + +function (l::GNOConv{true})(x::AbstractArray, ps, st::NamedTuple) + g = st.graph + s = g.ndata + edge_features = keys(s) + + function message(xi, xj, e) + si, sj = xi[edge_features], xj[edge_features] + si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj)) + + W, st_ϕ = ϕ(vcat(si, sj), ps.ϕ, st.ϕ) + st = merge(st, (; ϕ = st_ϕ)) + + hj = xj.h + nin, nedges = size(hj) + W = reshape(W, :, nin, nedges) + hj = reshape(hj, (nin, 1, nedges)) + m = NNlib.batched_mul(W, xj) + return reshape(m, :, nedges) + end + + xs = merge((; h = x), s) + m = propagate(message, g, l.aggr, xi = xs, xj = xs) + + y = l.linear.activation(ps.linear.weight * x .+ m .+ ps.linear.bias) + return y, st +end + +function (l::GNOConv{false})(x::AbstractArray, ps, st::NamedTuple) + g = st.graph + s = g.ndata + edge_features = keys(s) + + function message(xi, xj, e) + si, sj = xi[edge_features], xj[edge_features] + si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj)) + + W, st_ϕ = ϕ(vcat(si, sj), ps.ϕ, st.ϕ) + st = merge(st, (; ϕ = st_ϕ)) + + hj = xj.h + nin, nedges = size(hj) + W = reshape(W, :, nin, nedges) + hj = reshape(hj, (nin, 1, nedges)) + m = NNlib.batched_mul(W, xj) + return reshape(m, :, nedges) + end + + xs = merge((; h = x), s) + m = propagate(message, g, l.aggr, xi = xs, xj = xs) + + y = l.linear.activation(ps.linear.weight * x .+ m) + return y, st +end diff --git a/test/runtests.jl b/test/runtests.jl index 6f12f38..077c9ec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -112,6 +112,22 @@ using Flux: batch, unbatch @test size(y) == (7, gh.num_nodes) end end + + @testset "GNOConv" begin + g = rand_graph(10, 6) + + g = GNNGraph(g, ndata = (; a = rand(2, g.num_nodes), x = rand(3, g.num_nodes))) + in_chs, out_chs = 5, 7 + h = randn(in_chs, g.num_nodes) + ϕ = Dense(2 + 2 + 3 + 3 => in_chs * out_chs) + l = GNOConv(5 => 7, ϕ, initialgraph = g) + + rng = Random.default_rng() + ps, st = Lux.setup(rng, l) + + y, st = l(h, ps, st) + @test size(y) == (7, g.num_nodes) + end end end From 77e639467a1d0b0408c9fd317ddd10b6784e27b4 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Tue, 28 Jun 2022 01:03:43 -0600 Subject: [PATCH 2/3] test --- src/layers.jl | 8 ++++---- test/runtests.jl | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index c7d33d1..c3e6340 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -517,14 +517,14 @@ function (l::GNOConv{true})(x::AbstractArray, ps, st::NamedTuple) si, sj = xi[edge_features], xj[edge_features] si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj)) - W, st_ϕ = ϕ(vcat(si, sj), ps.ϕ, st.ϕ) + W, st_ϕ = l.ϕ(vcat(si, sj), ps.ϕ, st.ϕ) st = merge(st, (; ϕ = st_ϕ)) hj = xj.h nin, nedges = size(hj) W = reshape(W, :, nin, nedges) hj = reshape(hj, (nin, 1, nedges)) - m = NNlib.batched_mul(W, xj) + m = NNlib.batched_mul(W, hj) return reshape(m, :, nedges) end @@ -544,14 +544,14 @@ function (l::GNOConv{false})(x::AbstractArray, ps, st::NamedTuple) si, sj = xi[edge_features], xj[edge_features] si, sj = reduce(vcat, values(si)), reduce(vcat, values(sj)) - W, st_ϕ = ϕ(vcat(si, sj), ps.ϕ, st.ϕ) + W, st_ϕ = l.ϕ(vcat(si, sj), ps.ϕ, st.ϕ) st = merge(st, (; ϕ = st_ϕ)) hj = xj.h nin, nedges = size(hj) W = reshape(W, :, nin, nedges) hj = reshape(hj, (nin, 1, nedges)) - m = NNlib.batched_mul(W, xj) + m = NNlib.batched_mul(W, hj) return reshape(m, :, nedges) end diff --git a/test/runtests.jl b/test/runtests.jl index 077c9ec..e49d5f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,6 +127,13 @@ using Flux: batch, unbatch y, st = l(h, ps, st) @test size(y) == (7, g.num_nodes) + + l = GNOConv(5 => 7, ϕ, initialgraph = g, bias = false) + rng = Random.default_rng() + ps, st = Lux.setup(rng, l) + + y, st = l(h, ps, st) + @test size(y) == (7, g.num_nodes) end end end From dc43222165b5bd03626930370a314599f082a893 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Tue, 28 Jun 2022 01:05:34 -0600 Subject: [PATCH 3/3] import batch --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index e49d5f1..7a6e834 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Random using Lux using Lux: parameterlength using Test -using Flux: batch, unbatch +import Flux: batch, unbatch @testset "layers" begin rng = Random.default_rng()