From 664e682d2daf2eaf2143357ec937ca5845592ef4 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Wed, 21 Feb 2024 18:00:34 +0530 Subject: [PATCH 1/8] SAGEConv Hetero Layer --- src/layers/conv.jl | 5 +++-- test/layers/heteroconv.jl | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 006b03091..8517ebca9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -801,9 +801,10 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, SAGEConv(W, b, σ, aggr) end -function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix) +function (l::SAGEConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - m = propagate(copy_xj, g, l.aggr, xj = x) + xj, _ = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) x = l.σ.(l.weight * vcat(x, m) .+ l.bias) return x end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f1a07b7a7..e86e8ec6f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -109,4 +109,12 @@ @test size(y.A) == (2,2) && size(y.B) == (2,3) end + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu), + (:B, :to, :A) => SAGEConv(4 => 2, relu)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + end From 7daefc637c2370a6df233eeba65bf3d8e38af3c2 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 22 Feb 2024 00:19:22 +0530 Subject: [PATCH 2/8] without tests --- test/layers/heteroconv.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index e86e8ec6f..d94a4c363 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -108,13 +108,4 @@ y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end - - @testset "SAGEConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu), - (:B, :to, :A) => SAGEConv(4 => 2, relu)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - end From 7fa51cf5356a4219a36e3fde1e9adf5a7c686131 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 22 Feb 2024 18:16:29 +0530 Subject: [PATCH 3/8] test doesnt work yet --- test/layers/heteroconv.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index d94a4c363..f4ae23968 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -108,4 +108,12 @@ y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), + (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end From e6f826d4e5bbec3a56839b800cfbae94edc97ca6 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:19:59 +0530 Subject: [PATCH 4/8] tests --- test/layers/heteroconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f4ae23968..28c7f930d 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), - (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +), + (:B, :to, :A) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end From 4f4565c98703c6de100039fbb1ed8eafd9bde46a Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:27:22 +0530 Subject: [PATCH 5/8] tests should work --- src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8517ebca9..1afc1d1fb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -803,9 +803,9 @@ end function (l::SAGEConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - xj, _ = expand_srcdst(g, x) + xj, xi = expand_srcdst(g, x) m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight * vcat(x, m) .+ l.bias) + x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) return x end From 61cc6646a0f309234c84c4a90c5e0643a9340105 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:30:20 +0530 Subject: [PATCH 6/8] test update --- test/layers/heteroconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 28c7f930d..e4d0fd40a 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +), - (:B, :to, :A) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end From d8dde07b415c57d34ddc716de6f3d6743bee670f Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:32:08 +0530 Subject: [PATCH 7/8] temporary testing fast --- test/runtests.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 271373ecc..e2d5d9cd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,27 +25,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - "GNNGraphs/chainrules", - "GNNGraphs/datastore", - "GNNGraphs/gnngraph", - "GNNGraphs/convert", - "GNNGraphs/transform", - "GNNGraphs/operators", - "GNNGraphs/generate", - "GNNGraphs/query", - "GNNGraphs/sampling", - "GNNGraphs/gnnheterograph", - "GNNGraphs/temporalsnapshotsgnngraph", - "utils", - "msgpass", - "layers/basic", - "layers/conv", "layers/heteroconv", - "layers/temporalconv", - "layers/pool", - "mldatasets", - "examples/node_classification_cora", - "deprecations", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") From fe683edf13b5227301fab739aeacabe9ffe086fb Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:38:46 +0530 Subject: [PATCH 8/8] final tests --- test/runtests.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index e2d5d9cd9..271373ecc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,27 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ + "GNNGraphs/chainrules", + "GNNGraphs/datastore", + "GNNGraphs/gnngraph", + "GNNGraphs/convert", + "GNNGraphs/transform", + "GNNGraphs/operators", + "GNNGraphs/generate", + "GNNGraphs/query", + "GNNGraphs/sampling", + "GNNGraphs/gnnheterograph", + "GNNGraphs/temporalsnapshotsgnngraph", + "utils", + "msgpass", + "layers/basic", + "layers/conv", "layers/heteroconv", + "layers/temporalconv", + "layers/pool", + "mldatasets", + "examples/node_classification_cora", + "deprecations", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")