From fe08d5c15fc062a0bd53e5927874bf156678d5f0 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Wed, 13 Jul 2022 14:00:09 -0600 Subject: [PATCH 1/4] format --- src/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 16341b2..f70014f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,6 +34,6 @@ function updategraph(st::NamedTuple, g::GNNGraph) return st end -@inline _flatten(x:: AbstractMatrix) = x -@inline _flatten(x:: AbstractArray{T,3}) where {T} = reshape(x, siz(x,2) * size(x,3)) -@inline _flatten(x:: NamedTuple) = map(d -> _flatten(d), x) +@inline _flatten(x::AbstractMatrix) = x +@inline _flatten(x::AbstractArray{T, 3}) where {T} = reshape(x, siz(x, 2) * size(x, 3)) +@inline _flatten(x::NamedTuple) = map(d -> _flatten(d), x) From 07c174700d532ce8a648d73526becaa37033e195 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Wed, 13 Jul 2022 14:09:42 -0600 Subject: [PATCH 2/4] fix --- src/layers.jl | 5 ++++- src/utils.jl | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index c91cca9..b494073 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -314,6 +314,7 @@ function (l::VMHConv)(x::AbstractArray, ps, st::NamedTuple) end function (l::VMHConv)(x::NamedTuple, ps, st::NamedTuple) + x = _flatten(x) function message(xi, xj, e) posi, posj = xi.x, xj.x hi, hj = values(drop(xi, :x)), values(drop(xj, :x)) @@ -391,6 +392,7 @@ function MPPDEConv(ϕ::AbstractExplicitLayer, ψ::AbstractExplicitLayer; end function (l::MPPDEConv)(x::AbstractArray, ps, st::NamedTuple) + x = _flatten(x) g = st.graph num_nodes = g.num_nodes num_edges = g.num_edges @@ -519,7 +521,8 @@ function GNOConv(ch::Pair{Int, Int}, ϕ::AbstractExplicitLayer, activation = ide GNOConv{bias, typeof(aggr)}(first(ch), last(ch), initialgraph, aggr, linear, ϕ) end -function (l::GNOConv{bias})(x::AbstractMatrix, ps, st::NamedTuple) where {bias} +function (l::GNOConv{bias})(x::AbstractArray, ps, st::NamedTuple) where {bias} + x = _flatten(x) g = st.graph s = g.ndata nkeys = keys(s) diff --git a/src/utils.jl b/src/utils.jl index fcba246..f70014f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,7 +13,7 @@ copy(g::GNNGraph, kwarg...) = GNNGraph(g, kwarg...) wrapgraph(g::GNNGraph) = () -> copy(g) wrapgraph(f::Function) = f -Creater a function wrapper of the input graph. +Creater a function wrapper of the input graph. """ wrapgraph(g::GNNGraph) = () -> copy(g) wrapgraph(f::Function) = f @@ -36,4 +36,4 @@ end @inline _flatten(x::AbstractMatrix) = x @inline _flatten(x::AbstractArray{T, 3}) where {T} = reshape(x, siz(x, 2) * size(x, 3)) -@inline _flatten(x::NamedTuple) = map(d -> _flatten(d), x) \ No newline at end of file +@inline _flatten(x::NamedTuple) = map(d -> _flatten(d), x) From 88a7d2392963cfd02902e0c91cf84783754be400 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Wed, 13 Jul 2022 14:17:16 -0600 Subject: [PATCH 3/4] generalize --- src/utils.jl | 5 ++++- test/runtests.jl | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index f70014f..31ae6a4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,5 +35,8 @@ function updategraph(st::NamedTuple, g::GNNGraph) end @inline _flatten(x::AbstractMatrix) = x -@inline _flatten(x::AbstractArray{T, 3}) where {T} = reshape(x, siz(x, 2) * size(x, 3)) +@inline function _flatten(x::AbstractArray{T, N}) where {T, N} + s = size(x) + return reshape(x, s[1:(end - 2)]..., s[end - 1] * s[end]) +end @inline _flatten(x::NamedTuple) = map(d -> _flatten(d), x) diff --git a/test/runtests.jl b/test/runtests.jl index a8d0eb1..09c00f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,6 +107,13 @@ using SafeTestsets ps, st = Lux.setup(rng, l) y, st = l(h, ps, st) @test size(y) == (7, gh.num_nodes) + + h = randn(T, 5, g.num_nodes, 2) + + ps, st = Lux.setup(rng, l) + y, st = l(h, ps, st) + @test size(y) == (7, gh.num_nodes) + end @testset "Without theta" begin From fc8ab99ffc2c92b413683a0dfc87946241a05be1 Mon Sep 17 00:00:00 2001 From: NeuralGraphPDE Date: Wed, 13 Jul 2022 14:30:51 -0600 Subject: [PATCH 4/4] format --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 09c00f3..6a55f86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -113,7 +113,6 @@ using SafeTestsets ps, st = Lux.setup(rng, l) y, st = l(h, ps, st) @test size(y) == (7, gh.num_nodes) - end @testset "Without theta" begin