Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc improvements #206

Merged
merged 2 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 66 additions & 69 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
GCNConv([graph,] in => out, σ=identity; bias=true, init=glorot_uniform)
GCNConv([fg,] in => out, σ=identity; bias=true, init=glorot_uniform)
Graph convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
Expand Down Expand Up @@ -41,7 +41,7 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
l.σ.(l.weight * x *.+ l.bias)
end

(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GCNConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GCNConv)
Expand All @@ -53,13 +53,13 @@ end


"""
ChebConv([graph, ]in=>out, k; bias=true, init=glorot_uniform)
ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform)
Chebyshev spectral graph convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `k`: The order of Chebyshev polynomial.
Expand All @@ -71,16 +71,14 @@ struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph}
bias::B
fg::S
k::Int
in_channel::Int
out_channel::Int
end

function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int;
init=glorot_uniform, bias::Bool=true)
in, out = ch
W = init(out, in, k)
b = Flux.create_bias(W, bias, out)
ChebConv(W, b, fg, k, in, out)
ChebConv(W, b, fg, k)
end

ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
Expand All @@ -89,11 +87,11 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
@functor ChebConv

function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
= scaled_laplacian(fg, T)

@assert size(X, 1) == c.in_channel "Input feature size must match input channel size."
GraphSignals.check_num_node(L̃, X)
check_num_nodes(fg, X)
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
yuehhua marked this conversation as resolved.
Show resolved Hide resolved

= scaled_laplacian(fg, eltype(X))

Z_prev = X
Z = X *
Y = view(c.weight,:,:,1) * Z_prev
Expand All @@ -105,24 +103,25 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
return Y .+ c.bias
end

(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::ChebConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_channel, " => ", l.out_channel)
print(io, ", k=", l.k)
out, in, k = size(l.weight)
print(io, "ChebConv(", in, " => ", out)
print(io, ", k=", k)
print(io, ")")
end


"""
GraphConv([graph,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
GraphConv([fg,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
Graph neural network layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
Expand Down Expand Up @@ -154,18 +153,17 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) =

@functor GraphConv

message(g::GraphConv, x_i, x_j::AbstractVector, e_ij) = g.weight2 * x_j
message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j

update(g::GraphConv, m::AbstractVector, x::AbstractVector) = g.σ.(g.weight1*x .+ m .+ g.bias)
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)

function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, x)
_, x = propagate(gc, adjacency_list(g), Fill(0.f0, 0, ne(g)), x, +)
check_num_nodes(fg, x)
_, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +)
x
end

(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GraphConv)
Expand All @@ -180,7 +178,7 @@ end


"""
GATConv([graph,] in => out;
GATConv([fg,] in => out;
heads=1,
concat=true,
init=glorot_uniform
Expand All @@ -191,7 +189,7 @@ Graph attentional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
Expand Down Expand Up @@ -225,56 +223,55 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...)
@functor GATConv

# Here the α that has not been softmaxed is the first number of the output message
function message(g::GATConv, x_i::AbstractVector, x_j::AbstractVector)
x_i = reshape(g.weight*x_i, :, g.heads)
x_j = reshape(g.weight*x_j, :, g.heads)
function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector)
x_i = reshape(gat.weight*x_i, :, gat.heads)
x_j = reshape(gat.weight*x_j, :, gat.heads)
x_ij = vcat(x_i, x_j+zero(x_j))
e = sum(x_ij .* g.a, dims=1) # inner product for each head, output shape: (1, g.heads)
e_ij = leakyrelu.(e, g.negative_slope)
vcat(e_ij, x_j) # shape: (n+1, g.heads)
e = sum(x_ij .* gat.a, dims=1) # inner product for each head, output shape: (1, gat.heads)
e_ij = leakyrelu.(e, gat.negative_slope)
vcat(e_ij, x_j) # shape: (n+1, gat.heads)
end

# After some reshaping due to the multihead, we get the α from each message,
# then get the softmax over every α, and eventually multiply the message by α
function apply_batch_message(g::GATConv, i, js, X::AbstractMatrix)
e_ij = mapreduce(j -> GeometricFlux.message(g, _view(X, i), _view(X, j)), hcat, js)
function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js)
n = size(e_ij, 1)
αs = Flux.softmax(reshape(view(e_ij, 1, :), g.heads, :), dims=2)
αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2)
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
reshape(msgs, (n-1)*g.heads, :)
reshape(msgs, (n-1)*gat.heads, :)
end

update_batch_edge(g::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(g, adj, X)
update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X)

function update_batch_edge(g::GATConv, adj, X::AbstractMatrix)
function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix)
n = size(adj, 1)
# a vertex must always receive a message from itself
Zygote.ignore() do
GraphLaplacians.add_self_loop!(adj, n)
end
mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n)
mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n)
end

# The same as update function in batch manner
update_batch_vertex(g::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(g, M)
update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M)

function update_batch_vertex(g::GATConv, M::AbstractMatrix)
M = M .+ g.bias
if !g.concat
function update_batch_vertex(gat::GATConv, M::AbstractMatrix)
M = M .+ gat.bias
if !gat.concat
N = size(M, 2)
M = reshape(mean(reshape(M, :, g.heads, N), dims=2), :, N)
M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N)
end
return M
end

function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, X)
_, X = propagate(gat, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, +)
check_num_nodes(fg, X)
_, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +)
X
end

(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GATConv)
Expand All @@ -287,13 +284,13 @@ end


"""
GatedGraphConv([graph,] out, num_layers; aggr=+, init=glorot_uniform)
GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)
Gated graph convolution layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `out`: The dimension of output features.
- `num_layers`: The number of gated recurrent unit.
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
Expand All @@ -320,27 +317,29 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) =

