diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 04dbc442f..19c200e7a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -808,10 +808,13 @@ Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) -function (l::GINConv)(g::GNNGraph, x::AbstractMatrix) +function (l::GINConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - m = propagate(copy_xj, g, l.aggr, xj = x) - l.nn((1 + ofeltype(x, l.ϵ)) * x + m) + xj, xi = expand_srcdst(g, x) + + m = propagate(copy_xj, g, l.aggr, xj = xj) + + l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) end function Base.show(io::IO, l::GINConv) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 8352ab85b..3d5f2c09c 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -125,6 +125,14 @@ @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + @testset "GINConv" begin + x = (A = rand(4, 2), B = rand(4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4), + (:B, :to, :A) => GINConv(Dense(4, 2), 0.4)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + @testset "ResGatedGraphConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv((:A, :to, :B) => ResGatedGraphConv(4 => 2),