From bdbb5a458fc73113accf105d180f54e428bb2825 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Fri, 23 Feb 2024 05:06:57 +0530 Subject: [PATCH 1/6] Add GINConv support to HeteroGraphConv --- src/layers/conv.jl | 6 ++++-- test/layers/heteroconv.jl | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 006b03091..efd6ac5d4 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -680,9 +680,11 @@ Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) -function (l::GINConv)(g::GNNGraph, x::AbstractMatrix) +function (l::GINConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - m = propagate(copy_xj, g, l.aggr, xj = x) + xj, _ = expand_srcdst(g, x) + + m = propagate(copy_xj, g, l.aggr, xj = xj) l.nn((1 + ofeltype(x, l.ϵ)) * x + m) end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f1a07b7a7..069afd23f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -109,4 +109,14 @@ @test size(y.A) == (2,2) && size(y.B) == (2,3) end + @testset "GINConv" begin + g = rand_bipartite_heterograph((2, 5), 10) + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 5)); + layer = HeteroGraphConv((:A, :to, :B) => GraphConv(4 => 2, relu), + (:B, :to, :A) => GraphConv(4 => 2, relu)); #just temporary + y = layer(g, x); + out = GINConv(y, 1e-5) # not sure how to use it exactly + # continue test + end + end From 6447379aae6d2047db71dc1dbb4206e4839e82d1 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sat, 2 Mar 2024 15:42:21 +0530 Subject: [PATCH 2/6] gin hetero --- src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index efd6ac5d4..1ba56484e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -682,7 +682,7 @@ GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) function (l::GINConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - xj, _ = expand_srcdst(g, x) + xj, _ = expand_srcdst(g, x) # hetero graphs m = propagate(copy_xj, g, l.aggr, xj = xj) l.nn((1 + ofeltype(x, l.ϵ)) * x + m) From be0f921b9d811b324d9789a2df9575e7dd0da624 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sat, 2 Mar 2024 19:10:14 +0530 Subject: [PATCH 3/6] trying tests --- src/layers/conv.jl | 11 ++++++++--- test/layers/heteroconv.jl | 12 +++++------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 1fb6221c6..642ed52e1 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -682,10 +682,15 @@ GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) function (l::GINConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - xj, _ = expand_srcdst(g, x) # hetero graphs - + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) - l.nn((1 + ofeltype(x, l.ϵ)) * x + m) + + if g isa GNNHeteroGraph + # + else + l.nn((1 + ofeltype(x, l.ϵ)) * x + m) + end end function Base.show(io::IO, l::GINConv) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index a7310111a..b617f9d36 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -117,13 +117,11 @@ @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @testset "GINConv" begin - g = rand_bipartite_heterograph((2, 5), 10) - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 5)); - layer = HeteroGraphConv((:A, :to, :B) => GraphConv(4 => 2, relu), - (:B, :to, :A) => GraphConv(4 => 2, relu)); #just temporary - y = layer(g, x); - out = GINConv(y, 1e-5) # not sure how to use it exactly - # continue test + x = (A = rand(4,2), B = rand(4, 3)) + layers = HeteroGraphConv( (:A, :to, :B) => GINConv(Dense(2 * 4, 2), 0.4), + (:B, :to, :A) => GINConv(Dense(2 * 4, 2), 0.4)); + y = layers(hg, x); + @test size(y.A) == (2,2) && size(y.B) == (2,3) end end From 6dd268d8897cf321c2795a46cf7b8cf3a34bcfd1 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sat, 2 Mar 2024 19:36:14 +0530 Subject: [PATCH 4/6] tests pass --- src/layers/conv.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 642ed52e1..0a32d2375 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -687,7 +687,7 @@ function (l::GINConv)(g::AbstractGNNGraph, x) m = propagate(copy_xj, g, l.aggr, xj = xj) if g isa GNNHeteroGraph - # + l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) else l.nn((1 + ofeltype(x, l.ϵ)) * x + m) end From 7fefafb857bdfbc699d3e62d5ac7ec7c5b70481d Mon Sep 17 00:00:00 2001 From: rbSparky Date: Sat, 2 Mar 2024 19:38:54 +0530 Subject: [PATCH 5/6] tests pass --- test/layers/heteroconv.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index b617f9d36..0e7bb782b 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -117,11 +117,10 @@ @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @testset "GINConv" begin - x = (A = rand(4,2), B = rand(4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => GINConv(Dense(2 * 4, 2), 0.4), - (:B, :to, :A) => GINConv(Dense(2 * 4, 2), 0.4)); + x = (A = rand(4, 2), B = rand(4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4), + (:B, :to, :A) => GINConv(Dense(4, 2), 0.4)); y = layers(hg, x); - @test size(y.A) == (2,2) && size(y.B) == (2,3) + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end - end From 34ad15c9332561a669f1336464eb0f3ac6e5bf51 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 4 Mar 2024 12:35:54 +0530 Subject: [PATCH 6/6] fix --- src/layers/conv.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0a32d2375..22e980ae3 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -686,11 +686,7 @@ function (l::GINConv)(g::AbstractGNNGraph, x) m = propagate(copy_xj, g, l.aggr, xj = xj) - if g isa GNNHeteroGraph - l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) - else - l.nn((1 + ofeltype(x, l.ϵ)) * x + m) - end + l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) end function Base.show(io::IO, l::GINConv)