diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 006b03091..1afc1d1fb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -801,10 +801,11 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, SAGEConv(W, b, σ, aggr) end -function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix) +function (l::SAGEConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - m = propagate(copy_xj, g, l.aggr, xj = x) - x = l.σ.(l.weight * vcat(x, m) .+ l.bias) + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) + x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) return x end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f1a07b7a7..e4d0fd40a 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -109,4 +109,11 @@ @test size(y.A) == (2,2) && size(y.B) == (2,3) end + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end