Skip to content

Commit

Permalink
Merge pull request #68 from CarloLucibello/cl/agnn
Browse files Browse the repository at this point in the history
add AGNNConv
  • Loading branch information
CarloLucibello authored Nov 2, 2021
2 parents 651f761 + c2a1642 commit e5e919c
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 25 deletions.
4 changes: 3 additions & 1 deletion docs/src/api/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ propagate
## Built-in message functions

```@docs
copyxj
copy_xi
copy_xj
xi_dot_xj
```
2 changes: 1 addition & 1 deletion docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibell
## Built-in message functions

In order to exploit optimized specializations of the [`propagate`](@ref), it is recommended
to use built-in message functions such as [`copyxj`](@ref) whenever possible.
to use built-in message functions such as [`copy_xj`](@ref) whenever possible.
3 changes: 2 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ export

# msgpass
apply_edges, propagate,
copyxj,
copy_xj, copy_xi, xi_dot_xj,

# layers/basic
GNNLayer,
GNNChain,
WithGraph,

# layers/conv
AGNNConv,
CGConv,
ChebConv,
EdgeConv,
Expand Down
11 changes: 5 additions & 6 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Deprecated in v0.1
## Deprecated in v0.2

@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)


# Deprecated in v0.2
# TODO check if argument order is exact
function compute_message end
function update_node end
function update_edge end
Expand All @@ -29,3 +24,7 @@ function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
e = update_edge(l, e, m)
return x, e
end

## Deprecated in v0.3

@deprecate copyxj(xi, xj, e) copy_xj(xi, xj, e)
65 changes: 59 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
# @assert all(>(0), degree(g, T, dir=:in))
c = 1 ./ sqrt.(degree(g, T, dir=:in))
x = x .* c'
x = propagate(copyxj, g, +, xj=x)
x = propagate(copy_xj, g, +, xj=x)
x = x .* c'
if Dout >= Din
x = l.weight * x
Expand Down Expand Up @@ -179,7 +179,7 @@ end

function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
return x
end
Expand All @@ -206,7 +206,7 @@ Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.
Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} W \mathbf{x}_j
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
Expand Down Expand Up @@ -338,7 +338,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
end
for i = 1:l.num_layers
M = view(l.weight, :, :, i) * H
M = propagate(copyxj, g, l.aggr; xj=M)
M = propagate(copy_xj, g, l.aggr; xj=M)
H, _ = l.gru(H, M)
end
H
Expand Down Expand Up @@ -420,7 +420,7 @@ GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
end

Expand Down Expand Up @@ -542,7 +542,7 @@ end

function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
return x
end
Expand Down Expand Up @@ -711,3 +711,56 @@ function Base.show(io::IO, l::CGConv)
print(io, ", residual=$(l.residual)")
print(io, ")")
end


@doc raw"""
AGNNConv(init_beta=1f0)
Attention-based Graph Neural Network layer from paper [Attention-based
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
THe forward pass is given by
```math
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}}
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j'}}
```
with the cosine distance defined by
```math
\cos(\mathbf{x}_i, \mathbf{x}_j) =
\mathbf{x}_i \cdot \mathbf{x}_j / \lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert``
```
and ``\beta`` a trainable parameter.
# Arguments
- `init_beta`: The initial value of ``\beta``.
"""
struct AGNNConv{A<:AbstractVector} <: GNNLayer
β::A
end

@functor AGNNConv

function AGNNConv(init_beta = 1f0)
AGNNConv([init_beta])
end

function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
g = add_self_loops(g)

xn = x ./ sqrt.(sum(x.^2, dims=1))
cos_dist = apply_edges(xi_dot_xj, g, xi=xn, xj=xn)
α = softmax_edge_neighbors(g, l.β .* cos_dist)

x = propagate(g, +; xj=x, e=α) do xi, xj, α
α .* xj
end

return x
end

24 changes: 15 additions & 9 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,26 +139,32 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)

### SPECIALIZATIONS OF PROPAGATE ###
"""
copyxj(xi, xj, e) = xj
copy_xj(xi, xj, e) = xj
"""
copyxj(xi, xj, e) = xj
copy_xj(xi, xj, e) = xj

# copyxi(xi, xj, e) = xi
# ximulxj(xi, xj, e) = xi .* xj
# xiaddxj(xi, xj, e) = xi .+ xj
"""
copy_xi(xi, xj, e) = xi
"""
copy_xi(xi, xj, e) = xi

"""
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
"""
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)


function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
A = adjacency_matrix(g)
return xj * A
end

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copyxj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e)->copyxj(xi,xj,e), g, +, xi, xj, e)
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
end

# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g)
# D = compute_degree(A)
# return xj * A * D
Expand Down
2 changes: 1 addition & 1 deletion test/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
end

function new_forward(l, g, x)
x = propagate(copyxj, g, +, xj=x)
x = propagate(copy_xj, g, +, xj=x)
return l.σ.(l.weight * x .+ l.bias)
end

Expand Down
9 changes: 9 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,13 @@
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
end
end


@testset "AGNNConv" begin
l = AGNNConv()
l.β == [1f0]
for g in test_graphs
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
end
end
end

0 comments on commit e5e919c

Please sign in to comment.