diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 5258352e3..321322069 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -50,6 +50,7 @@ export ChebConv, EdgeConv, GATConv, + GATv2Conv, GatedGraphConv, GCNConv, GINConv, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 744f02a73..96e585e12 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -175,7 +175,7 @@ function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where T check_num_nodes(g, X) @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." - L̃ = scaled_laplacian(g, eltype(X)) + L̃ = scaled_laplacian(g, eltype(X)) Z_prev = X Z = X * L̃ @@ -333,9 +333,9 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix) x = mean(x, dims=2) end x = reshape(x, :, size(x, 3)) # return a matrix - x = l.σ.(x .+ l.bias) + x = l.σ.(x .+ l.bias) - return x + return x end @@ -346,6 +346,108 @@ function Base.show(io::IO, l::GATConv) print(io, "))") end +@doc raw""" + GATv2Conv(in => out, σ=identity; + heads=1, + concat=true, + init=glorot_uniform + bias=true, + negative_slope=0.2f0) + +GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491). + +Implements the operation +```math +\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j +``` +where the attention coefficients ``\alpha_{ij}`` are given by +```math +\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_2 \mathbf{x}_i; W_1 \mathbf{x}_j])) +``` +with ``z_i`` a normalization factor. + +# Arguments + +- `in`: The dimension of input features. +- `out`: The dimension of output features. +- `bias`: Learn the additive bias if true. +- `heads`: Number attention heads. +- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. +- `negative_slope`: The parameter of LeakyReLU. +""" +struct GATv2Conv{T, A1, A2, B, C<:AbstractMatrix} <: GNNLayer + dense_i::A1 + dense_j::A2 + bias::B + a::C + σ + negative_slope::T + channel::Pair{Int, Int} + heads::Int + concat::Bool +end + +@functor GATv2Conv +Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.bias, l.a) + +function GATv2Conv( + channel::Pair{Int,Int}, + σ=identity; + heads::Int=1, + concat::Bool=true, + negative_slope=0.2, + init=glorot_uniform, + bias::Bool=true, +) + in, out = channel + dense_i = Dense(in, out*heads; bias=bias, init=init) + dense_j = Dense(in, out*heads; bias=false, init=init) + if concat + b = bias ? Flux.create_bias(dense_i.weight, bias, out*heads) : false + else + b = bias ? Flux.create_bias(dense_i.weight, bias, out) : false + end + a = init(out, heads) + + negative_slope = convert(eltype(dense_i.weight), negative_slope) + GATv2Conv(dense_i, dense_j, b, a, σ, negative_slope, channel, heads, concat) +end + +function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + g = add_self_loops(g) + in, out = l.channel + heads = l.heads + + Wix = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes + Wjx = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes + + + function message(Wix, Wjx, e) + eij = sum(l.a .* leakyrelu.(Wix + Wjx, l.negative_slope), dims=1) # 1 × heads × nedges + α = exp.(eij) + return (α = α, β = α .* Wjx) + end + + m = propagate(message, g, +; xi=Wix, xj=Wjx) # out × heads × nnodes + x = m.β ./ m.α + + if !l.concat + x = mean(x, dims=2) + end + x = reshape(x, :, size(x, 3)) + x = l.σ.(x .+ l.bias) + return x +end + + +function Base.show(io::IO, l::GATv2Conv) + out, in = size(l.weight_i) + print(io, "GATv2Conv(", in, "=>", out ÷ l.heads) + print(io, ", LeakyReLU(λ=", l.negative_slope) + print(io, "))") +end + @doc raw""" GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index e17569fda..294db3f1f 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -5,9 +5,9 @@ T = Float32 adj1 = [0 1 0 1 - 1 0 1 0 - 0 1 0 1 - 1 0 1 0] + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] g1 = GNNGraph(adj1, ndata=rand(T, in_channel, N), @@ -109,6 +109,23 @@ end end + @testset "GATv2Conv" begin + + for heads in (1, 2), concat in (true, false) + l = GATv2Conv(in_channel => out_channel; heads, concat) + for g in test_graphs + test_layer(l, g, rtol=1e-4, + outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) + end + end + + @testset "bias=false" begin + @test length(Flux.params(GATv2Conv(2=>3))) == 5 + @test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3 + end + end + + @testset "GatedGraphConv" begin num_layers = 3 l = GatedGraphConv(out_channel, num_layers)