From a392d41d443b3e6acec1236b5b24b0db97721233 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Mon, 4 Mar 2024 15:45:07 +0530 Subject: [PATCH 1/5] gat hetero support --- src/layers/conv.jl | 16 +++++++++++----- src/utils.jl | 10 ++++++++-- test/layers/heteroconv.jl | 8 ++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 09efea74e..3169f714b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -356,25 +356,31 @@ end (l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) -function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, +function (l::GATConv)(g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + xj, xi = expand_srcdst(g, x) + edge_t = g isa GNNHeteroGraph ? g.etypes[1] : nothing + if l.add_self_loops @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = add_self_loops(g) + g = g isa GNNHeteroGraph ? add_self_loops(g, edge_t) : add_self_loops(g) end _, chout = l.channel heads = l.heads - Wx = l.dense_x(x) - Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes + Wxj = l.dense_x(xj) + Wxj = reshape(Wxj, chout, heads, :) + + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) # a hand-written message passing - m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e) + m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e) α = softmax_edge_neighbors(g, m.logα) β = α .* m.Wxj x = aggregate_neighbors(g, +, β) diff --git a/src/utils.jl b/src/utils.jl index b7876f2ba..8434c63c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -81,8 +81,14 @@ Softmax over each node's neighborhood of the edge features `e`. {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. ``` """ -function softmax_edge_neighbors(g::GNNGraph, e) - @assert size(e)[end] == g.num_edges +function softmax_edge_neighbors(g::AbstractGNNGraph, e) + if g isa GNNHeteroGraph + for (key, value) in g.num_edges + @assert size(e)[end] == value + end + else + @assert size(e)[end] == g.num_edges + end s, t = edge_index(g) max_ = gather(scatter(max, e, t), t) num = exp.(e .- max_) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index e4d0fd40a..3f71c8c44 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -116,4 +116,12 @@ y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + + @testset "GATConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GATConv(4 => 2), + (:B, :to, :A) => GATConv(4 => 2)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end From 5867ac7327a491f29e746f8ee6cd51c9507bc41c Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:39:32 +0530 Subject: [PATCH 2/5] Update src/layers/conv.jl Co-authored-by: Carlo Lucibello --- src/layers/conv.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 22645d895..76adc7274 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -420,7 +420,11 @@ function (l::GATConv)(g::AbstractGNNGraph, x, if l.add_self_loops @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = g isa GNNHeteroGraph ? add_self_loops(g, edge_t) : add_self_loops(g) + if g isa GNNHeteroGraph + g = add_self_loops(g, g.etypes[1]) + else + g = add_self_loops(g) + end end _, chout = l.channel From 7a392bb01d59975762a447a7a29456fa4eb5f2fa Mon Sep 17 00:00:00 2001 From: rbSparky Date: Tue, 5 Mar 2024 21:45:35 +0530 Subject: [PATCH 3/5] changes made --- src/layers/conv.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 76adc7274..437b17a4a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -433,8 +433,10 @@ function (l::GATConv)(g::AbstractGNNGraph, x, Wxj = l.dense_x(xj) Wxj = reshape(Wxj, chout, heads, :) - Wxi = l.dense_x(xi) - Wxi = reshape(Wxi, chout, heads, :) + if xi !== xj + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) + end # a hand-written message passing m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e) From d6f63f49945760caa84e1697448b18973a545c4d Mon Sep 17 00:00:00 2001 From: rbSparky Date: Tue, 5 Mar 2024 22:27:20 +0530 Subject: [PATCH 4/5] fix --- src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 437b17a4a..d423e3082 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -430,8 +430,8 @@ function (l::GATConv)(g::AbstractGNNGraph, x, _, chout = l.channel heads = l.heads - Wxj = l.dense_x(xj) - Wxj = reshape(Wxj, chout, heads, :) + Wxi = Wxj = l.dense_x(xj) + Wxi = Wxj = reshape(Wxj, chout, heads, :) if xi !== xj Wxi = l.dense_x(xi) From eeec6fac2b4accdd92a6638246a8e3a4ec2828ba Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Wed, 6 Mar 2024 00:35:02 +0530 Subject: [PATCH 5/5] Update src/layers/conv.jl Co-authored-by: Carlo Lucibello --- src/layers/conv.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index d423e3082..a48d48149 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -416,7 +416,6 @@ function (l::GATConv)(g::AbstractGNNGraph, x, @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" xj, xi = expand_srcdst(g, x) - edge_t = g isa GNNHeteroGraph ? g.etypes[1] : nothing if l.add_self_loops @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported."