diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 56ab12018..c1aac5d6e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -944,14 +944,16 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, return CGConv(ch, dense_f, dense_s, residual) end -function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, +function (l::CGConv)(g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + if e !== nothing check_num_edges(g, e) end - m = propagate(message, g, +, l, xi = x, xj = x, e = e) + m = propagate(message, g, +, l, xi = xi, xj = xj, e = e) if l.residual if size(x, 1) == size(m, 1) @@ -964,6 +966,7 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix, return m end + function message(l::CGConv, xi, xj, e) if e !== nothing z = vcat(xi, xj, e) diff --git a/src/msgpass.jl b/src/msgpass.jl index 232aff36c..2118cdac6 100644 --- a/src/msgpass.jl +++ b/src/msgpass.jl @@ -87,11 +87,11 @@ end # https://github.com/JuliaLang/julia/issues/15276 ## and zygote issues # https://github.com/FluxML/Zygote.jl/issues/1317 -function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing, +function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing, e = nothing) propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) end -function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) +function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index ff1203497..6cbd8563f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -91,4 +91,13 @@ output = y2.A[:, [2]] @test expected ≈ output end + + @testset "CGConv" begin + g = rand_bipartite_heterograph((2,3), 6) + x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), + (:B, :to, :A) => CGConv(4 => 2, relu)); + y = layers(g, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) + end end