Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AGNNConv #68

Merged
merged 2 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/src/api/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ propagate
## Built-in message functions

```@docs
copyxj
copy_xi
copy_xj
xi_dot_xj
```
2 changes: 1 addition & 1 deletion docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibell
## Built-in message functions

In order to exploit optimized specializations of the [`propagate`](@ref), it is recommended
to use built-in message functions such as [`copyxj`](@ref) whenever possible.
to use built-in message functions such as [`copy_xj`](@ref) whenever possible.
3 changes: 2 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ export

# msgpass
apply_edges, propagate,
copyxj,
copy_xj, copy_xi, xi_dot_xj,

# layers/basic
GNNLayer,
GNNChain,
WithGraph,

# layers/conv
AGNNConv,
CGConv,
ChebConv,
EdgeConv,
Expand Down
11 changes: 5 additions & 6 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# Deprecated in v0.1
## Deprecated in v0.2

@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)


# Deprecated in v0.2
# TODO check if argument order is exact
function compute_message end
function update_node end
function update_edge end
Expand All @@ -29,3 +24,7 @@ function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
e = update_edge(l, e, m)
return x, e
end

## Deprecated in v0.3

@deprecate copyxj(xi, xj, e) copy_xj(xi, xj, e)
65 changes: 59 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
# @assert all(>(0), degree(g, T, dir=:in))
c = 1 ./ sqrt.(degree(g, T, dir=:in))
x = x .* c'
x = propagate(copyxj, g, +, xj=x)
x = propagate(copy_xj, g, +, xj=x)
x = x .* c'
if Dout >= Din
x = l.weight * x
Expand Down Expand Up @@ -179,7 +179,7 @@ end

function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
return x
end
Expand All @@ -206,7 +206,7 @@ Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.

Implements the operation
```math
\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} W \mathbf{x}_j
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
Expand Down Expand Up @@ -338,7 +338,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
end
for i = 1:l.num_layers
M = view(l.weight, :, :, i) * H
M = propagate(copyxj, g, l.aggr; xj=M)
M = propagate(copy_xj, g, l.aggr; xj=M)
H, _ = l.gru(H, M)
end
H
Expand Down Expand Up @@ -420,7 +420,7 @@ GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
end

Expand Down Expand Up @@ -542,7 +542,7 @@ end

function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m = propagate(copyxj, g, l.aggr, xj=x)
m = propagate(copy_xj, g, l.aggr, xj=x)
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
return x
end
Expand Down Expand Up @@ -711,3 +711,56 @@ function Base.show(io::IO, l::CGConv)
print(io, ", residual=$(l.residual)")
print(io, ")")
end


@doc raw"""
AGNNConv(init_beta=1f0)

Attention-based Graph Neural Network layer from paper [Attention-based
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).

THe forward pass is given by
```math
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
```
where the attention coefficients ``\alpha_{ij}`` are given by
```math
\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}}
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j'}}
```
with the cosine distance defined by
```math
\cos(\mathbf{x}_i, \mathbf{x}_j) =
\mathbf{x}_i \cdot \mathbf{x}_j / \lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert``
```
and ``\beta`` a trainable parameter.

# Arguments

- `init_beta`: The initial value of ``\beta``.
"""
struct AGNNConv{A<:AbstractVector} <: GNNLayer
β::A
end

@functor AGNNConv

function AGNNConv(init_beta = 1f0)
AGNNConv([init_beta])
end

function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
g = add_self_loops(g)

xn = x ./ sqrt.(sum(x.^2, dims=1))
cos_dist = apply_edges(xi_dot_xj, g, xi=xn, xj=xn)
α = softmax_edge_neighbors(g, l.β .* cos_dist)

x = propagate(g, +; xj=x, e=α) do xi, xj, α
α .* xj
end

return x
end

24 changes: 15 additions & 9 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,26 +139,32 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)

### SPECIALIZATIONS OF PROPAGATE ###
"""
copyxj(xi, xj, e) = xj
copy_xj(xi, xj, e) = xj
"""
copyxj(xi, xj, e) = xj
copy_xj(xi, xj, e) = xj

# copyxi(xi, xj, e) = xi
# ximulxj(xi, xj, e) = xi .* xj
# xiaddxj(xi, xj, e) = xi .+ xj
"""
copy_xi(xi, xj, e) = xi
"""
copy_xi(xi, xj, e) = xi

"""
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
"""
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)


function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
A = adjacency_matrix(g)
return xj * A
end

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copyxj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e)->copyxj(xi,xj,e), g, +, xi, xj, e)
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
end

# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g)
# D = compute_degree(A)
# return xj * A * D
Expand Down
2 changes: 1 addition & 1 deletion test/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
end

function new_forward(l, g, x)
x = propagate(copyxj, g, +, xj=x)
x = propagate(copy_xj, g, +, xj=x)
return l.σ.(l.weight * x .+ l.bias)
end

Expand Down
9 changes: 9 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,13 @@
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
end
end


@testset "AGNNConv" begin
l = AGNNConv()
l.β == [1f0]
for g in test_graphs
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
end
end
end