Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds GATv2 layer #97

Merged
merged 9 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export
ChebConv,
EdgeConv,
GATConv,
GATv2Conv,
GatedGraphConv,
GCNConv,
GINConv,
Expand Down
108 changes: 105 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where T
check_num_nodes(g, X)
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."

L̃ = scaled_laplacian(g, eltype(X))
L̃ = scaled_laplacian(g, eltype(X))

Z_prev = X
Z = X * L̃
Expand Down Expand Up @@ -333,9 +333,9 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
x = mean(x, dims=2)
end
x = reshape(x, :, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)
x = l.σ.(x .+ l.bias)

return x
return x
end


Expand All @@ -346,6 +346,108 @@ function Base.show(io::IO, l::GATConv)
print(io, "))")
end

@doc raw"""
GATv2Conv(in => out, σ=identity;
heads=1,
concat=true,
init=glorot_uniform
bias=true,
negative_slope=0.2f0)

GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).

Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W \mathbf{x}_i; W \mathbf{x}_j]))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
```
with ``z_i`` a normalization factor.

# Arguments

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `bias`: Learn the additive bias if true.
- `heads`: Number attention heads.
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
- `negative_slope`: The parameter of LeakyReLU.
"""
struct GATv2Conv{T, A<:AbstractMatrix, B} <: GNNLayer
Wi::A
Wj::A
bias::B
a::A
σ
negative_slope::T
channel::Pair{Int, Int}
heads::Int
concat::Bool
end

@functor GATv2Conv
Flux.trainable(l::GATv2Conv) = (l.Wi, l.Wj, l.bias, l.a)

function GATv2Conv(
channel::Pair{Int,Int},
σ=identity;
heads::Int=1,
concat::Bool=true,
negative_slope=0.2,
init=glorot_uniform,
bias::Bool=true,
)
in, out = channel
Wi = init(out*heads, in)
Wj = init(out*heads, in)
if concat
b = bias ? Flux.create_bias(Wi, bias, out*heads) : false
else
b = bias ? Flux.create_bias(Wi, bias, out) : false
end
# bias = Flux.create_bias(Wi, bias, out*heads)
a = init(out, heads)

negative_slope = convert(eltype(Wi), negative_slope)
GATv2Conv(Wi, Wj, b, a, σ, negative_slope, channel, heads, concat)
end

function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
g = add_self_loops(g)
in, out = l.channel
heads = l.heads

Wix = reshape(l.Wi * x, out, heads, :) # out × heads × nnodes
Wjx = reshape(l.Wj * x, out, heads, :) # out × heads × nnodes

function message(Wix, Wjx, e)
eij = sum(l.a .* leakyrelu.(Wix + Wjx, l.negative_slope), dims=1) # 1 × heads × nedges
α = exp.(eij)
return (α = α, β = α .* Wjx)
end

m = propagate(message, g, +; xi=Wix, xj=Wjx) # out × heads × nnodes
x = m.β ./ m.α

if !l.concat
x = mean(x, dims=2)
end
x = reshape(x, :, size(x, 3))
x = l.σ.(x .+ l.bias)
return x
end


function Base.show(io::IO, l::GATv2Conv)
out, in = size(l.weight_i)
print(io, "GATv2Conv(", in, "=>", out ÷ l.heads)
print(io, ", LeakyReLU(λ=", l.negative_slope)
print(io, "))")
end


@doc raw"""
GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform)
Expand Down
23 changes: 20 additions & 3 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
T = Float32

adj1 = [0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0]
1 0 1 0
0 1 0 1
1 0 1 0]

g1 = GNNGraph(adj1,
ndata=rand(T, in_channel, N),
Expand Down Expand Up @@ -109,6 +109,23 @@
end
end

@testset "GATv2Conv" begin

for heads in (1, 2), concat in (true, false)
l = GATv2Conv(in_channel => out_channel; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=1e-4,
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end

@testset "bias=false" begin
@test length(Flux.params(GATv2Conv(2=>3))) == 4
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
end
end


@testset "GatedGraphConv" begin
num_layers = 3
l = GatedGraphConv(out_channel, num_layers)
Expand Down