diff --git a/Project.toml b/Project.toml index e1bd8a30e..bf1c16a48 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphNeuralNetworks" uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" authors = ["Carlo Lucibello and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7b0779a7d..862f69e12 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -60,19 +60,29 @@ end Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...) +# input from graph +applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g))) +applylayer(l::GNNLayer, g::GNNGraph) = l(g) + +# explicit input applylayer(l, g::GNNGraph, x) = l(x) applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) # Handle Flux.Parallel +applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=applylayer(l, g, node_features(g))) applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers) -applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs) -applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...) +# input from graph +applychain(::Tuple{}, g::GNNGraph) = g +applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g)) +# explicit input applychain(::Tuple{}, g::GNNGraph, x) = x applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x)) (c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x) +(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g) + Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...) Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 24bebab41..d42c4492f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -31,6 +31,18 @@ test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[:μ, :σ²]) end + + @testset "Only graph input" begin + nin, nout = 2, 4 + ndata = rand(nin, 3) + edata = rand(nin, 3) + g = GNNGraph([1,1,2], [2, 3, 3], ndata=ndata, edata=edata) + m = NNConv(nin => nout, Dense(2, nin*nout, tanh)) + chain = GNNChain(m) + y = m(g, g.ndata.x, g.edata.e) + @test m(g).ndata.x == y + @test chain(g).ndata.x == y + end end end