Skip to content

Commit

Permalink
Adds GATv2 layer (#97)
Browse files Browse the repository at this point in the history
* Adds GATv2 layer
  • Loading branch information
abieler authored Jan 7, 2022
1 parent 4dec29c commit 957ce36
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export
ChebConv,
EdgeConv,
GATConv,
GATv2Conv,
GatedGraphConv,
GCNConv,
GINConv,
Expand Down
108 changes: 105 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where T
check_num_nodes(g, X)
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."

= scaled_laplacian(g, eltype(X))
= scaled_laplacian(g, eltype(X))

Z_prev = X
Z = X *
Expand Down Expand Up @@ -333,9 +333,9 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
x = mean(x, dims=2)
end
x = reshape(x, :, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)
x = l.σ.(x .+ l.bias)

return x
return x
end


Expand All @@ -346,6 +346,108 @@ function Base.show(io::IO, l::GATConv)
print(io, "))")
end

@doc raw"""
GATv2Conv(in => out, σ=identity;
heads=1,
concat=true,
init=glorot_uniform
bias=true,
negative_slope=0.2f0)
GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_2 \mathbf{x}_i; W_1 \mathbf{x}_j]))
```
with ``z_i`` a normalization factor.
# Arguments
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `bias`: Learn the additive bias if true.
- `heads`: Number attention heads.
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
- `negative_slope`: The parameter of LeakyReLU.
"""
struct GATv2Conv{T, A1, A2, B, C<:AbstractMatrix} <: GNNLayer
dense_i::A1
dense_j::A2
bias::B
a::C
σ
negative_slope::T
channel::Pair{Int, Int}
heads::Int
concat::Bool
end

@functor GATv2Conv
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.bias, l.a)

function GATv2Conv(
channel::Pair{Int,Int},
σ=identity;
heads::Int=1,
concat::Bool=true,
negative_slope=0.2,
init=glorot_uniform,
bias::Bool=true,
)
in, out = channel
dense_i = Dense(in, out*heads; bias=bias, init=init)
dense_j = Dense(in, out*heads; bias=false, init=init)
if concat
b = bias ? Flux.create_bias(dense_i.weight, bias, out*heads) : false
else
b = bias ? Flux.create_bias(dense_i.weight, bias, out) : false
end
a = init(out, heads)

negative_slope = convert(eltype(dense_i.weight), negative_slope)
GATv2Conv(dense_i, dense_j, b, a, σ, negative_slope, channel, heads, concat)
end

function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
g = add_self_loops(g)
in, out = l.channel
heads = l.heads

Wix = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
Wjx = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes


function message(Wix, Wjx, e)
eij = sum(l.a .* leakyrelu.(Wix + Wjx, l.negative_slope), dims=1) # 1 × heads × nedges
α = exp.(eij)
return= α, β = α .* Wjx)
end

m = propagate(message, g, +; xi=Wix, xj=Wjx) # out × heads × nnodes
x = m.β ./ m.α

if !l.concat
x = mean(x, dims=2)
end
x = reshape(x, :, size(x, 3))
x = l.σ.(x .+ l.bias)
return x
end


function Base.show(io::IO, l::GATv2Conv)
out, in = size(l.weight_i)
print(io, "GATv2Conv(", in, "=>", out ÷ l.heads)
print(io, ", LeakyReLU(λ=", l.negative_slope)
print(io, "))")
end


@doc raw"""
GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform)
Expand Down
23 changes: 20 additions & 3 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
T = Float32

adj1 = [0 1 0 1
1 0 1 0
0 1 0 1
1 0 1 0]
1 0 1 0
0 1 0 1
1 0 1 0]

g1 = GNNGraph(adj1,
ndata=rand(T, in_channel, N),
Expand Down Expand Up @@ -109,6 +109,23 @@
end
end

@testset "GATv2Conv" begin

for heads in (1, 2), concat in (true, false)
l = GATv2Conv(in_channel => out_channel; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=1e-4,
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end

@testset "bias=false" begin
@test length(Flux.params(GATv2Conv(2=>3))) == 5
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
end
end


@testset "GatedGraphConv" begin
num_layers = 3
l = GatedGraphConv(out_channel, num_layers)
Expand Down

0 comments on commit 957ce36

Please sign in to comment.