Skip to content

Commit

Permalink
Merge pull request #206 from CarloLucibello/cl/misc
Browse files Browse the repository at this point in the history
misc improvements
  • Loading branch information
yuehhua authored Jul 27, 2021
2 parents 8344085 + c079737 commit 0935c3a
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 168 deletions.
135 changes: 66 additions & 69 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
GCNConv([graph,] in => out, σ=identity; bias=true, init=glorot_uniform)
GCNConv([fg,] in => out, σ=identity; bias=true, init=glorot_uniform)
Graph convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
Expand Down Expand Up @@ -41,7 +41,7 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
l.σ.(l.weight * x *.+ l.bias)
end

(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GCNConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GCNConv)
Expand All @@ -53,13 +53,13 @@ end


"""
ChebConv([graph, ]in=>out, k; bias=true, init=glorot_uniform)
ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform)
Chebyshev spectral graph convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `k`: The order of Chebyshev polynomial.
Expand All @@ -71,16 +71,14 @@ struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph}
bias::B
fg::S
k::Int
in_channel::Int
out_channel::Int
end

function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int;
init=glorot_uniform, bias::Bool=true)
in, out = ch
W = init(out, in, k)
b = Flux.create_bias(W, bias, out)
ChebConv(W, b, fg, k, in, out)
ChebConv(W, b, fg, k)
end

ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
Expand All @@ -89,11 +87,11 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
@functor ChebConv

function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
= scaled_laplacian(fg, T)

@assert size(X, 1) == c.in_channel "Input feature size must match input channel size."
GraphSignals.check_num_node(L̃, X)
check_num_nodes(fg, X)
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."

= scaled_laplacian(fg, eltype(X))

Z_prev = X
Z = X *
Y = view(c.weight,:,:,1) * Z_prev
Expand All @@ -105,24 +103,25 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
return Y .+ c.bias
end

(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::ChebConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_channel, " => ", l.out_channel)
print(io, ", k=", l.k)
out, in, k = size(l.weight)
print(io, "ChebConv(", in, " => ", out)
print(io, ", k=", k)
print(io, ")")
end


"""
GraphConv([graph,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
GraphConv([fg,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
Graph neural network layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
Expand Down Expand Up @@ -154,18 +153,17 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) =

@functor GraphConv

message(g::GraphConv, x_i, x_j::AbstractVector, e_ij) = g.weight2 * x_j
message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j

update(g::GraphConv, m::AbstractVector, x::AbstractVector) = g.σ.(g.weight1*x .+ m .+ g.bias)
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)

function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, x)
_, x = propagate(gc, adjacency_list(g), Fill(0.f0, 0, ne(g)), x, +)
check_num_nodes(fg, x)
_, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +)
x
end

(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GraphConv)
Expand All @@ -180,7 +178,7 @@ end


"""
GATConv([graph,] in => out;
GATConv([fg,] in => out;
heads=1,
concat=true,
init=glorot_uniform
Expand All @@ -191,7 +189,7 @@ Graph attentional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
Expand Down Expand Up @@ -225,56 +223,55 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...)
@functor GATConv

# Here the α that has not been softmaxed is the first number of the output message
function message(g::GATConv, x_i::AbstractVector, x_j::AbstractVector)
x_i = reshape(g.weight*x_i, :, g.heads)
x_j = reshape(g.weight*x_j, :, g.heads)
function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector)
x_i = reshape(gat.weight*x_i, :, gat.heads)
x_j = reshape(gat.weight*x_j, :, gat.heads)
x_ij = vcat(x_i, x_j+zero(x_j))
e = sum(x_ij .* g.a, dims=1) # inner product for each head, output shape: (1, g.heads)
e_ij = leakyrelu.(e, g.negative_slope)
vcat(e_ij, x_j) # shape: (n+1, g.heads)
e = sum(x_ij .* gat.a, dims=1) # inner product for each head, output shape: (1, gat.heads)
e_ij = leakyrelu.(e, gat.negative_slope)
vcat(e_ij, x_j) # shape: (n+1, gat.heads)
end

# After some reshaping due to the multihead, we get the α from each message,
# then get the softmax over every α, and eventually multiply the message by α
function apply_batch_message(g::GATConv, i, js, X::AbstractMatrix)
e_ij = mapreduce(j -> GeometricFlux.message(g, _view(X, i), _view(X, j)), hcat, js)
function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js)
n = size(e_ij, 1)
αs = Flux.softmax(reshape(view(e_ij, 1, :), g.heads, :), dims=2)
αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2)
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
reshape(msgs, (n-1)*g.heads, :)
reshape(msgs, (n-1)*gat.heads, :)
end

update_batch_edge(g::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(g, adj, X)
update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X)

function update_batch_edge(g::GATConv, adj, X::AbstractMatrix)
function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix)
n = size(adj, 1)
# a vertex must always receive a message from itself
Zygote.ignore() do
GraphLaplacians.add_self_loop!(adj, n)
end
mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n)
mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n)
end

# The same as update function in batch manner
update_batch_vertex(g::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(g, M)
update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M)

function update_batch_vertex(g::GATConv, M::AbstractMatrix)
M = M .+ g.bias
if !g.concat
function update_batch_vertex(gat::GATConv, M::AbstractMatrix)
M = M .+ gat.bias
if !gat.concat
N = size(M, 2)
M = reshape(mean(reshape(M, :, g.heads, N), dims=2), :, N)
M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N)
end
return M
end

function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, X)
_, X = propagate(gat, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, +)
check_num_nodes(fg, X)
_, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +)
X
end

(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GATConv)
Expand All @@ -287,13 +284,13 @@ end


"""
GatedGraphConv([graph,] out, num_layers; aggr=+, init=glorot_uniform)
GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)
Gated graph convolution layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `out`: The dimension of output features.
- `num_layers`: The number of gated recurrent unit.
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
Expand All @@ -320,27 +317,29 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) =

