Skip to content

Commit

Permalink
feat: Add CGConv support for HeteroGraphConv (#363)
Browse files Browse the repository at this point in the history
* add gnnheterograph support for cgconv

* remove changes for agnnconv for now

* add test
  • Loading branch information
askorupka authored Jan 28, 2024
1 parent 1dafb8d commit b07aaa2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -944,14 +944,16 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
return CGConv(ch, dense_f, dense_s, residual)
end

function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
function (l::CGConv)(g::AbstractGNNGraph, x,
e::Union{Nothing, AbstractMatrix} = nothing)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)

if e !== nothing
check_num_edges(g, e)
end

m = propagate(message, g, +, l, xi = x, xj = x, e = e)
m = propagate(message, g, +, l, xi = xi, xj = xj, e = e)

if l.residual
if size(x, 1) == size(m, 1)
Expand All @@ -964,6 +966,7 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
return m
end


function message(l::CGConv, xi, xj, e)
if e !== nothing
z = vcat(xi, xj, e)
Expand Down
4 changes: 2 additions & 2 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ end
# https://github.com/JuliaLang/julia/issues/15276
## and zygote issues
# https://github.com/FluxML/Zygote.jl/issues/1317
function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
e = nothing)
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
end
function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
end

Expand Down
9 changes: 9 additions & 0 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,13 @@
output = y2.A[:, [2]]
@test expected output
end

@testset "CGConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
(:B, :to, :A) => CGConv(4 => 2, relu));
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end
end

0 comments on commit b07aaa2

Please sign in to comment.