@functor GatedGraphConv

message(g::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j

update(g::GatedGraphConv, m::AbstractVector, x) = m
update(ggc::GatedGraphConv, m::AbstractVector, x) = m


function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
adj = adjacency_list(fg)
check_num_nodes(fg, H)
m, n = size(H)
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
GraphSignals.check_num_node(adj, H)
(m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n)))

adj = adjacency_list(fg)
if m < ggc.out_ch
Hpad = similar(H, S, ggc.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
end
for i = 1:ggc.num_layers
M = view(ggc.weight, :, :, i) * H
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(adj)), M, +)
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +)
H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381
end
H
end

(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)


Expand All @@ -353,16 +352,15 @@ end


"""
EdgeConv(graph, nn; aggr=max)
EdgeConv([fg,] nn; aggr=max)
Edge convolutional layer.
# Arguments
- `graph`: Optionally pass a FeaturedGraph.
- `nn`: A neural network or a layer. It can be, e.g., MLP.
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
`*`, `/`, `max`, `min` and `mean` are available.
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
- `nn`: A neural network (e.g. a Dense layer or a MLP).
- `aggr`: An aggregate function applied to the result of message function. `+`, `max` and `mean` are available.
"""
struct EdgeConv{V<:AbstractFeaturedGraph} <: MessagePassing
fg::V
Expand All @@ -375,17 +373,16 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...)

@functor EdgeConv

message(e::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = e.nn(vcat(x_i, x_j .- x_i))
update(e::EdgeConv, m::AbstractVector, x) = m
message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
update(ec::EdgeConv, m::AbstractVector, x) = m

function (e::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
g = graph(fg)
GraphSignals.check_num_node(g, X)
_, X = propagate(e, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, e.aggr)
function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
check_num_nodes(fg, X)
_, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr)
X
end

(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::EdgeConv)
Expand Down
9 changes: 6 additions & 3 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,27 @@ end
end

@inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge)
@nospecialize E accu_edge num_V num_E
@nospecialize E accu_edge
return nothing
end

@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E)

@inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E)
@nospecialize E
return nothing
end

@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V)

@inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V)
@nospecialize V
return nothing
end

function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr)
FeaturedGraph(graph(fg), nf=V, ef=E, gf=u)
FeaturedGraph(fg, nf=V, ef=E, gf=u)
end

function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P,
Expand All @@ -63,5 +66,5 @@ function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P,
= aggregate_edges(gn, eaggr, E)
= aggregate_vertices(gn, vaggr, V)
u = update_global(gn, ē, v̄, u)
E, V, u
return E, V, u
end
2 changes: 1 addition & 1 deletion src/layers/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Bypassing graph in FeaturedGraph and let other layer process (node, edge and glo
"""
function bypass_graph(nf_func=identity, ef_func=identity, gf_func=identity)
return function (fg::FeaturedGraph)
FeaturedGraph(graph(fg),
FeaturedGraph(fg,
nf=nf_func(node_feature(fg)),
ef=ef_func(edge_feature(fg)),
gf=gf_func(global_feature(fg)))
Expand Down
Loading