diff --git a/Project.toml b/Project.toml index f4f02e75a..f909a8201 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,10 @@ version = "0.7.6" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc" GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8" -GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -23,11 +21,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3.3" DataStructures = "0.18" -FillArrays = "0.11, 0.12" Flux = "0.12" -GraphLaplacians = "0.1" GraphMLDatasets = "0.1" -GraphSignals = "0.2" LightGraphs = "1.3" NNlib = "0.7" NNlibCUDA = "0.1" diff --git a/docs/src/abstractions/msgpass.md b/docs/src/abstractions/msgpass.md index cad3a4f0d..8ab9a4bef 100644 --- a/docs/src/abstractions/msgpass.md +++ b/docs/src/abstractions/msgpass.md @@ -20,20 +20,20 @@ A message function accepts feature vector representing node state `x_i`, feature Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following: ``` -propagate(mp, fg::FeaturedGraph, aggr::Symbol=:add) +propagate(mp, fg::FeaturedGraph; aggr::Symbol=+) ``` -`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `:add` represents an aggregate function of addition of all messages. +`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `+` represents an aggregate function of addition of all messages. The following `aggr` are available aggregate functions: -`:add`: sum over all messages -`:sub`: negative of sum over all messages -`:mul`: multiplication over all messages -`:div`: inverse of multiplication over all messages -`:max`: the maximum of all messages -`:min`: the minimum of all messages -`:mean`: the average of all messages +`+`: sum over all messages +`-`: negative of sum over all messages +`*`: multiplication over all messages +`/`: inverse of multiplication over all messages +`max`: the maximum of all messages +`min`: the minimum of all messages +`mean`: the average of all messages ## Update function diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 69be76864..bc37f8e09 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,22 +1,30 @@ module GeometricFlux +using NNlib: similar +using ChainRulesCore: eltype, reshape +using LinearAlgebra: similar using Statistics: mean -using LinearAlgebra: Adjoint, norm, Transpose -using Reexport +using LinearAlgebra using CUDA -using FillArrays: Fill using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor using NNlib, NNlibCUDA -using GraphLaplacians -@reexport using GraphSignals -using LightGraphs using Zygote +using ChainRulesCore +import LightGraphs +using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, + adjacency_matrix, degree export - # layers/gn - GraphNet, + # featured_graph + FeaturedGraph, + edge_index, + node_feature, edge_feature, global_feature, + adjacency_list, normalized_laplacian, scaled_laplacian, + + # from LightGraphs + ne, nv, adjacency_matrix, # layers/msgpass MessagePassing, @@ -44,16 +52,13 @@ export sample, # layer/selector - bypass_graph, - - # utils - generate_cluster + bypass_graph +include("featured_graph.jl") include("datasets.jl") include("utils.jl") -include("layers/gn.jl") include("layers/msgpass.jl") include("layers/conv.jl") @@ -61,9 +66,6 @@ include("layers/pool.jl") include("models.jl") include("layers/misc.jl") -include("cuda/msgpass.jl") -include("cuda/conv.jl") - using .Datasets diff --git a/src/cuda/conv.jl b/src/cuda/conv.jl deleted file mode 100644 index d29c2fce3..000000000 --- a/src/cuda/conv.jl +++ /dev/null @@ -1,26 +0,0 @@ -(g::GCNConv)(L̃::AbstractMatrix, X::CuArray) = g(cu(L̃), X) - -(g::GCNConv)(L̃::CuArray, X::CuArray) = g.σ.(g.weight * X * L̃ .+ g.bias) - -(c::ChebConv)(L̃::AbstractMatrix, X::CuArray) = c(cu(L̃), X) - -function (c::ChebConv)(L̃::CuArray, X::CuArray) - @assert size(X, 1) == c.in_channel "Input feature size must match input channel size." - @assert size(X, 2) == size(L̃, 1) "Input vertex number must match Laplacian matrix size." - - Z_prev = X - Z = X * L̃ - Y = view(c.weight,:,:,1) * Z_prev - Y += view(c.weight,:,:,2) * Z - for k = 3:c.k - Z, Z_prev = 2*Z*L̃ - Z_prev, Z - Y += view(c.weight,:,:,k) * Z - end - return Y .+ c.bias -end - - -# Avoid ambiguity -update_batch_edge(g::GATConv, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} = update_batch_edge(g, adj, X) - -update_batch_vertex(g::GATConv, M::CuMatrix, X::CuMatrix, u) = update_batch_vertex(g, M) diff --git a/src/cuda/msgpass.jl b/src/cuda/msgpass.jl deleted file mode 100644 index 2996656a5..000000000 --- a/src/cuda/msgpass.jl +++ /dev/null @@ -1,41 +0,0 @@ -@inline function update_batch_edge(mp::MessagePassing, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} - E = fill(E.value, E.axes) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::Fill{S,2,Axes}, u) where {S,Axes} - X = fill(X.value, X.axes) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::AbstractMatrix, X::CuMatrix, u) - E = convert(typeof(X), E) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::AbstractMatrix, u) - X = convert(typeof(E), X) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::CuMatrix, u) - n = size(adj, 1) - edge_idx = edge_index_table(adj) - mapreduce(i -> apply_batch_message(mp, i, adj[i], edge_idx, E, X, u), hcat, 1:n) -end - -@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::CuMatrix, X::CuMatrix, u) = - mapreduce(j -> message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) - -@inline function update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::CuMatrix, u) - M = convert(typeof(X), M) - update_batch_vertex(mp, M, X, u) -end - -@inline function update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::AbstractMatrix, u) - X = convert(typeof(M), X) - update_batch_vertex(mp, M, X, u) -end - -@inline update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::CuMatrix, u) = - mapreduce(i -> update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) diff --git a/src/featured_graph.jl b/src/featured_graph.jl new file mode 100644 index 000000000..9d06e60eb --- /dev/null +++ b/src/featured_graph.jl @@ -0,0 +1,279 @@ +#=================================== +Define FeaturedGraph type as a subtype of LightGraphs' AbstractGraph. +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/LightGraphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/LightGraphs.jl/latest/developing/#Developing-Alternate-Graph-Types +=============================================# + +abstract type AbstractFeaturedGraph <: AbstractGraph{Int} end + +""" + NullGraph() + +Null object for `FeaturedGraph`. +""" +struct NullGraph <: AbstractFeaturedGraph end + + +struct FeaturedGraph <: AbstractFeaturedGraph + edge_index + num_nodes::Int + num_edges::Int + # ndata::Dict{String, Any} # https://github.com/FluxML/Zygote.jl/issues/717 + # edata::Dict{String, Any} + # gdata::Dict{String, Any} + nf + ef + gf +end + + +function FeaturedGraph(s::AbstractVector{Int}, t::AbstractVector{Int}; + num_nodes = max(maximum(s), maximum(t)), + # ndata = Dict{String, Any}(), + # edata = Dict{String, Any}(), + # gdata = Dict{String, Any}(), + nf = nothing, + ef = nothing, + gf = nothing) + + @assert length(s) == length(t) + @assert min(minimum(s), minimum(t)) >= 1 + @assert max(maximum(s), maximum(t)) <= num_nodes + + num_edges = length(s) + + ## I would like to have dict data store, but currently this + ## doesn't play well with zygote due to + ## https://github.com/FluxML/Zygote.jl/issues/717 + # ndata["x"] = nf + # edata["e"] = ef + # gdata["g"] = gf + + + FeaturedGraph((s, t), num_nodes, num_edges, + nf, ef, gf) +end + +function FeaturedGraph(adj_mat::AbstractMatrix; dir=:out, kws...) + @assert dir ∈ [:out, :in] + num_nodes = size(adj_mat, 1) + @assert num_nodes == size(adj_mat, 2) + @assert all(x -> (x == 1) || (x == 0), adj_mat) + num_edges = round(Int, sum(adj_mat)) + s = zeros(Int, num_edges) + t = zeros(Int, num_edges) + e = 0 + for j in 1:num_nodes + for i in 1:num_nodes + if adj_mat[i, j] == 1 + e += 1 + s[e] = i + t[e] = j + end + end + end + @assert e == num_edges + if dir == :in + s, t = t, s + end + FeaturedGraph(s, t; num_nodes, kws...) +end + +function FeaturedGraph(adj_list::AbstractVector{<:AbstractVector}; dir=:out, kws...) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + s = zeros(Int, num_edges) + t = zeros(Int, num_edges) + e = 0 + for i in 1:num_nodes + for j in adj_list[i] + e += 1 + s[e] = i + t[e] = j + end + end + @assert e == num_edges + if dir == :in + s, t = t, s + end + FeaturedGraph(s, t; num_nodes, kws...) +end + +FeaturedGraph(g::AbstractGraph; kws...) = FeaturedGraph(adjacency_matrix(g, dir=:out); kws...) + +function FeaturedGraph(fg::FeaturedGraph; + # ndata=copy(fg.ndata), edata=copy(fg.edata), gdata=copy(fg.gdata), # copy keeps the refs to old data + nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) + + FeaturedGraph(fg.edge_index[1], fg.edge_index[2]; + # ndata, edata, gdata, + nf, ef, gf) +end + +@functor FeaturedGraph + +""" + edge_index(fg::FeaturedGraph) + +Return a tuple containing two vectors, respectively containing the source and target +nodes of the edges in the graph `fg`. + +```julia +s, t = edge_index(fg) +``` +""" +edge_index(fg::FeaturedGraph) = fg.edge_index + +LightGraphs.edges(fg::FeaturedGraph) = zip(edge_index(fg)...) + +LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{Int, Int} + +function LightGraphs.has_edge(fg::FeaturedGraph, i::Integer, j::Integer) + s, t = edge_index(fg) + return any((s .== i) .& (t .== j)) +end + +LightGraphs.nv(fg::FeaturedGraph) = fg.num_nodes +LightGraphs.ne(fg::FeaturedGraph) = fg.num_edges +LightGraphs.has_vertex(fg::FeaturedGraph, i::Int) = i in 1:fg.num_nodes +LightGraphs.vertices(fg::FeaturedGraph) = 1:fg.num_nodes + +function LightGraphs.outneighbors(fg::FeaturedGraph, i::Integer) + s, t = edge_index(fg) + return t[s .== i] +end + +function LightGraphs.inneighbors(fg::FeaturedGraph, i::Integer) + s, t = edge_index(fg) + return s[t .== i] +end + +LightGraphs.is_directed(::FeaturedGraph) = true +LightGraphs.is_directed(::Type{FeaturedGraph}) = true + +function adjacency_list(fg::FeaturedGraph; dir=:out) + @assert dir ∈ [:out, :in] + fneighs = dir == :out ? outneighbors : inneighbors + return [fneighs(fg, i) for i in 1:fg.num_nodes] +end + +# TODO return sparse matrix +function LightGraphs.adjacency_matrix(fg::FeaturedGraph, T::DataType=Int; dir=:out) + # TODO dir=:both + s, t = edge_index(fg) + n = fg.num_nodes + adj_mat = fill!(similar(s, T, (n, n)), 0) + adj_mat[s .+ n .* (t .- 1)] .= 1 # exploiting linear indexing + return dir == :out ? adj_mat : adj_mat' +end + +function LightGraphs.degree(fg::FeaturedGraph; dir=:both) + s, t = edge_index(fg) + degs = fill!(similar(s, eltype(s), fg.num_nodes), 0) + o = fill!(similar(s, eltype(s), fg.num_edges), 1) + if dir ∈ [:out, :both] + NNlib.scatter!(+, degs, o, s) + end + if dir ∈ [:in, :both] + NNlib.scatter!(+, degs, o, t) + end + return degs +end + +# node_feature(fg::FeaturedGraph) = fg.ndata["x"] +# edge_feature(fg::FeaturedGraph) = fg.edata["e"] +# global_feature(fg::FeaturedGraph) = fg.gdata["g"] + +node_feature(fg::FeaturedGraph) = fg.nf +edge_feature(fg::FeaturedGraph) = fg.ef +global_feature(fg::FeaturedGraph) = fg.gf + +# function Base.getproperty(fg::FeaturedGraph, sym::Symbol) +# if sym === :nf +# return fg.ndata["x"] +# elseif sym === :ef +# return fg.edata["e"] +# elseif sym === :gf +# return fg.gdata["g"] +# else # fallback to getfield +# return getfield(fg, sym) +# end +# end + +# function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:out) +# A = adjacency_matrix(fg, T; dir=dir) +# D = Diagonal(vec(sum(A; dims=2))) +# return D - A +# end + +## from GraphLaplacians + +""" + normalized_laplacian(fg, T=Float32; selfloop=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `fg`: A `FeaturedGraph`. +- `T`: result element type of degree vector; default `Float32`. +- `selfloop`: adding self loop while calculating the matrix. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(fg::FeaturedGraph, T::DataType=Float32; selfloop::Bool=false, dir::Symbol=:out) + A = adjacency_matrix(fg, T; dir=dir) + selfloop && (A += I) + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return I - inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g[, T]; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). +- `T`: result element type of degree vector; default is the element type of `g` (optional). +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(fg::FeaturedGraph, T::DataType=Float32; dir=:out) + A = adjacency_matrix(fg, T; dir=dir) + @assert issymmetric(A) "scaled_laplacian only works with symmetric matrices" + E = eigen(Symmetric(A)).values + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + Lnorm = I - inv_sqrtD * A * inv_sqrtD + return 2 / maximum(E) * Lnorm - I +end + +function add_self_loops(fg::FeaturedGraph) + s, t = edge_index(fg) + @assert edge_feature(fg) === nothing + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + n = fg.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + FeaturedGraph(s, t, nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) +end + +@non_differentiable normalized_laplacian(x...) +@non_differentiable scaled_laplacian(x...) +@non_differentiable adjacency_matrix(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable degree(x...) +@non_differentiable add_self_loops(x...) + +# delete when https://github.com/JuliaDiff/ChainRules.jl/pull/472 is merged +function ChainRulesCore.rrule(::typeof(copy), x) + copy_pullback(ȳ) = (NoTangent(), ȳ) + return copy(x), copy_pullback +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 256daba86..4d9717162 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -12,11 +12,10 @@ Graph convolutional layer. - `bias`: Add learnable bias. - `init`: Weights' initializer. - The input to the layer is a node feature array `X` of size `(num_features, num_nodes)`. """ -struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} +struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} <: MessagePassing weight::A bias::B σ::F @@ -36,9 +35,23 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) = @functor GCNConv +# function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) +# L̃ = normalized_laplacian(fg, eltype(x); selfloop=true) +# l.σ.(l.weight * x * L̃ .+ l.bias) +# end + +message(l::GCNConv, xi, xj) = xj +update(l::GCNConv, m, x) = m + function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) - L̃ = normalized_laplacian(fg, eltype(x); selfloop=true) - l.σ.(l.weight * x * L̃ .+ l.bias) + fg = add_self_loops(fg) + T = eltype(l.weight) + # cout = sqrt.(degree(fg, dir=:out)) + cin = reshape(sqrt.(T.(degree(fg, dir=:in))), 1, :) + x = cin .* x + _, x = propagate(l, fg, nothing, x, nothing, +) + x = cin .* x + return l.σ.(l.weight * x .+ l.bias) end (l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) @@ -153,13 +166,13 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) = @functor GraphConv -message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j +message(gc::GraphConv, x_i, x_j, e_ij) = x_j -update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias) +update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias) function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) check_num_nodes(fg, x) - _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) + _, x = propagate(gc, fg, nothing, x, nothing, +) x end @@ -194,7 +207,7 @@ Graph attentional layer. - `out`: The dimension of output features. - `bias::Bool`: Keyword argument, whether to learn the additive bias. - `heads`: Number attention heads -- `concat`: Concatenate layer output or not. If not, layer output is averaged. +- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. - `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU. """ struct GATConv{V<:AbstractFeaturedGraph, T, A<:AbstractMatrix{T}, B} <: MessagePassing @@ -222,53 +235,34 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...) @functor GATConv -# Here the α that has not been softmaxed is the first number of the output message -function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector) - x_i = reshape(gat.weight*x_i, :, gat.heads) - x_j = reshape(gat.weight*x_j, :, gat.heads) - x_ij = vcat(x_i, x_j+zero(x_j)) - e = sum(x_ij .* gat.a, dims=1) # inner product for each head, output shape: (1, gat.heads) - e_ij = leakyrelu.(e, gat.negative_slope) - vcat(e_ij, x_j) # shape: (n+1, gat.heads) -end - -# After some reshaping due to the multihead, we get the α from each message, -# then get the softmax over every α, and eventually multiply the message by α -function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix) - e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js) - n = size(e_ij, 1) - αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2) - msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :) - reshape(msgs, (n-1)*gat.heads, :) -end - -update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X) - -function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) - n = size(adj, 1) - # a vertex must always receive a message from itself - Zygote.ignore() do - GraphLaplacians.add_self_loop!(adj, n) - end - mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) -end - -# The same as update function in batch manner -update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M) - -function update_batch_vertex(gat::GATConv, M::AbstractMatrix) - M = M .+ gat.bias +function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) + check_num_nodes(fg, X) + fg = add_self_loops(fg) + chin, chout = gat.channel + heads = gat.heads + + source, target = edge_index(fg) + Wx = gat.weight*X + Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes + Wxi = NNlib.gather(Wx, target) # chout × nheads × nedges + Wxj = NNlib.gather(Wx, source) + + # Edge Message + # Computing softmax. TODO make it numerically stable + aWW = sum(gat.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges + α = exp.(leakyrelu.(aWW, gat.negative_slope)) + m̄ = NNlib.scatter(+, α .* Wxj, target) # chout × nheads × nnodes + ᾱ = NNlib.scatter(+, α, target) # 1 × nheads × nnodes + + # Node update + b = reshape(gat.bias, chout, heads) + X = m̄ ./ ᾱ .+ b # chout × nheads × nnodes if !gat.concat - N = size(M, 2) - M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N) + X = sum(X, dims=2) end - return M -end -function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) - check_num_nodes(fg, X) - _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) - X + # We finally return a matrix + return reshape(X, :, size(X, 3)) end (l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) @@ -317,24 +311,23 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) = @functor GatedGraphConv -message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j +message(ggc::GatedGraphConv, x_i, x_j, e_ij) = x_j -update(ggc::GatedGraphConv, m::AbstractVector, x) = m +update(ggc::GatedGraphConv, m, x) = m function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} check_num_nodes(fg, H) m, n = size(H) @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." - adj = adjacency_list(fg) if m < ggc.out_ch Hpad = similar(H, S, ggc.out_ch - m, n) H = vcat(H, fill!(Hpad, 0)) end for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H - _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +) - H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381 + _, M = propagate(ggc, fg, nothing, M, nothing, +) + H, _ = ggc.gru(H, M) end H end @@ -373,12 +366,13 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...) @functor EdgeConv -message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) -update(ec::EdgeConv, m::AbstractVector, x) = m +message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) + +update(ec::EdgeConv, m, x) = m function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) - _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) + _, X = propagate(ec, fg, nothing, X, nothing, ec.aggr) X end diff --git a/src/layers/gn.jl b/src/layers/gn.jl deleted file mode 100644 index 496d3dcf2..000000000 --- a/src/layers/gn.jl +++ /dev/null @@ -1,70 +0,0 @@ -_view(::Nothing, i) = nothing -_view(A::Fill{T,2,Axes}, i) where {T,Axes} = view(A, :, 1) -_view(A::AbstractMatrix, idx) = view(A, :, idx) - -aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2)) -aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2)) -aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2)) -aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2)) -aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2)) -aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2)) -aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2)) - -abstract type GraphNet end - -@inline update_edge(gn::GraphNet, e, vi, vj, u) = e -@inline update_vertex(gn::GraphNet, ē, vi, u) = vi -@inline update_global(gn::GraphNet, ē, v̄, u) = u - -@inline function update_batch_edge(gn::GraphNet, adj, E, V, u) - n = size(adj, 1) - edge_idx = edge_index_table(adj) - mapreduce(i -> apply_batch_message(gn, i, adj[i], edge_idx, E, V, u), hcat, 1:n) -end - -@inline apply_batch_message(gn::GraphNet, i, js, edge_idx, E, V, u) = - mapreduce(j -> update_edge(gn, _view(E, edge_idx[(i,j)]), _view(V, i), _view(V, j), u), hcat, js) - -@inline update_batch_vertex(gn::GraphNet, Ē, V, u) = - mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2)) - -@inline function aggregate_neighbors(gn::GraphNet, aggr, E, accu_edge) - @assert !iszero(accu_edge) "accumulated edge must not be zero." - cluster = generate_cluster(E, accu_edge) - NNlib.scatter(aggr, E, cluster) -end - -@inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge) - @nospecialize E accu_edge - return nothing -end - -@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E) - -@inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E) - @nospecialize E - return nothing -end - -@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V) - -@inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V) - @nospecialize V - return nothing -end - -function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) - E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr) - FeaturedGraph(fg, nf=V, ef=E, gf=u) -end - -function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P, - naggr=nothing, eaggr=nothing, vaggr=nothing) where {S<:AbstractVector,R,Q,P} - E = update_batch_edge(gn, adj, E, V, u) - Ē = aggregate_neighbors(gn, naggr, E, accumulated_edges(adj)) - V = update_batch_vertex(gn, Ē, V, u) - ē = aggregate_edges(gn, eaggr, E) - v̄ = aggregate_vertices(gn, vaggr, V) - u = update_global(gn, ē, v̄, u) - return E, V, u -end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 9aee05391..557d70d0d 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -1,7 +1,70 @@ -abstract type MessagePassing <: GraphNet end +# Adapted message passing from paper +# "Relational inductive biases, deep learning, and graph networks" +""" + MessagePassing + +The abstract type from which all message passing layers are derived. + +Related methods are [`propagate`](@ref), [`message`](@ref), +[`update`](@ref), [`update_edge`](@ref), and [`update_global`](@ref). +""" +abstract type MessagePassing end + +""" + propagate(mp::MessagePassing, fg::FeaturedGraph, aggr) + propagate(mp::MessagePassing, fg::FeaturedGraph, E, X, u, aggr) + +Perform the sequence of operation implementing the message-passing scheme +and updating node, edge, and global features `X`, `E`, and `u` respectively. + +The computation involved is the following: + +```julia +M = compute_batch_message(mp, fg, E, X, u) +E = update_edge(mp, M, E, u) +M̄ = aggregate_neighbors(mp, aggr, fg, M) +X = update(mp, M̄, X, u) +u = update_global(mp, E, X, u) +``` + +Custom layers sub-typing [`MessagePassing`](@ref) +typically call define their own [`update`](@ref) +and [`message`](@ref) function, than call +this method in the forward pass: + +```julia +function (l::GNNLayer)(fg, X) + ... some prepocessing if needed ... + E = nothing + u = nothing + propagate(l, fg, E, X, u, +) +end +``` + +See also [`message`](@ref) and [`update`](@ref). +""" +function propagate end + +function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr) + E, X, u = propagate(mp, fg, + edge_feature(fg), node_feature(fg), global_feature(fg), + aggr) + FeaturedGraph(fg, nf=X, ef=E, gf=u) +end + +function propagate(mp::MessagePassing, fg::FeaturedGraph, E, X, u, aggr) + M = compute_batch_message(mp, fg, E, X, u) + E = update_edge(mp, M, E, u) + M̄ = aggregate_neighbors(mp, aggr, fg, M) + X = update(mp, M̄, X, u) + u = update_global(mp, E, X, u) + return E, X, u +end """ + message(mp::MessagePassing, x_i, x_j, e_ij, u) message(mp::MessagePassing, x_i, x_j, e_ij) + message(mp::MessagePassing, x_i, x_j) Message function for the message-passing scheme, returning the message from node `j` to node `i` . @@ -15,51 +78,82 @@ specialize this method with custom behavior. # Arguments -- `mp`: message-passing layer. -- `x_i`: the features of node `i`. -- `x_j`: the features of the nighbor `j` of node `i`. -- `e_ij`: the features of edge (`i`, `j`). +- `mp`: A [`MessagePassing`](@ref) layer. +- `x_i`: Features of the central node `i`. +- `x_j`: Features of the neighbor `j` of node `i`. +- `e_ij`: Features of edge (`i`, `j`). +- `u`: Global features. -See also [`update`](@ref). +See also [`update`](@ref) and [`propagate`](@ref). """ -@inline message(mp::MessagePassing, x_i, x_j, e_ij) = x_j -@inline message(mp::MessagePassing, i::Integer, j::Integer, x_i, x_j, e_ij) = x_j +function message end """ - update(mp::MessagePassing, m, x) + update(mp::MessagePassing, m̄, x, u) + update(mp::MessagePassing, m̄, x) Update function for the message-passing scheme, returning a new set of node features `x′` based on old features `x` and the incoming message from the neighborhood -aggregation `m`. +aggregation `m̄`. -By default, the function returns `m`. +By default, the function returns `m̄`. Layers subtyping [`MessagePassing`](@ref) should specialize this method with custom behavior. # Arguments -- `mp`: message-passing layer. -- `m`: the aggregated edge messages from the [`message`](@ref) function. -- `x`: the node features to be updated. +- `mp`: A [`MessagePassing`](@ref) layer. +- `m̄`: Aggregated edge messages from the [`message`](@ref) function. +- `x`: Node features to be updated. +- `u`: Global features. -See also [`message`](@ref). +See also [`message`](@ref) and [`propagate`](@ref). """ -@inline update(mp::MessagePassing, m, x) = m -@inline update(mp::MessagePassing, i::Integer, m, x) = m +function update end -@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::AbstractMatrix, X::AbstractMatrix, u) = - mapreduce(j -> GeometricFlux.message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) -@inline update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::AbstractMatrix, u) = - mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) +_gather(x, i) = NNlib.gather(x, i) +_gather(x::Nothing, i) = nothing -function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+) - E, X = propagate(mp, adjacency_list(fg), fg.ef, fg.nf, aggr) - FeaturedGraph(fg, nf=X, ef=E, gf=Fill(0.f0, 0)) +## Step 1. + +function compute_batch_message(mp::MessagePassing, fg, E, X, u) + s, t = edge_index(fg) + Xi = _gather(X, t) + Xj = _gather(X, s) + M = message(mp, Xi, Xj, E, u) + return M end -function propagate(mp::MessagePassing, adj::AbstractVector{S}, E::R, X::Q, aggr) where {S<:AbstractVector,R,Q} - E, X, u = propagate(mp, adj, E, X, Fill(0.f0, 0), aggr, nothing, nothing) - E, X +# @inline message(mp::MessagePassing, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future +@inline message(mp::MessagePassing, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij) +@inline message(mp::MessagePassing, x_i, x_j, e_ij) = message(mp, x_i, x_j) +@inline message(mp::MessagePassing, x_i, x_j) = x_j + +## Step 2 + +@inline update_edge(mp::MessagePassing, M, E, u) = update_edge(mp::MessagePassing, M, E) +@inline update_edge(mp::MessagePassing, M, E) = E + +## Step 3 + +function aggregate_neighbors(mp::MessagePassing, aggr, fg, E) + s, t = edge_index(fg) + NNlib.scatter(aggr, E, t) end + +aggregate_neighbors(mp::MessagePassing, aggr::Nothing, fg, E) = nothing + +## Step 4 + +# @inline update(mp::MessagePassing, i, m̄, x, u) = update(mp, m, x, u) +@inline update(mp::MessagePassing, m̄, x, u) = update(mp, m̄, x) +@inline update(mp::MessagePassing, m̄, x) = m̄ + +## Step 5 + +@inline update_global(mp::MessagePassing, E, X, u) = u + +### end steps ### + diff --git a/src/utils.jl b/src/utils.jl index 9a59b563f..67066ec4d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,79 +1,5 @@ -""" - accumulated_edges(adj) -Return a vector which acts as a mapping table. The index is the vertex index, -value is accumulated numbers of edge (current vertex not included). -""" -accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}}) = [0, cumsum(map(length, adj))...] - -Zygote.@nograd accumulated_edges - -Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, accu_edge) where {T,N} - num_V = length(accu_edge) - 1 - num_E = accu_edge[end] - cluster = similar(M, Int, num_E) - @inbounds for i = 1:num_V - j = accu_edge[i] - k = accu_edge[i+1] - cluster[j+1:k] .= i - end - cluster -end - -""" - edge_index_table(adj[, directed]) - -Generate a mapping from vertex pair (i, j) to edge index. The edge indecies are determined by -the sorted vertex indecies. -""" -function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}, directed::Bool=is_directed(adj)) - table = Dict{Tuple{UInt32,UInt32},UInt64}() - e = one(UInt64) - if directed - for (i, js) = enumerate(adj) - js = sort(js) - for j = js - table[(i, j)] = e - e += one(UInt64) - end - end - else - for (i, js) = enumerate(adj) - js = sort(js) - js = js[i .≤ js] - for j = js - table[(i, j)] = e - table[(j, i)] = e - e += one(UInt64) - end - end - end - table -end - -function edge_index_table(vpair::AbstractVector{<:Tuple}) - table = Dict{Tuple{UInt32,UInt32},UInt64}() - for (i, p) = enumerate(vpair) - table[p] = i - end - table -end - -edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph, fg.directed) - -Zygote.@nograd edge_index_table - -### TODO move these to GraphSignals ###### -import GraphSignals: FeaturedGraph - -function FeaturedGraph(fg::FeaturedGraph; - nf=node_feature(fg), - ef=edge_feature(fg), - gf=global_feature(fg)) - - return FeaturedGraph(graph(fg); nf, ef, gf) -end function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) @assert nv(fg) == size(x, ndims(x)) -end +end \ No newline at end of file diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 386edc94e..5ddf68411 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -1,21 +1,19 @@ -using Flux: Dense, gpu - -in_channel = 3 -out_channel = 5 -N = 4 -adj = [0 1 0 1; - 1 0 1 0; - 0 1 0 1; - 1 0 1 0] +@testset "cuda/conv" begin + in_channel = 3 + out_channel = 5 + N = 4 + adj = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] -fg = FeaturedGraph(adj) + fg = FeaturedGraph(adj) -@testset "cuda/conv" begin @testset "GCNConv" begin gc = GCNConv(fg, in_channel=>out_channel) |> gpu @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test Array(adjacency_matrix(gc.fg)) == adj + @test adjacency_matrix(gc.fg |> cpu) == adj X = rand(in_channel, N) |> gpu Y = gc(X) @@ -35,21 +33,23 @@ fg = FeaturedGraph(adj) cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test Array(adjacency_matrix(cc.fg)) == adj + @test adjacency_matrix(cc.fg |> cpu) == adj @test cc.k == k - @test cc.in_channel == in_channel - @test cc.out_channel == out_channel + + @test_broken begin + X = rand(in_channel, N) |> gpu + Y = cc(X) + @test size(Y) == (out_channel, N) - X = rand(in_channel, N) |> gpu - Y = cc(X) - @test size(Y) == (out_channel, N) + g = Zygote.gradient(x -> sum(cc(x)), X)[1] + @test size(g) == size(X) - # g = Zygote.gradient(x -> sum(cc(x)), X)[1] - # @test size(g) == size(X) + g = Zygote.gradient(model -> sum(model(X)), cc)[1] + @test size(g.weight) == size(cc.weight) + @test size(g.bias) == size(cc.bias) - # g = Zygote.gradient(model -> sum(model(X)), cc)[1] - # @test size(g.weight) == size(cc.weight) - # @test size(g.bias) == size(cc.bias) + true + end end @testset "GraphConv" begin diff --git a/test/cuda/featured_graph.jl b/test/cuda/featured_graph.jl new file mode 100644 index 000000000..15430e849 --- /dev/null +++ b/test/cuda/featured_graph.jl @@ -0,0 +1,36 @@ +@testset "featured graph" begin + s = [1,1,2,3,4,5,5,5] + t = [2,5,3,2,1,4,3,1] + s, t = [s; t], [t; s] #symmetrize + fg = FeaturedGraph(s, t) + fg_gpu = fg |> gpu + + @testset "functor" begin + s_gpu, t_gpu = edge_index(fg_gpu) + @test s_gpu isa CuVector{Int} + @test Array(s_gpu) == s + @test t_gpu isa CuVector{Int} + @test Array(t_gpu) == t + end + + @testset "adjacency_matrix" begin + mat = adjacency_matrix(fg) + mat_gpu = adjacency_matrix(fg_gpu) + @test mat_gpu isa CuMatrix{Int} + end + + @testset "normalized_laplacian" begin + mat = normalized_laplacian(fg) + mat_gpu = normalized_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + end + + @testset "scaled_laplacian" begin + @test_broken begin + mat = scaled_laplacian(fg) + mat_gpu = scaled_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + true + end + end +end diff --git a/test/cuda/msgpass.jl b/test/cuda/msgpass.jl index e372a41b5..8f1259799 100644 --- a/test/cuda/msgpass.jl +++ b/test/cuda/msgpass.jl @@ -2,12 +2,12 @@ in_channel = 10 out_channel = 5 N = 6 T = Float32 -adj = [0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] +adj = [0 1 0 0 0 0; + 1 0 0 1 1 1; + 0 0 0 0 0 1; + 0 1 0 0 1 0; + 0 1 0 1 0 1; + 0 1 1 0 1 0] struct NewCudaLayer <: MessagePassing weight @@ -15,12 +15,12 @@ end NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n)) @functor NewCudaLayer -(l::NewCudaLayer)(X) = propagate(l, X, +) +(l::NewCudaLayer)(fg) = GeometricFlux.propagate(l, fg, +) GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j GeometricFlux.update(::NewCudaLayer, m, x) = m X = rand(T, in_channel, N) |> gpu -fg = FeaturedGraph(adj, nf=X) +fg = FeaturedGraph(adj, nf=X) |> gpu l = NewCudaLayer(out_channel, in_channel) |> gpu @testset "cuda/msgpass" begin diff --git a/test/featured_graph.jl b/test/featured_graph.jl new file mode 100644 index 000000000..5000b8fa9 --- /dev/null +++ b/test/featured_graph.jl @@ -0,0 +1,81 @@ +@testset "FeaturedGraph" begin + @testset "symmetric graph" begin + u = [1, 2, 3, 4, 2, 3, 4, 1] + v = [2, 3, 4, 1, 1, 2, 3, 4] + adj_mat = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + adj_list_out = [[2,4], [3,1], [4,2], [1,3]] + adj_list_in = [[4,2], [1,3], [2,4], [3,1]] + + # core functionality + fg = FeaturedGraph(u, v) + @test fg.num_edges == 8 + @test fg.num_nodes == 4 + @test collect(edges(fg)) == collect(zip(u, v)) + @test sort(outneighbors(fg, 1)) == [2, 4] + @test sort(inneighbors(fg, 1)) == [2, 4] + @test is_directed(fg) == true + + # adjacency + @test adjacency_matrix(fg) == adj_mat + @test adjacency_matrix(fg; dir=:in) == adj_mat + @test adjacency_matrix(fg; dir=:out) == adj_mat + @test adjacency_list(fg; dir=:in) == adj_list_in + @test adjacency_list(fg; dir=:out) == adj_list_out + + @testset "constructors" begin + fg = FeaturedGraph(adj_mat) + adjacency_matrix(fg; dir=:out) == adj_mat + adjacency_matrix(fg; dir=:in) == adj_mat + end + + @testset "degree" begin + fg = FeaturedGraph(adj_mat) + @test degree(fg, dir=:out) == vec(sum(adj_mat, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat, dims=1)) + end + end + + @testset "asymmetric graph" begin + u = [1, 2, 3, 4] + v = [2, 3, 4, 1] + adj_mat_out = [0 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + adj_list_out = [[2], [3], [4], [1]] + + + adj_mat_in = [0 0 0 1 + 1 0 0 0 + 0 1 0 0 + 0 0 1 0] + adj_list_in = [[4], [1], [2], [3]] + + # core functionality + fg = FeaturedGraph(u, v) + @test fg.num_edges == 4 + @test fg.num_nodes == 4 + @test collect(edges(fg)) == collect(zip(u, v)) + @test sort(outneighbors(fg, 1)) == [2] + @test sort(inneighbors(fg, 1)) == [4] + @test is_directed(fg) == true + + # adjacency + @test adjacency_matrix(fg) == adj_mat_out + @test adjacency_list(fg) == adj_list_out + @test adjacency_matrix(fg, dir=:out) == adj_mat_out + @test adjacency_list(fg, dir=:out) == adj_list_out + @test adjacency_matrix(fg, dir=:in) == adj_mat_in + @test adjacency_list(fg, dir=:in) == adj_list_in + + @testset "degree" begin + fg = FeaturedGraph(adj_mat_out) + @test degree(fg, dir=:out) == vec(sum(adj_mat_out, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat_out, dims=1)) + end + end + +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index ffe6ee3cf..ca5a47367 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -1,23 +1,22 @@ -T = Float32 -in_channel = 3 -out_channel = 5 -N = 4 -adj = T[0. 1. 0. 1.; - 1. 0. 1. 0.; - 0. 1. 0. 1.; - 1. 0. 1. 0.] - -fg = FeaturedGraph(adj) - -adj_single_vertex = T[0. 0. 0. 1.; - 0. 0. 0. 0.; - 0. 0. 0. 1.; - 1. 0. 1. 0.] - -fg_single_vertex = FeaturedGraph(adj_single_vertex) - - @testset "layer" begin + T = Float32 + in_channel = 3 + out_channel = 5 + N = 4 + adj = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + + fg = FeaturedGraph(adj) + + adj_single_vertex = [0 0 0 1 + 0 0 0 0 + 0 0 0 1 + 1 0 1 0] + + fg_single_vertex = FeaturedGraph(adj_single_vertex) + @testset "GCNConv" begin X = rand(T, in_channel, N) Xt = transpose(rand(T, N, in_channel)) @@ -25,13 +24,15 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(fg, in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test graph(gc.fg) === adj + @test adjacency_matrix(gc.fg) == adj Y = gc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = gc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(gc(x)), X)[1] @@ -46,20 +47,22 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test !has_graph(gc.fg) + # @test !has_graph(gc.fg) fg = FeaturedGraph(adj, nf=X) fg_ = gc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = gc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight) == size(gc.weight) @@ -81,14 +84,17 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(fg, in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test graph(cc.fg) === adj + @test adjacency_matrix(cc.fg) == adj + @test cc.k == k Y = cc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = cc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(cc(x)), X)[1] @@ -103,21 +109,23 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test !has_graph(cc.fg) + # @test !has_graph(cc.fg) @test cc.k == k fg = FeaturedGraph(adj, nf=X) fg_ = cc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError cc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = cc(fgt) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(cc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), cc)[1] @test size(g.weight) == size(cc.weight) @@ -141,10 +149,12 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(gc.bias) == (out_channel,) Y = gc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = gc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(gc(x)), X)[1] @@ -164,16 +174,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) fg = FeaturedGraph(adj, nf=X) fg_ = gc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = gc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight1) == size(gc.weight1) @@ -209,10 +221,12 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(gat.a) == (2*out_channel, heads) Y = gat(X) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) # Test with transposed features Y = gat(Xt) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(gat(x)), X)[1] @@ -235,16 +249,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) fg_ = gat(fg_gat) Y = node_feature(fg_) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) @test_throws MethodError gat(X) # Test with transposed features fgt = FeaturedGraph(adj_gat, nf=Xt) fgt_ = gat(fgt) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(node_feature(gat(x))), fg_gat)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg_gat))), gat)[1] @test size(g.weight) == size(gat.weight) @@ -269,11 +285,12 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(ggc.weight) == (out_channel, out_channel, num_layers) Y = ggc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) - # Test with transposed features Y = ggc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(ggc(x)), X)[1] @@ -289,16 +306,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) fg = FeaturedGraph(adj, nf=X) fg_ = ggc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ggc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = ggc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ggc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ggc)[1] @test size(g.weight) == size(ggc.weight) @@ -313,10 +332,12 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test adjacency_list(ec.fg) == [[2,4], [1,3], [2,4], [1,3]] Y = ec(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = ec(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(ec(x)), X)[1] @@ -332,16 +353,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) fg = FeaturedGraph(adj, nf=X) fg_ = ec(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ec(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = ec(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ec(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ec)[1] @test size(g.nn.weight) == size(ec.nn.weight) diff --git a/test/layers/gn.jl b/test/layers/gn.jl deleted file mode 100644 index 0c32ec1f9..000000000 --- a/test/layers/gn.jl +++ /dev/null @@ -1,76 +0,0 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 7 -T = Float32 - -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] - -struct NewGNLayer <: GraphNet -end - -V = rand(T, in_channel, num_V) -E = rand(T, in_channel, 2num_E) -u = rand(T, in_channel) - -@testset "gn" begin - l = NewGNLayer() - - @testset "without aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg) - - fg = FeaturedGraph(adj, nf=V) - fg_ = l(fg) - - @test graph(fg_) === adj - @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (0, 2*num_E) - @test size(global_feature(fg_)) == (0,) - end - - @testset "with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) - l = NewGNLayer() - fg_ = l(fg) - - @test graph(fg_) === adj - @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (in_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) - end - - GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = rand(T, out_channel) - @testset "update edge with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) - l = NewGNLayer() - fg_ = l(fg) - - @test graph(fg_) === adj - @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) - end - - GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) - @testset "update edge/vertex with all aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=u) - l = NewGNLayer() - fg_ = l(fg) - - @test graph(fg_) === adj - @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (in_channel,) - end -end diff --git a/test/layers/misc.jl b/test/layers/misc.jl index 58ba83029..9f89500cd 100644 --- a/test/layers/misc.jl +++ b/test/layers/misc.jl @@ -16,7 +16,7 @@ x -> x .+ 2., x -> x .+ 3.) fg_ = layer(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test node_feature(fg_) == nf .+ 1. @test edge_feature(fg_) == ef .+ 2. @test global_feature(fg_) == gf .+ 3. diff --git a/test/layers/msgpass.jl b/test/layers/msgpass.jl index 9e9c86caa..54e2ed2c3 100644 --- a/test/layers/msgpass.jl +++ b/test/layers/msgpass.jl @@ -1,55 +1,114 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 7 -T = Float32 - -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] - -struct NewLayer <: MessagePassing - weight -end -NewLayer(m, n) = NewLayer(randn(T, m,n)) +@testset "MessagePassing" begin + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 14 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + struct NewLayer <: MessagePassing end + + X = rand(T, in_channel, num_V) + E = rand(T, in_channel, num_E) + u = rand(T, in_channel) -(l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) -X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V)) -fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E)) + @testset "no aggregation" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, nothing) + + fg = FeaturedGraph(adj, nf=X) + fg_ = l(fg) + + @test adjacency_matrix(fg_) == adj + @test node_feature(fg_) === nothing + @test edge_feature(fg_) === nothing + @test global_feature(fg_) === nothing + end -l = NewLayer(out_channel, in_channel) + @testset "neighbor aggregation (+)" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) -@testset "msgpass" begin - @testset "no message or update" begin + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (in_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test edge_feature(fg_) ≈ E + @test global_feature(fg_) ≈ u + end + + GeometricFlux.message(l::NewLayer, xi, xj, e, u) = ones(T, out_channel, size(e,2)) + + @testset "custom message and neighbor aggregation" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) + fg_ = l(fg) + + @test adjacency_matrix(fg_) == adj + @test size(node_feature(fg_)) == (out_channel, num_V) + @test edge_feature(fg_) ≈ edge_feature(fg) + @test global_feature(fg_) ≈ global_feature(fg) end - GeometricFlux.message(l::NewLayer, x_i, x_j, e_ij) = l.weight * x_j - @testset "message function" begin + GeometricFlux.update_edge(l::NewLayer, m, e) = m + + @testset "update_edge" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test size(edge_feature(fg_)) == (out_channel, num_E) + @test global_feature(fg_) ≈ global_feature(fg) + end + + GeometricFlux.update(l::NewLayer, m̄, xi, u) = rand(T, 2*out_channel, size(xi, 2)) + + @testset "update edge/vertex" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) + fg_ = l(fg) + + @test all(adjacency_matrix(fg_) .== adj) + @test size(node_feature(fg_)) == (2*out_channel, num_V) + @test size(edge_feature(fg_)) == (out_channel, num_E) + @test size(global_feature(fg_)) == (in_channel,) + end + + struct NewLayerW <: MessagePassing + weight end - GeometricFlux.update(l::NewLayer, m, x) = l.weight * x + m - @testset "message and update" begin + NewLayerW(in, out) = NewLayerW(randn(T, out, in)) + + GeometricFlux.message(l::NewLayerW, x_i, x_j, e_ij) = l.weight * x_j + GeometricFlux.update(l::NewLayerW, m, x) = l.weight * x + m + + @testset "message and update with weights" begin + l = NewLayerW(in_channel, out_channel) + (l::NewLayerW)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test edge_feature(fg_) === E + @test global_feature(fg_) === u end end diff --git a/test/runtests.jl b/test/runtests.jl index 46f335f15..427e45bc9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,23 +2,24 @@ using GeometricFlux using GeometricFlux.Datasets using Flux using Flux: @functor -using FillArrays -using GraphSignals -using LightGraphs: SimpleGraph, SimpleDiGraph, add_edge!, nv, ne +using LightGraphs using LinearAlgebra using NNlib -using SparseArrays: SparseMatrixCSC using Statistics: mean using Zygote using Test +using CUDA +using Flux: gpu +using NNlibCUDA +CUDA.allowscalar(false) cuda_tests = [ - # "cuda/conv", - # "cuda/msgpass", + "cuda/conv", + "cuda/msgpass", ] tests = [ - "layers/gn", + "featured_graph", "layers/msgpass", "layers/conv", "layers/pool", @@ -27,9 +28,6 @@ tests = [ ] if Flux.use_cuda[] - using CUDA - using Flux: gpu - using NNlibCUDA append!(tests, cuda_tests) else @warn "CUDA unavailable, not testing GPU support"