@functor GatedGraphConv

message(g::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j

update(g::GatedGraphConv, m::AbstractVector, x) = m
update(ggc::GatedGraphConv, m::AbstractVector, x) = m


function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
adj = adjacency_list(fg)
check_num_nodes(fg, H)
m, n = size(H)
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
GraphSignals.check_num_node(adj, H)
(m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n)))

adj = adjacency_list(fg)
if m < ggc.out_ch
Hpad = similar(H, S, ggc.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
end
for i = 1:ggc.num_layers
M = view(ggc.weight, :, :, i) * H
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(adj)), M, +)
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +)
H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381
end
H
end

(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)


Expand All @@ -353,16 +352,15 @@ end


"""
EdgeConv(graph, nn; aggr=max)
EdgeConv([fg,] nn; aggr=max)
Edge convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `nn`: A neural network or a layer. It can be, e.g., MLP.
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
`*`, `/`, `max`, `min` and `mean` are available.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `nn`: A neural network (e.g. a Dense layer or a MLP).
- `aggr`: An aggregate function applied to the result of message function. `+`, `max` and `mean` are available.
"""
struct EdgeConv{V<:AbstractFeaturedGraph} <: MessagePassing
fg::V
Expand All @@ -375,17 +373,16 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...)

@functor EdgeConv

message(e::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = e.nn(vcat(x_i, x_j .- x_i))
update(e::EdgeConv, m::AbstractVector, x) = m
message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
update(ec::EdgeConv, m::AbstractVector, x) = m

function (e::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, X)
_, X = propagate(e, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, e.aggr)
function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
check_num_nodes(fg, X)
_, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr)
X
end

(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::EdgeConv)
Expand Down
9 changes: 6 additions & 3 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,27 @@ end
end

@inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge)
@nospecialize E accu_edge num_V num_E
@nospecialize E accu_edge
return nothing
end

@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E)

@inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E)
@nospecialize E
return nothing
end

@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V)

@inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V)
@nospecialize V
return nothing
end

function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr)
FeaturedGraph(graph(fg), nf=V, ef=E, gf=u)
FeaturedGraph(fg, nf=V, ef=E, gf=u)
end

function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P,
Expand All @@ -63,5 +66,5 @@ function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P,
= aggregate_edges(gn, eaggr, E)
= aggregate_vertices(gn, vaggr, V)
u = update_global(gn, ē, v̄, u)
E, V, u
return E, V, u
end
2 changes: 1 addition & 1 deletion src/layers/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Bypassing graph in FeaturedGraph and let other layer process (node, edge and glo
"""
function bypass_graph(nf_func=identity, ef_func=identity, gf_func=identity)
return function (fg::FeaturedGraph)
FeaturedGraph(graph(fg),
FeaturedGraph(fg,
nf=nf_func(node_feature(fg)),
ef=ef_func(edge_feature(fg)),
gf=gf_func(global_feature(fg)))
Expand Down
Loading

0 comments on commit 0935c3a

Please sign in to comment.