From 0f5ae6c8df1262ca270f09b73469b731ab9d38b8 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 20 Jul 2021 14:26:53 +0200 Subject: [PATCH 01/11] implement COO featured graph --- Project.toml | 5 +- src/GeometricFlux.jl | 27 +++- src/cuda/conv.jl | 32 ++--- src/featured_graph.jl | 302 +++++++++++++++++++++++++++++++++++++++++ src/layers/conv.jl | 31 +++++ src/layers/gn.jl | 7 +- src/layers/msgpass.jl | 6 +- src/models.jl | 4 + src/utils.jl | 15 +- test/cuda/conv.jl | 8 ++ test/featured_graph.jl | 73 ++++++++++ test/layers/conv.jl | 23 ++-- test/layers/gn.jl | 16 ++- test/layers/misc.jl | 2 +- test/layers/msgpass.jl | 16 ++- test/runtests.jl | 5 +- 16 files changed, 511 insertions(+), 61 deletions(-) create mode 100644 src/featured_graph.jl create mode 100644 test/featured_graph.jl diff --git a/Project.toml b/Project.toml index f4f02e75a..aa0f9ff27 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,11 @@ version = "0.7.6" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc" GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8" -GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -25,9 +24,7 @@ CUDA = "3.3" DataStructures = "0.18" FillArrays = "0.11, 0.12" Flux = "0.12" -GraphLaplacians = "0.1" GraphMLDatasets = "0.1" -GraphSignals = "0.2" LightGraphs = "1.3" NNlib = "0.7" NNlibCUDA = "0.1" diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 69be76864..bbfe79581 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,20 +1,34 @@ module GeometricFlux using Statistics: mean -using LinearAlgebra: Adjoint, norm, Transpose -using Reexport +using LinearAlgebra +using FillArrays: Fill using CUDA -using FillArrays: Fill using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor using NNlib, NNlibCUDA -using GraphLaplacians -@reexport using GraphSignals -using LightGraphs using Zygote +using ChainRulesCore + + +# import GraphLaplacians +# using GraphLaplacians: normalized_laplacian, scaled_laplacian +# using GraphLaplacians: adjacency_matrix +# using Reexport +# @reexport using GraphSignals +import LightGraphs +using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, adjacency_matrix export + # featured_graph + FeaturedGraph, + adjacency_list, + # graph, + # has_graph, + node_feature, edge_feature, global_feature, + # ne, nv, adjacency_matrix, # from LightGraphs + # layers/gn GraphNet, @@ -49,6 +63,7 @@ export # utils generate_cluster +include("featured_graph.jl") include("datasets.jl") include("utils.jl") diff --git a/src/cuda/conv.jl b/src/cuda/conv.jl index d29c2fce3..d6464b37e 100644 --- a/src/cuda/conv.jl +++ b/src/cuda/conv.jl @@ -1,23 +1,23 @@ -(g::GCNConv)(L̃::AbstractMatrix, X::CuArray) = g(cu(L̃), X) +# (g::GCNConv)(L̃::AbstractMatrix, X::CuArray) = g(cu(L̃), X) -(g::GCNConv)(L̃::CuArray, X::CuArray) = g.σ.(g.weight * X * L̃ .+ g.bias) +# (g::GCNConv)(L̃::CuArray, X::CuArray) = g.σ.(g.weight * X * L̃ .+ g.bias) -(c::ChebConv)(L̃::AbstractMatrix, X::CuArray) = c(cu(L̃), X) +# (c::ChebConv)(L̃::AbstractMatrix, X::CuArray) = c(cu(L̃), X) -function (c::ChebConv)(L̃::CuArray, X::CuArray) - @assert size(X, 1) == c.in_channel "Input feature size must match input channel size." - @assert size(X, 2) == size(L̃, 1) "Input vertex number must match Laplacian matrix size." +# function (c::ChebConv)(L̃::CuArray, X::CuArray) +# @assert size(X, 1) == c.in_channel "Input feature size must match input channel size." +# @assert size(X, 2) == size(L̃, 1) "Input vertex number must match Laplacian matrix size." - Z_prev = X - Z = X * L̃ - Y = view(c.weight,:,:,1) * Z_prev - Y += view(c.weight,:,:,2) * Z - for k = 3:c.k - Z, Z_prev = 2*Z*L̃ - Z_prev, Z - Y += view(c.weight,:,:,k) * Z - end - return Y .+ c.bias -end +# Z_prev = X +# Z = X * L̃ +# Y = view(c.weight,:,:,1) * Z_prev +# Y += view(c.weight,:,:,2) * Z +# for k = 3:c.k +# Z, Z_prev = 2*Z*L̃ - Z_prev, Z +# Y += view(c.weight,:,:,k) * Z +# end +# return Y .+ c.bias +# end # Avoid ambiguity diff --git a/src/featured_graph.jl b/src/featured_graph.jl new file mode 100644 index 000000000..1dd10c551 --- /dev/null +++ b/src/featured_graph.jl @@ -0,0 +1,302 @@ +#=================================== +Define FeaturedGraph type as a subtype of LightGraphs' AbstractGraph. + +All LightGraphs functions rely on a standard API to function. +As long as your graph structure is a subtype of AbstractGraph and +implements the following API functions with the given return values, +all functions within the LightGraphs package should just work: + edges + Base.eltype + edgetype (example: edgetype(g::CustomGraph) = LightGraphs.SimpleEdge{eltype(g)})) + has_edge + has_vertex + inneighbors + ne + nv + outneighbors + vertices + is_directed(::Type{CustomGraph})::Bool (example: is_directed(::Type{<:CustomGraph}) = false) + is_directed(g::CustomGraph)::Bool + zero +=============================================# + +abstract type AbstractFeaturedGraph <: AbstractGraph{Int} end + +""" + NullGraph() + +Null object for `FeaturedGraph`. +""" +struct NullGraph <: AbstractFeaturedGraph end + + +struct FeaturedGraph <: AbstractFeaturedGraph + edge_index + num_nodes::Int + num_edges::Int + # ndata::Dict{String, Any} # https://github.com/FluxML/Zygote.jl/issues/717 + # edata::Dict{String, Any} + # gdata::Dict{String, Any} + nf + ef + gf +end + + +function FeaturedGraph(u::AbstractVector{Int}, v::AbstractVector{Int}; + num_nodes = max(maximum(u), maximum(v)), + # ndata = Dict{String, Any}(), + # edata = Dict{String, Any}(), + # gdata = Dict{String, Any}(), + nf = nothing, + ef = nothing, + gf = nothing) + + @assert length(u) == length(v) + @assert min(minimum(u), minimum(v)) >= 1 + @assert max(maximum(u), maximum(v)) <= num_nodes + + num_edges = length(u) + + ## I would like to have dict data store, but currently this + ## doesn't play well with zygote due to + ## https://github.com/FluxML/Zygote.jl/issues/717 + # ndata["x"] = nf + # edata["e"] = ef + # gdata["g"] = gf + + + FeaturedGraph((u, v), num_nodes, num_edges, + nf, ef, gf) +end + +# Construct from adjacency matrix # TODO deprecate? +function FeaturedGraph(adj_mat::AbstractMatrix; dir=:out, kws...) + @assert dir == :out # TODO + num_nodes = size(adj_mat, 1) + @assert num_nodes == size(adj_mat, 2) + @assert all(x -> (x == 1) || (x == 0), adj_mat) + num_edges = round(Int, sum(adj_mat)) + u = zeros(Int, num_edges) + v = zeros(Int, num_edges) + e = 0 + for j in 1:num_nodes + for i in 1:num_nodes + if adj_mat[i, j] == 1 + e += 1 + u[e] = i + v[e] = j + end + end + end + @assert e == num_edges + FeaturedGraph(u, v; num_nodes, kws...) +end + + +# Construct from adjacency list # TODO deprecate? +function FeaturedGraph(adj_list::AbstractVector{<:AbstractVector}; dir=:out, kws...) + @assert dir == :out # TODO + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + u = zeros(Int, num_edges) + v = zeros(Int, num_edges) + e = 0 + for i in 1:num_nodes + for j in adj_list[i] + e += 1 + u[e] = i + v[e] = j + end + end + @assert e == num_edges + FeaturedGraph(u, v; num_nodes, kws...) +end + + +# from other featured_graph +function FeaturedGraph(fg::FeaturedGraph; + # ndata=copy(fg.ndata), edata=copy(fg.edata), gdata=copy(fg.gdata), # copy keeps the refs to old data + nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) + + FeaturedGraph(fg.edge_index[1], fg.edge_index[2]; + # ndata, edata, gdata, + nf, ef, gf) +end + + +@functor FeaturedGraph + +LightGraphs.edges(fg::FeaturedGraph) = zip(fg.edge_index[1], fg.edge_index[2]) + +LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{eltype(fg.edge_index[1]), eltype(fg.edge_index[2])} + +function LightGraphs.has_edge(fg::FeaturedGraph, i::Integer, j::Integer) + u, v = fg.edge_index + return any((u .== i) .& (v .== j)) +end + +LightGraphs.nv(fg::FeaturedGraph) = fg.num_nodes +LightGraphs.ne(fg::FeaturedGraph) = fg.num_edges +LightGraphs.has_vertex(fg::FeaturedGraph, i::Int) = i in 1:fg.num_nodes +LightGraphs.vertices(fg::FeaturedGraph) = 1:fg.num_nodes + +function LightGraphs.outneighbors(fg::FeaturedGraph, i::Integer) + u, v = fg.edge_index + return v[u .== i] +end + +function LightGraphs.inneighbors(fg::FeaturedGraph, i::Integer) + u, v = fg.edge_index + return u[v .== i] +end + +LightGraphs.is_directed(::FeaturedGraph) = true +LightGraphs.is_directed(::Type{FeaturedGraph}) = true + +function adjacency_list(fg::FeaturedGraph; dir=:out) + # TODO probably this has to be called with `dir=:in` by gnn layers + # TODO dir=:both + fneighs = dir == :out ? outneighbors : inneighbors + return [fneighs(fg, i) for i in 1:fg.num_nodes] +end + +# TODO return sparse matrix +function LightGraphs.adjacency_matrix(fg::FeaturedGraph, T::DataType=Int; dir=:out) + # TODO dir=:both + u, v = fg.edge_index + n = fg.num_nodes + adj_mat = zeros(T, n, n) + adj_mat[u .+ n .* (v .- 1)] .= 1 # exploiting linear indexing + return dir == :out ? adj_mat : adj_mat' +end + +Zygote.@nograd adjacency_matrix, adjacency_list + + +# function ChainRulesCore.rrule(::typeof(copy), x) +# copy_pullback(ȳ) = (NoTangent(), ȳ) +# return copy(x), copy_pullback +# end + +# node_feature(fg::FeaturedGraph) = fg.ndata["x"] +# edge_feature(fg::FeaturedGraph) = fg.edata["e"] +# global_feature(fg::FeaturedGraph) = fg.gdata["g"] + +node_feature(fg::FeaturedGraph) = fg.nf +edge_feature(fg::FeaturedGraph) = fg.ef +global_feature(fg::FeaturedGraph) = fg.gf + +## TO DEPRECATE EVERYTHING BELOW ??? ############################## + +# has_graph(fg::FeaturedGraph) = true +# has_graph(fg::NullGraph) = false +# graph(fg::FeaturedGraph) = adjacency_list(fg) # DEPRECATE + +# function Base.getproperty(fg::FeaturedGraph, sym::Symbol) +# if sym === :nf +# return fg.ndata["x"] +# elseif sym === :ef +# return fg.edata["e"] +# elseif sym === :gf +# return fg.gdata["g"] +# else # fallback to getfield +# return getfield(fg, sym) +# end +# end + +## Already in GraphSignals ############## +LightGraphs.ne(adj_list::AbstractVector{<:AbstractVector}) = sum(length.(adj_list)) +LightGraphs.nv(adj_list::AbstractVector{<:AbstractVector}) = length(adj_list) +LightGraphs.ne(adj_mat::AbstractMatrix) = round(Int, sum(adj_mat)) +LightGraphs.nv(adj_mat::AbstractMatrix) = size(adj_mat, 1) + +adjacency_list(adj::AbstractVector{<:AbstractVector}) = adj + +function LightGraphs.is_directed(g::AbstractVector{T}) where {T<:AbstractVector} + edges = Set{Tuple{Int64,Int64}}() + for (i, js) in enumerate(g) + for j in Set(js) + if i != j + e = (i,j) + if e in edges + pop!(edges, e) + else + push!(edges, (j,i)) + end + end + end + end + !isempty(edges) +end + +LightGraphs.is_directed(g::AbstractMatrix) = !issymmetric(Matrix(g)) + +# function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:out) +# A = adjacency_matrix(fg, T; dir=dir) +# D = Diagonal(vec(sum(A; dims=2))) +# return D - A +# end + +## from GraphLaplacians + +""" + normalized_laplacian(g[, T]; selfloop=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). +- `T`: result element type of degree vector; default is the element type of `g` (optional). +- `selfloop`: adding self loop while calculating the matrix (optional). +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(fg::FeaturedGraph, T::DataType=Int; selfloop::Bool=false, dir::Symbol=:out) + A = adjacency_matrix(fg, T; dir=dir) + selfloop && (A += I) + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return I - inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g[, T]; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). +- `T`: result element type of degree vector; default is the element type of `g` (optional). +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(fg::FeaturedGraph, T::DataType=Int; dir=:out) + A = adjacency_matrix(fg, T; dir=dir) + @assert issymmetric(A) "scaled_laplacian only works with symmetric matrices" + E = eigen(Symmetric(A)).values + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + Lnorm = I - inv_sqrtD * A * inv_sqrtD + return 2 / maximum(E) * Lnorm - I +end + + +function add_self_loop!(adj::AbstractVector{<:AbstractVector}) + for i = 1:length(adj) + i in adj[i] || push!(adj[i], i) + end + adj +end + +# # TODO Do we need a separate package just for laplacians? +# GraphLaplacians.scaled_laplacian(fg::FeaturedGraph, T::DataType) = +# scaled_laplacian(adjacency_matrix(fg, T)) +# GraphLaplacians.normalized_laplacian(fg::FeaturedGraph, T::DataType; kws...) = +# normalized_laplacian(adjacency_matrix(fg, T); kws...) + + +@non_differentiable normalized_laplacian(x...) +@non_differentiable scaled_laplacian(x...) +@non_differentiable add_self_loop!(x...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 256daba86..f451c8493 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -88,10 +88,16 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) = function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T check_num_nodes(fg, X) +<<<<<<< HEAD @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." L̃ = scaled_laplacian(fg, eltype(X)) +======= + L̃ = scaled_laplacian(fg, eltype(X)) + + @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." +>>>>>>> 17dbba7 (implement COO featured graph) Z_prev = X Z = X * L̃ Y = view(c.weight,:,:,1) * Z_prev @@ -159,7 +165,11 @@ update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1* function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) check_num_nodes(fg, x) +<<<<<<< HEAD _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) +======= + _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) +>>>>>>> 17dbba7 (implement COO featured graph) x end @@ -246,11 +256,16 @@ update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) n = size(adj, 1) +<<<<<<< HEAD # a vertex must always receive a message from itself Zygote.ignore() do GraphLaplacians.add_self_loop!(adj, n) end mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) +======= + add_self_loop!(adj) + mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n) +>>>>>>> 17dbba7 (implement COO featured graph) end # The same as update function in batch manner @@ -267,7 +282,11 @@ end function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) +<<<<<<< HEAD _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) +======= + _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) +>>>>>>> 17dbba7 (implement COO featured graph) X end @@ -324,6 +343,7 @@ update(ggc::GatedGraphConv, m::AbstractVector, x) = m function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} check_num_nodes(fg, H) +<<<<<<< HEAD m, n = size(H) @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." adj = adjacency_list(fg) @@ -331,6 +351,13 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T Hpad = similar(H, S, ggc.out_ch - m, n) H = vcat(H, fill!(Hpad, 0)) end +======= + adj = adjacency_list(fg) + m, n = size(H) + @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." + (m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n))) + +>>>>>>> 17dbba7 (implement COO featured graph) for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +) @@ -378,7 +405,11 @@ update(ec::EdgeConv, m::AbstractVector, x) = m function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) +<<<<<<< HEAD _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) +======= + _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) +>>>>>>> 17dbba7 (implement COO featured graph) X end diff --git a/src/layers/gn.jl b/src/layers/gn.jl index 496d3dcf2..360d58e97 100644 --- a/src/layers/gn.jl +++ b/src/layers/gn.jl @@ -1,5 +1,4 @@ _view(::Nothing, i) = nothing -_view(A::Fill{T,2,Axes}, i) where {T,Axes} = view(A, :, 1) _view(A::AbstractMatrix, idx) = view(A, :, idx) aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2)) @@ -54,7 +53,13 @@ end end function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) +<<<<<<< HEAD E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr) +======= + E, V, u = propagate(gn, adjacency_list(fg), + edge_feature(fg), node_feature(fg), global_feature(fg), + naggr, eaggr, vaggr) +>>>>>>> 17dbba7 (implement COO featured graph) FeaturedGraph(fg, nf=V, ef=E, gf=u) end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 9aee05391..b0f20a8ea 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -55,11 +55,11 @@ See also [`message`](@ref). mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+) - E, X = propagate(mp, adjacency_list(fg), fg.ef, fg.nf, aggr) - FeaturedGraph(fg, nf=X, ef=E, gf=Fill(0.f0, 0)) + E, X = propagate(mp, adjacency_list(fg), edge_feature(fg), node_feature(fg), aggr) + FeaturedGraph(fg, nf=X, ef=E) end function propagate(mp::MessagePassing, adj::AbstractVector{S}, E::R, X::Q, aggr) where {S<:AbstractVector,R,Q} - E, X, u = propagate(mp, adj, E, X, Fill(0.f0, 0), aggr, nothing, nothing) + E, X, u = propagate(mp, adj, E, X, nothing, aggr, nothing, nothing) E, X end diff --git a/src/models.jl b/src/models.jl index 9a9594453..21439933f 100644 --- a/src/models.jl +++ b/src/models.jl @@ -71,7 +71,11 @@ end function (i::InnerProductDecoder)(fg::FeaturedGraph) Z = node_feature(fg) A = i(Z) +<<<<<<< HEAD return FeaturedGraph(fg, nf=A) +======= + FeaturedGraph(fg, nf=A) +>>>>>>> 17dbba7 (implement COO featured graph) end diff --git a/src/utils.jl b/src/utils.jl index 9a59b563f..6bc8b5f51 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,10 +8,12 @@ accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}}) = [0, cumsum Zygote.@nograd accumulated_edges -Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, accu_edge) where {T,N} + +Zygote.@nograd function generate_cluster(M, accu_edge) num_V = length(accu_edge) - 1 num_E = accu_edge[end] - cluster = similar(M, Int, num_E) + # cluster = similar(M, Int, num_E) + cluster = zeros(Int, num_E) @inbounds for i = 1:num_V j = accu_edge[i] k = accu_edge[i+1] @@ -59,8 +61,6 @@ function edge_index_table(vpair::AbstractVector{<:Tuple}) table end -edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph, fg.directed) - Zygote.@nograd edge_index_table ### TODO move these to GraphSignals ###### @@ -77,3 +77,10 @@ end function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) @assert nv(fg) == size(x, ndims(x)) end +<<<<<<< HEAD +======= + +### TODO move these to GraphSignals ###### +# @functor FeaturedGraph +# Zygote.@nograd normalized_laplacian, scaled_laplacian +>>>>>>> 17dbba7 (implement COO featured graph) diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 386edc94e..757b5ed92 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -15,7 +15,11 @@ fg = FeaturedGraph(adj) gc = GCNConv(fg, in_channel=>out_channel) |> gpu @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) +<<<<<<< HEAD @test Array(adjacency_matrix(gc.fg)) == adj +======= + @test Array(adjacency_matrix(gc.fg) ≈ adj +>>>>>>> 17dbba7 (implement COO featured graph) X = rand(in_channel, N) |> gpu Y = gc(X) @@ -35,7 +39,11 @@ fg = FeaturedGraph(adj) cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) +<<<<<<< HEAD @test Array(adjacency_matrix(cc.fg)) == adj +======= + @test Array(adjacency_matrix(cc.fg)) ≈ adj +>>>>>>> 17dbba7 (implement COO featured graph) @test cc.k == k @test cc.in_channel == in_channel @test cc.out_channel == out_channel diff --git a/test/featured_graph.jl b/test/featured_graph.jl new file mode 100644 index 000000000..e0dc7031e --- /dev/null +++ b/test/featured_graph.jl @@ -0,0 +1,73 @@ +using LightGraphs + +@testset "FeaturedGraph" begin + @testset "symmetric graph" begin + u = [1, 2, 3, 4, 2, 3, 4, 1] + v = [2, 3, 4, 1, 1, 2, 3, 4] + adj_mat = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + adj_list_out = [[2,4], [3,1], [4,2], [1,3]] + adj_list_in = [[4,2], [1,3], [2,4], [3,1]] + + # core functionality + fg = FeaturedGraph(u, v) + @test fg.num_edges == 8 + @test fg.num_nodes == 4 + @test collect(edges(fg)) == collect(zip(u, v)) + @test sort(outneighbors(fg, 1)) == [2, 4] + @test sort(inneighbors(fg, 1)) == [2, 4] + @test is_directed(fg) == true + + # adjacency + @test adjacency_matrix(fg) == adj_mat + @test adjacency_matrix(fg; dir=:in) == adj_mat + @test adjacency_matrix(fg; dir=:out) == adj_mat + @test adjacency_list(fg; dir=:in) == adj_list_in + @test adjacency_list(fg; dir=:out) == adj_list_out + + @testset "constructors" begin + fg = FeaturedGraph(adj_mat) + adjacency_matrix(fg; dir=:out) == adj_mat + adjacency_matrix(fg; dir=:in) == adj_mat + end + end + + @testset "asymmetric graph" begin + u = [1, 2, 3, 4] + v = [2, 3, 4, 1] + adj_mat_out = [0 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + adj_list_out = [[2], [3], [4], [1]] + + + adj_mat_in = [0 0 0 1 + 1 0 0 0 + 0 1 0 0 + 0 0 1 0] + adj_list_in = [[4], [1], [2], [3]] + + # core functionality + fg = FeaturedGraph(u, v) + @test fg.num_edges == 4 + @test fg.num_nodes == 4 + @test collect(edges(fg)) == collect(zip(u, v)) + @test sort(outneighbors(fg, 1)) == [2] + @test sort(inneighbors(fg, 1)) == [4] + @test is_directed(fg) == true + + # adjacency + @test adjacency_matrix(fg) == adj_mat_out + @test adjacency_list(fg) == adj_list_out + @test adjacency_matrix(fg, dir=:out) == adj_mat_out + @test adjacency_list(fg, dir=:out) == adj_list_out + @test adjacency_matrix(fg, dir=:in) == adj_mat_in + @test adjacency_list(fg, dir=:in) == adj_list_in + end + + + +end \ No newline at end of file diff --git a/test/layers/conv.jl b/test/layers/conv.jl index ffe6ee3cf..b55a437a9 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -15,7 +15,7 @@ adj_single_vertex = T[0. 0. 0. 1.; 1. 0. 1. 0.] fg_single_vertex = FeaturedGraph(adj_single_vertex) - + @testset "layer" begin @testset "GCNConv" begin @@ -25,7 +25,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(fg, in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test graph(gc.fg) === adj + @test all(adjacency_matrix(gc.fg) .== adj) Y = gc(X) @test size(Y) == (out_channel, N) @@ -46,7 +46,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test !has_graph(gc.fg) + # @test !has_graph(gc.fg) fg = FeaturedGraph(adj, nf=X) fg_ = gc(fg) @@ -59,7 +59,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight) == size(gc.weight) @@ -81,7 +81,8 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(fg, in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test graph(cc.fg) === adj + @test adjacency_matrix(cc.fg) == adj + @test cc.k == k Y = cc(X) @@ -103,7 +104,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test !has_graph(cc.fg) + # @test !has_graph(cc.fg) @test cc.k == k fg = FeaturedGraph(adj, nf=X) @@ -117,7 +118,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(cc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), cc)[1] @test size(g.weight) == size(cc.weight) @@ -173,7 +174,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight1) == size(gc.weight1) @@ -244,7 +245,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(node_feature(gat(x))), fg_gat)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg_gat))), gat)[1] @test size(g.weight) == size(gat.weight) @@ -298,7 +299,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ggc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ggc)[1] @test size(g.weight) == size(ggc.weight) @@ -341,7 +342,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ec(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ec)[1] @test size(g.nn.weight) == size(ec.nn.weight) diff --git a/test/layers/gn.jl b/test/layers/gn.jl index 0c32ec1f9..5e47e7409 100644 --- a/test/layers/gn.jl +++ b/test/layers/gn.jl @@ -27,10 +27,12 @@ u = rand(T, in_channel) fg = FeaturedGraph(adj, nf=V) fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (0, 2*num_E) - @test size(global_feature(fg_)) == (0,) + # @test edge_feature(fg_) === nothing # TODO + @test global_feature(fg_) === nothing + # @test size(edge_feature(fg_)) == (0, 2*num_E) + # @test size(global_feature(fg_)) == (0,) end @testset "with neighbor aggregation" begin @@ -40,13 +42,14 @@ u = rand(T, in_channel) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (in_channel, 2*num_E) @test size(global_feature(fg_)) == (0,) end GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = rand(T, out_channel) + @testset "update edge with neighbor aggregation" begin (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) @@ -54,13 +57,14 @@ u = rand(T, in_channel) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test all(adjacency_matrix(fg_) .== adj) @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) @test size(global_feature(fg_)) == (0,) end GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) + @testset "update edge/vertex with all aggregation" begin (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +) @@ -68,7 +72,7 @@ u = rand(T, in_channel) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test all(adjacency_matrix(fg_) .== adj) @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) @test size(global_feature(fg_)) == (in_channel,) diff --git a/test/layers/misc.jl b/test/layers/misc.jl index 58ba83029..9f89500cd 100644 --- a/test/layers/misc.jl +++ b/test/layers/misc.jl @@ -16,7 +16,7 @@ x -> x .+ 2., x -> x .+ 3.) fg_ = layer(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test node_feature(fg_) == nf .+ 1. @test edge_feature(fg_) == ef .+ 2. @test global_feature(fg_) == gf .+ 3. diff --git a/test/layers/msgpass.jl b/test/layers/msgpass.jl index 9e9c86caa..f7b49c9c6 100644 --- a/test/layers/msgpass.jl +++ b/test/layers/msgpass.jl @@ -27,29 +27,31 @@ l = NewLayer(out_channel, in_channel) @testset "no message or update" begin fg_ = l(fg) - @test graph(fg_) == adj + @test all(adjacency_matrix(fg_) .== adj) @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (in_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test global_feature(fg_) === nothing end GeometricFlux.message(l::NewLayer, x_i, x_j, e_ij) = l.weight * x_j + @testset "message function" begin fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test global_feature(fg_) === nothing end - + GeometricFlux.update(l::NewLayer, m, x) = l.weight * x + m + @testset "message and update" begin fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test global_feature(fg_) === nothing end end diff --git a/test/runtests.jl b/test/runtests.jl index 46f335f15..98dc5c3ec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,8 +3,8 @@ using GeometricFlux.Datasets using Flux using Flux: @functor using FillArrays -using GraphSignals -using LightGraphs: SimpleGraph, SimpleDiGraph, add_edge!, nv, ne +# using GraphSignals +using LightGraphs using LinearAlgebra using NNlib using SparseArrays: SparseMatrixCSC @@ -18,6 +18,7 @@ cuda_tests = [ ] tests = [ + "featured_graph", "layers/gn", "layers/msgpass", "layers/conv", From 0a54ce38a6998d064e9bdf1489738e4df03639bf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 28 Jul 2021 13:45:10 +0200 Subject: [PATCH 02/11] rebase --- src/featured_graph.jl | 1 - src/layers/conv.jl | 36 +----------------------------------- src/layers/gn.jl | 4 ---- src/models.jl | 4 ---- src/utils.jl | 20 +------------------- test/cuda/conv.jl | 8 -------- 6 files changed, 2 insertions(+), 71 deletions(-) diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 1dd10c551..2c6731d45 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -124,7 +124,6 @@ function FeaturedGraph(fg::FeaturedGraph; nf, ef, gf) end - @functor FeaturedGraph LightGraphs.edges(fg::FeaturedGraph) = zip(fg.edge_index[1], fg.edge_index[2]) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f451c8493..58044d11f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -88,16 +88,10 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) = function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T check_num_nodes(fg, X) -<<<<<<< HEAD @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." L̃ = scaled_laplacian(fg, eltype(X)) -======= - L̃ = scaled_laplacian(fg, eltype(X)) - - @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." ->>>>>>> 17dbba7 (implement COO featured graph) Z_prev = X Z = X * L̃ Y = view(c.weight,:,:,1) * Z_prev @@ -165,11 +159,7 @@ update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1* function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) check_num_nodes(fg, x) -<<<<<<< HEAD _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) -======= - _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) ->>>>>>> 17dbba7 (implement COO featured graph) x end @@ -256,16 +246,8 @@ update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) n = size(adj, 1) -<<<<<<< HEAD - # a vertex must always receive a message from itself - Zygote.ignore() do - GraphLaplacians.add_self_loop!(adj, n) - end - mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) -======= add_self_loop!(adj) - mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n) ->>>>>>> 17dbba7 (implement COO featured graph) + mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) end # The same as update function in batch manner @@ -282,11 +264,7 @@ end function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) -<<<<<<< HEAD _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) -======= - _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) ->>>>>>> 17dbba7 (implement COO featured graph) X end @@ -343,7 +321,6 @@ update(ggc::GatedGraphConv, m::AbstractVector, x) = m function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} check_num_nodes(fg, H) -<<<<<<< HEAD m, n = size(H) @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." adj = adjacency_list(fg) @@ -351,13 +328,6 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T Hpad = similar(H, S, ggc.out_ch - m, n) H = vcat(H, fill!(Hpad, 0)) end -======= - adj = adjacency_list(fg) - m, n = size(H) - @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." - (m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n))) - ->>>>>>> 17dbba7 (implement COO featured graph) for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +) @@ -405,11 +375,7 @@ update(ec::EdgeConv, m::AbstractVector, x) = m function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) -<<<<<<< HEAD _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) -======= - _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) ->>>>>>> 17dbba7 (implement COO featured graph) X end diff --git a/src/layers/gn.jl b/src/layers/gn.jl index 360d58e97..326fc45d4 100644 --- a/src/layers/gn.jl +++ b/src/layers/gn.jl @@ -53,13 +53,9 @@ end end function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) -<<<<<<< HEAD - E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr) -======= E, V, u = propagate(gn, adjacency_list(fg), edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr) ->>>>>>> 17dbba7 (implement COO featured graph) FeaturedGraph(fg, nf=V, ef=E, gf=u) end diff --git a/src/models.jl b/src/models.jl index 21439933f..9a9594453 100644 --- a/src/models.jl +++ b/src/models.jl @@ -71,11 +71,7 @@ end function (i::InnerProductDecoder)(fg::FeaturedGraph) Z = node_feature(fg) A = i(Z) -<<<<<<< HEAD return FeaturedGraph(fg, nf=A) -======= - FeaturedGraph(fg, nf=A) ->>>>>>> 17dbba7 (implement COO featured graph) end diff --git a/src/utils.jl b/src/utils.jl index 6bc8b5f51..a6556a06c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -63,24 +63,6 @@ end Zygote.@nograd edge_index_table -### TODO move these to GraphSignals ###### -import GraphSignals: FeaturedGraph - -function FeaturedGraph(fg::FeaturedGraph; - nf=node_feature(fg), - ef=edge_feature(fg), - gf=global_feature(fg)) - - return FeaturedGraph(graph(fg); nf, ef, gf) -end - function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) @assert nv(fg) == size(x, ndims(x)) -end -<<<<<<< HEAD -======= - -### TODO move these to GraphSignals ###### -# @functor FeaturedGraph -# Zygote.@nograd normalized_laplacian, scaled_laplacian ->>>>>>> 17dbba7 (implement COO featured graph) +end \ No newline at end of file diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 757b5ed92..386edc94e 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -15,11 +15,7 @@ fg = FeaturedGraph(adj) gc = GCNConv(fg, in_channel=>out_channel) |> gpu @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) -<<<<<<< HEAD @test Array(adjacency_matrix(gc.fg)) == adj -======= - @test Array(adjacency_matrix(gc.fg) ≈ adj ->>>>>>> 17dbba7 (implement COO featured graph) X = rand(in_channel, N) |> gpu Y = gc(X) @@ -39,11 +35,7 @@ fg = FeaturedGraph(adj) cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) -<<<<<<< HEAD @test Array(adjacency_matrix(cc.fg)) == adj -======= - @test Array(adjacency_matrix(cc.fg)) ≈ adj ->>>>>>> 17dbba7 (implement COO featured graph) @test cc.k == k @test cc.in_channel == in_channel @test cc.out_channel == out_channel From 3cf302c423af47123585107395005afd8e236677 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 28 Jul 2021 17:36:38 +0200 Subject: [PATCH 03/11] start coo integration --- docs/src/abstractions/msgpass.md | 18 +++--- src/GeometricFlux.jl | 11 +--- src/cuda/conv.jl | 4 +- src/cuda/msgpass.jl | 82 ++++++++++++------------ src/featured_graph.jl | 2 + src/layers/gn.jl | 103 ++++++++++++++++--------------- src/layers/msgpass.jl | 22 +++---- src/utils.jl | 63 ------------------- test/cuda/msgpass.jl | 16 ++--- test/layers/gn.jl | 34 +++++----- test/runtests.jl | 10 +-- 11 files changed, 147 insertions(+), 218 deletions(-) diff --git a/docs/src/abstractions/msgpass.md b/docs/src/abstractions/msgpass.md index cad3a4f0d..1ed7f1af8 100644 --- a/docs/src/abstractions/msgpass.md +++ b/docs/src/abstractions/msgpass.md @@ -20,20 +20,20 @@ A message function accepts feature vector representing node state `x_i`, feature Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following: ``` -propagate(mp, fg::FeaturedGraph, aggr::Symbol=:add) +propagate(mp, fg::FeaturedGraph, aggr::Symbol=+) ``` -`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `:add` represents an aggregate function of addition of all messages. +`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `+` represents an aggregate function of addition of all messages. The following `aggr` are available aggregate functions: -`:add`: sum over all messages -`:sub`: negative of sum over all messages -`:mul`: multiplication over all messages -`:div`: inverse of multiplication over all messages -`:max`: the maximum of all messages -`:min`: the minimum of all messages -`:mean`: the average of all messages +`+`: sum over all messages +`-`: negative of sum over all messages +`*`: multiplication over all messages +`/`: inverse of multiplication over all messages +`max`: the maximum of all messages +`min`: the minimum of all messages +`mean`: the average of all messages ## Update function diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index bbfe79581..6d42561dd 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,5 +1,6 @@ module GeometricFlux +using Base: Tuple using Statistics: mean using LinearAlgebra using FillArrays: Fill @@ -21,13 +22,10 @@ import LightGraphs using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, adjacency_matrix export - # featured_graph FeaturedGraph, adjacency_list, - # graph, - # has_graph, node_feature, edge_feature, global_feature, - # ne, nv, adjacency_matrix, # from LightGraphs + ne, nv, adjacency_matrix, # from LightGraphs # layers/gn GraphNet, @@ -58,10 +56,7 @@ export sample, # layer/selector - bypass_graph, - - # utils - generate_cluster + bypass_graph include("featured_graph.jl") include("datasets.jl") diff --git a/src/cuda/conv.jl b/src/cuda/conv.jl index d6464b37e..0e64888be 100644 --- a/src/cuda/conv.jl +++ b/src/cuda/conv.jl @@ -21,6 +21,6 @@ # Avoid ambiguity -update_batch_edge(g::GATConv, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} = update_batch_edge(g, adj, X) +# update_batch_edge(g::GATConv, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} = update_batch_edge(g, adj, X) -update_batch_vertex(g::GATConv, M::CuMatrix, X::CuMatrix, u) = update_batch_vertex(g, M) +# update_batch_vertex(g::GATConv, M::CuMatrix, X::CuMatrix, u) = update_batch_vertex(g, M) diff --git a/src/cuda/msgpass.jl b/src/cuda/msgpass.jl index 2996656a5..4820821f1 100644 --- a/src/cuda/msgpass.jl +++ b/src/cuda/msgpass.jl @@ -1,41 +1,41 @@ -@inline function update_batch_edge(mp::MessagePassing, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} - E = fill(E.value, E.axes) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::Fill{S,2,Axes}, u) where {S,Axes} - X = fill(X.value, X.axes) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::AbstractMatrix, X::CuMatrix, u) - E = convert(typeof(X), E) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::AbstractMatrix, u) - X = convert(typeof(E), X) - update_batch_edge(mp, adj, E, X, u) -end - -@inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::CuMatrix, u) - n = size(adj, 1) - edge_idx = edge_index_table(adj) - mapreduce(i -> apply_batch_message(mp, i, adj[i], edge_idx, E, X, u), hcat, 1:n) -end - -@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::CuMatrix, X::CuMatrix, u) = - mapreduce(j -> message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) - -@inline function update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::CuMatrix, u) - M = convert(typeof(X), M) - update_batch_vertex(mp, M, X, u) -end - -@inline function update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::AbstractMatrix, u) - X = convert(typeof(M), X) - update_batch_vertex(mp, M, X, u) -end - -@inline update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::CuMatrix, u) = - mapreduce(i -> update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) +# @inline function update_batch_edge(mp::MessagePassing, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} +# E = fill(E.value, E.axes) +# update_batch_edge(mp, adj, E, X, u) +# end + +# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::Fill{S,2,Axes}, u) where {S,Axes} +# X = fill(X.value, X.axes) +# update_batch_edge(mp, adj, E, X, u) +# end + +# @inline function update_batch_edge(mp::MessagePassing, adj, E::AbstractMatrix, X::CuMatrix, u) +# E = convert(typeof(X), E) +# update_batch_edge(mp, adj, E, X, u) +# end + +# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::AbstractMatrix, u) +# X = convert(typeof(E), X) +# update_batch_edge(mp, adj, E, X, u) +# end + +# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::CuMatrix, u) +# n = size(adj, 1) +# edge_idx = edge_index_table(adj) +# mapreduce(i -> apply_batch_message(mp, i, adj[i], edge_idx, E, X, u), hcat, 1:n) +# end + +# @inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::CuMatrix, X::CuMatrix, u) = +# mapreduce(j -> message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) + +# @inline function update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::CuMatrix, u) +# M = convert(typeof(X), M) +# update_batch_vertex(mp, M, X, u) +# end + +# @inline function update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::AbstractMatrix, u) +# X = convert(typeof(M), X) +# update_batch_vertex(mp, M, X, u) +# end + +# @inline update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::CuMatrix, u) = +# mapreduce(i -> update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 2c6731d45..67a3e5e47 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -126,6 +126,8 @@ end @functor FeaturedGraph +edge_index(fg::FeaturedGraph) = fg.edge_index + LightGraphs.edges(fg::FeaturedGraph) = zip(fg.edge_index[1], fg.edge_index[2]) LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{eltype(fg.edge_index[1]), eltype(fg.edge_index[2])} diff --git a/src/layers/gn.jl b/src/layers/gn.jl index 326fc45d4..451f9c24c 100644 --- a/src/layers/gn.jl +++ b/src/layers/gn.jl @@ -1,71 +1,76 @@ -_view(::Nothing, i) = nothing -_view(A::AbstractMatrix, idx) = view(A, :, idx) +abstract type GraphNet end -aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2)) -aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2)) -aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2)) -aggregate(aggr::typeof(/), X) = 1 ./ vec(prod(X, dims=2)) -aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2)) -aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2)) -aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2)) +_view(x::AbstractMatrix, i) = view(x, :, i) # use standard indexing instead of views? +_view(x::Nothing, i) = nothing -abstract type GraphNet end +aggregate(aggr::typeof(+), X::AbstractMatrix) = vec(sum(X, dims=2)) +aggregate(aggr::typeof(-), X::AbstractMatrix) = -vec(sum(X, dims=2)) +aggregate(aggr::typeof(*), X::AbstractMatrix) = vec(prod(X, dims=2)) +aggregate(aggr::typeof(/), X::AbstractMatrix) = 1 ./ vec(prod(X, dims=2)) +aggregate(aggr::typeof(max), X::AbstractMatrix) = vec(maximum(X, dims=2)) +aggregate(aggr::typeof(min), X::AbstractMatrix) = vec(minimum(X, dims=2)) +aggregate(aggr::typeof(mean), X::AbstractMatrix) = vec(mean(X, dims=2)) +aggregate(aggr::Nothing, X::AbstractMatrix) = nothing -@inline update_edge(gn::GraphNet, e, vi, vj, u) = e -@inline update_vertex(gn::GraphNet, ē, vi, u) = vi -@inline update_global(gn::GraphNet, ē, v̄, u) = u +## Step 1. -@inline function update_batch_edge(gn::GraphNet, adj, E, V, u) - n = size(adj, 1) - edge_idx = edge_index_table(adj) - mapreduce(i -> apply_batch_message(gn, i, adj[i], edge_idx, E, V, u), hcat, 1:n) +function update_batch_edge(gn::GraphNet, st, E, X, u) + s, t = st + message(gn, X[:,t], X[:,s], E, u) # use view instead of indexing? end -@inline apply_batch_message(gn::GraphNet, i, js, edge_idx, E, V, u) = - mapreduce(j -> update_edge(gn, _view(E, edge_idx[(i,j)]), _view(V, i), _view(V, j), u), hcat, js) +message(gn::GraphNet, x_i, x_j, e_ij, u) = x_j +# message(gn::GraphNet, i, j, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij, u) # TODO add in the future -@inline update_batch_vertex(gn::GraphNet, Ē, V, u) = - mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2)) +## Step 2 -@inline function aggregate_neighbors(gn::GraphNet, aggr, E, accu_edge) - @assert !iszero(accu_edge) "accumulated edge must not be zero." - cluster = generate_cluster(E, accu_edge) - NNlib.scatter(aggr, E, cluster) +function aggregate_neighbors(gn::GraphNet, aggr, st, E) + s, t = st + NNlib.scatter(aggr, E, t) end -@inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge) - @nospecialize E accu_edge - return nothing -end +aggregate_neighbors(gn::GraphNet, aggr::Nothing, st, E) = nothing -@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E) +## Step 3 -@inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E) - @nospecialize E - return nothing -end +update_batch_vertex(gn::GraphNet, M, X, u) = update(gn, M, X, u) + +update(gn::GraphNet, m, x, u) = x +# update(gn::GraphNet, i, m, x, u) = update(gn, m, x, u) -@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V) -@inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V) - @nospecialize V - return nothing -end +## Step 4 + +aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E) + +## Step 5 + +aggregate_vertices(gn::GraphNet, aggr, X) = aggregate(aggr, X) + +## Step 6 + +update_global(gn::GraphNet, ē, x̄, u) = u + +### end steps ### + function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) - E, V, u = propagate(gn, adjacency_list(fg), + E, X, u = propagate(gn, edge_index(fg), edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr) - FeaturedGraph(fg, nf=V, ef=E, gf=u) + FeaturedGraph(fg, nf=X, ef=E, gf=u) end -function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P, - naggr=nothing, eaggr=nothing, vaggr=nothing) where {S<:AbstractVector,R,Q,P} - E = update_batch_edge(gn, adj, E, V, u) - Ē = aggregate_neighbors(gn, naggr, E, accumulated_edges(adj)) - V = update_batch_vertex(gn, Ē, V, u) +function propagate(gn::GraphNet, st::Tuple, E, X, u, + naggr=nothing, eaggr=nothing, vaggr=nothing) + E = update_batch_edge(gn, st, E, X, u) + @show E + Ē = aggregate_neighbors(gn, naggr, st, E) + @show Ē + X = update_batch_vertex(gn, Ē, X, u) + @show X ē = aggregate_edges(gn, eaggr, E) - v̄ = aggregate_vertices(gn, vaggr, V) - u = update_global(gn, ē, v̄, u) - return E, V, u + x̄ = aggregate_vertices(gn, vaggr, X) + u = update_global(gn, ē, x̄, u) + return E, X, u end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index b0f20a8ea..3637fd0b9 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -1,7 +1,7 @@ abstract type MessagePassing <: GraphNet end """ - message(mp::MessagePassing, x_i, x_j, e_ij) +message(mp::MessagePassing, x_i, x_j, e_ij) Message function for the message-passing scheme, returning the message from node `j` to node `i` . @@ -22,11 +22,10 @@ specialize this method with custom behavior. See also [`update`](@ref). """ -@inline message(mp::MessagePassing, x_i, x_j, e_ij) = x_j -@inline message(mp::MessagePassing, i::Integer, j::Integer, x_i, x_j, e_ij) = x_j +function message end """ - update(mp::MessagePassing, m, x) +update(mp::GraphNet, m, x) Update function for the message-passing scheme, returning a new set of node features `x′` based on old @@ -45,21 +44,14 @@ specialize this method with custom behavior. See also [`message`](@ref). """ -@inline update(mp::MessagePassing, m, x) = m -@inline update(mp::MessagePassing, i::Integer, m, x) = m - -@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::AbstractMatrix, X::AbstractMatrix, u) = - mapreduce(j -> GeometricFlux.message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) - -@inline update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::AbstractMatrix, u) = - mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) +function message end function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+) - E, X = propagate(mp, adjacency_list(fg), edge_feature(fg), node_feature(fg), aggr) + E, X = propagate(mp, edge_index(fg), edge_feature(fg), node_feature(fg), aggr) FeaturedGraph(fg, nf=X, ef=E) end -function propagate(mp::MessagePassing, adj::AbstractVector{S}, E::R, X::Q, aggr) where {S<:AbstractVector,R,Q} - E, X, u = propagate(mp, adj, E, X, nothing, aggr, nothing, nothing) +function propagate(mp::MessagePassing, eindex::Tuple, E, X, aggr) + E, X, u = propagate(mp, eindex, E, X, nothing, aggr, nothing, nothing) E, X end diff --git a/src/utils.jl b/src/utils.jl index a6556a06c..67066ec4d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,67 +1,4 @@ -""" - accumulated_edges(adj) -Return a vector which acts as a mapping table. The index is the vertex index, -value is accumulated numbers of edge (current vertex not included). -""" -accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}}) = [0, cumsum(map(length, adj))...] - -Zygote.@nograd accumulated_edges - - -Zygote.@nograd function generate_cluster(M, accu_edge) - num_V = length(accu_edge) - 1 - num_E = accu_edge[end] - # cluster = similar(M, Int, num_E) - cluster = zeros(Int, num_E) - @inbounds for i = 1:num_V - j = accu_edge[i] - k = accu_edge[i+1] - cluster[j+1:k] .= i - end - cluster -end - -""" - edge_index_table(adj[, directed]) - -Generate a mapping from vertex pair (i, j) to edge index. The edge indecies are determined by -the sorted vertex indecies. -""" -function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}, directed::Bool=is_directed(adj)) - table = Dict{Tuple{UInt32,UInt32},UInt64}() - e = one(UInt64) - if directed - for (i, js) = enumerate(adj) - js = sort(js) - for j = js - table[(i, j)] = e - e += one(UInt64) - end - end - else - for (i, js) = enumerate(adj) - js = sort(js) - js = js[i .≤ js] - for j = js - table[(i, j)] = e - table[(j, i)] = e - e += one(UInt64) - end - end - end - table -end - -function edge_index_table(vpair::AbstractVector{<:Tuple}) - table = Dict{Tuple{UInt32,UInt32},UInt64}() - for (i, p) = enumerate(vpair) - table[p] = i - end - table -end - -Zygote.@nograd edge_index_table function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) @assert nv(fg) == size(x, ndims(x)) diff --git a/test/cuda/msgpass.jl b/test/cuda/msgpass.jl index e372a41b5..5f16af3e6 100644 --- a/test/cuda/msgpass.jl +++ b/test/cuda/msgpass.jl @@ -2,12 +2,12 @@ in_channel = 10 out_channel = 5 N = 6 T = Float32 -adj = [0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] +adj = [0 1 0 0 0 0; + 1 0 0 1 1 1; + 0 0 0 0 0 1; + 0 1 0 0 1 0; + 0 1 0 1 0 1; + 0 1 1 0 1 0] struct NewCudaLayer <: MessagePassing weight @@ -15,12 +15,12 @@ end NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n)) @functor NewCudaLayer -(l::NewCudaLayer)(X) = propagate(l, X, +) +(l::NewCudaLayer)(X) = GeometricFlux.propagate(l, X, +) GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j GeometricFlux.update(::NewCudaLayer, m, x) = m X = rand(T, in_channel, N) |> gpu -fg = FeaturedGraph(adj, nf=X) +fg = FeaturedGraph(adj, nf=X) |> gpu l = NewCudaLayer(out_channel, in_channel) |> gpu @testset "cuda/msgpass" begin diff --git a/test/layers/gn.jl b/test/layers/gn.jl index 5e47e7409..cf7e61d91 100644 --- a/test/layers/gn.jl +++ b/test/layers/gn.jl @@ -1,21 +1,21 @@ in_channel = 10 out_channel = 5 num_V = 6 -num_E = 7 +num_E = 14 T = Float32 -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] +adj = [0 1 0 0 0 0; + 1 0 0 1 1 1; + 0 0 0 0 0 1; + 0 1 0 0 1 0; + 0 1 0 1 0 1; + 0 1 1 0 1 0] struct NewGNLayer <: GraphNet end V = rand(T, in_channel, num_V) -E = rand(T, in_channel, 2num_E) +E = rand(T, in_channel, num_E) u = rand(T, in_channel) @testset "gn" begin @@ -29,41 +29,39 @@ u = rand(T, in_channel) @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - # @test edge_feature(fg_) === nothing # TODO + @test edge_feature(fg_) === nothing @test global_feature(fg_) === nothing - # @test size(edge_feature(fg_)) == (0, 2*num_E) - # @test size(global_feature(fg_)) == (0,) end @testset "with neighbor aggregation" begin (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) + fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(2)) l = NewGNLayer() fg_ = l(fg) @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (in_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test size(edge_feature(fg_)) == (in_channel, num_E) + @test size(global_feature(fg_)) == (2,) end - GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = rand(T, out_channel) + GeometricFlux.message(l::NewGNLayer, xi, xj, e, u) = rand(T, out_channel) @testset "update edge with neighbor aggregation" begin (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) + fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(2)) l = NewGNLayer() fg_ = l(fg) @test all(adjacency_matrix(fg_) .== adj) @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test size(global_feature(fg_)) == (2,) end - GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) + GeometricFlux.update(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) @testset "update edge/vertex with all aggregation" begin (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +) diff --git a/test/runtests.jl b/test/runtests.jl index 98dc5c3ec..547e655a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,11 +20,11 @@ cuda_tests = [ tests = [ "featured_graph", "layers/gn", - "layers/msgpass", - "layers/conv", - "layers/pool", - "layers/misc", - "models", + # "layers/msgpass", + # "layers/conv", + # "layers/pool", + # "layers/misc", + # "models", ] if Flux.use_cuda[] From 8a09c2622e6c90cd76f0183690d0d8050bf19d4f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 09:56:05 +0200 Subject: [PATCH 04/11] complete redisign --- docs/src/abstractions/msgpass.md | 2 +- src/GeometricFlux.jl | 15 ---- src/cuda/conv.jl | 26 ------- src/cuda/msgpass.jl | 41 ---------- src/layers/conv.jl | 86 +++++++++------------ src/layers/gn.jl | 76 ------------------ src/layers/msgpass.jl | 75 +++++++++++++++--- test/featured_graph.jl | 4 - test/layers/conv.jl | 39 +++++----- test/layers/gn.jl | 78 ------------------- test/layers/msgpass.jl | 127 ++++++++++++++++++++++--------- test/runtests.jl | 12 ++- 12 files changed, 218 insertions(+), 363 deletions(-) delete mode 100644 src/cuda/conv.jl delete mode 100644 src/cuda/msgpass.jl delete mode 100644 src/layers/gn.jl delete mode 100644 test/layers/gn.jl diff --git a/docs/src/abstractions/msgpass.md b/docs/src/abstractions/msgpass.md index 1ed7f1af8..8ab9a4bef 100644 --- a/docs/src/abstractions/msgpass.md +++ b/docs/src/abstractions/msgpass.md @@ -20,7 +20,7 @@ A message function accepts feature vector representing node state `x_i`, feature Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following: ``` -propagate(mp, fg::FeaturedGraph, aggr::Symbol=+) +propagate(mp, fg::FeaturedGraph; aggr::Symbol=+) ``` `propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `+` represents an aggregate function of addition of all messages. diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 6d42561dd..4d063d2b1 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,6 +1,5 @@ module GeometricFlux -using Base: Tuple using Statistics: mean using LinearAlgebra using FillArrays: Fill @@ -11,13 +10,6 @@ using Flux: glorot_uniform, leakyrelu, GRUCell, @functor using NNlib, NNlibCUDA using Zygote using ChainRulesCore - - -# import GraphLaplacians -# using GraphLaplacians: normalized_laplacian, scaled_laplacian -# using GraphLaplacians: adjacency_matrix -# using Reexport -# @reexport using GraphSignals import LightGraphs using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, adjacency_matrix @@ -27,9 +19,6 @@ export node_feature, edge_feature, global_feature, ne, nv, adjacency_matrix, # from LightGraphs - # layers/gn - GraphNet, - # layers/msgpass MessagePassing, @@ -63,7 +52,6 @@ include("datasets.jl") include("utils.jl") -include("layers/gn.jl") include("layers/msgpass.jl") include("layers/conv.jl") @@ -71,9 +59,6 @@ include("layers/pool.jl") include("models.jl") include("layers/misc.jl") -include("cuda/msgpass.jl") -include("cuda/conv.jl") - using .Datasets diff --git a/src/cuda/conv.jl b/src/cuda/conv.jl deleted file mode 100644 index 0e64888be..000000000 --- a/src/cuda/conv.jl +++ /dev/null @@ -1,26 +0,0 @@ -# (g::GCNConv)(L̃::AbstractMatrix, X::CuArray) = g(cu(L̃), X) - -# (g::GCNConv)(L̃::CuArray, X::CuArray) = g.σ.(g.weight * X * L̃ .+ g.bias) - -# (c::ChebConv)(L̃::AbstractMatrix, X::CuArray) = c(cu(L̃), X) - -# function (c::ChebConv)(L̃::CuArray, X::CuArray) -# @assert size(X, 1) == c.in_channel "Input feature size must match input channel size." -# @assert size(X, 2) == size(L̃, 1) "Input vertex number must match Laplacian matrix size." - -# Z_prev = X -# Z = X * L̃ -# Y = view(c.weight,:,:,1) * Z_prev -# Y += view(c.weight,:,:,2) * Z -# for k = 3:c.k -# Z, Z_prev = 2*Z*L̃ - Z_prev, Z -# Y += view(c.weight,:,:,k) * Z -# end -# return Y .+ c.bias -# end - - -# Avoid ambiguity -# update_batch_edge(g::GATConv, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} = update_batch_edge(g, adj, X) - -# update_batch_vertex(g::GATConv, M::CuMatrix, X::CuMatrix, u) = update_batch_vertex(g, M) diff --git a/src/cuda/msgpass.jl b/src/cuda/msgpass.jl deleted file mode 100644 index 4820821f1..000000000 --- a/src/cuda/msgpass.jl +++ /dev/null @@ -1,41 +0,0 @@ -# @inline function update_batch_edge(mp::MessagePassing, adj, E::Fill{S,2,Axes}, X::CuMatrix, u) where {S,Axes} -# E = fill(E.value, E.axes) -# update_batch_edge(mp, adj, E, X, u) -# end - -# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::Fill{S,2,Axes}, u) where {S,Axes} -# X = fill(X.value, X.axes) -# update_batch_edge(mp, adj, E, X, u) -# end - -# @inline function update_batch_edge(mp::MessagePassing, adj, E::AbstractMatrix, X::CuMatrix, u) -# E = convert(typeof(X), E) -# update_batch_edge(mp, adj, E, X, u) -# end - -# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::AbstractMatrix, u) -# X = convert(typeof(E), X) -# update_batch_edge(mp, adj, E, X, u) -# end - -# @inline function update_batch_edge(mp::MessagePassing, adj, E::CuMatrix, X::CuMatrix, u) -# n = size(adj, 1) -# edge_idx = edge_index_table(adj) -# mapreduce(i -> apply_batch_message(mp, i, adj[i], edge_idx, E, X, u), hcat, 1:n) -# end - -# @inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::CuMatrix, X::CuMatrix, u) = -# mapreduce(j -> message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) - -# @inline function update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::CuMatrix, u) -# M = convert(typeof(X), M) -# update_batch_vertex(mp, M, X, u) -# end - -# @inline function update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::AbstractMatrix, u) -# X = convert(typeof(M), X) -# update_batch_vertex(mp, M, X, u) -# end - -# @inline update_batch_vertex(mp::MessagePassing, M::CuMatrix, X::CuMatrix, u) = -# mapreduce(i -> update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 58044d11f..492c881f1 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -153,13 +153,13 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) = @functor GraphConv -message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j +message(gc::GraphConv, x_i, x_j, e_ij) = x_j -update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias) +update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias) function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) check_num_nodes(fg, x) - _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) + _, x = propagate(gc, fg, nothing, x, +) x end @@ -222,50 +222,34 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...) @functor GATConv -# Here the α that has not been softmaxed is the first number of the output message -function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector) - x_i = reshape(gat.weight*x_i, :, gat.heads) - x_j = reshape(gat.weight*x_j, :, gat.heads) - x_ij = vcat(x_i, x_j+zero(x_j)) - e = sum(x_ij .* gat.a, dims=1) # inner product for each head, output shape: (1, gat.heads) - e_ij = leakyrelu.(e, gat.negative_slope) - vcat(e_ij, x_j) # shape: (n+1, gat.heads) -end - -# After some reshaping due to the multihead, we get the α from each message, -# then get the softmax over every α, and eventually multiply the message by α -function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix) - e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js) - n = size(e_ij, 1) - αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2) - msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :) - reshape(msgs, (n-1)*gat.heads, :) -end - -update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X) - -function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) - n = size(adj, 1) - add_self_loop!(adj) - mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) -end - -# The same as update function in batch manner -update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M) - -function update_batch_vertex(gat::GATConv, M::AbstractMatrix) - M = M .+ gat.bias +function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) + check_num_nodes(fg, X) + # add_self_loop!(adj) #TODO + chin, chout = gat.channel + heads = gat.heads + + source, target = edge_index(fg) + Wx = gat.weight*X + Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes + Wxi = NNlib.gather(Wx, target) # chout × nheads × nedges + Wxj = NNlib.gather(Wx, source) + + # Edge Message + # Computing softmax. TODO make it numerically stable + aWW = sum(gat.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges + α = exp.(leakyrelu.(aWW, gat.negative_slope)) + m̄ = NNlib.scatter(+, α .* Wxj, target) # chout × nheads × nnodes + ᾱ = NNlib.scatter(+, α, target) # 1 × nheads × nnodes + + # Node update + b = reshape(gat.bias, chout, heads) + X = m̄ ./ ᾱ .+ b # chout × nheads × nnodes if !gat.concat - N = size(M, 2) - M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N) + X = sum(X, dims=2) end - return M -end -function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) - check_num_nodes(fg, X) - _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) - X + # We finally return a matrix + return reshape(X, :, size(X, 3)) end (l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) @@ -314,23 +298,22 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) = @functor GatedGraphConv -message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j +message(ggc::GatedGraphConv, x_i, x_j, e_ij) = x_j -update(ggc::GatedGraphConv, m::AbstractVector, x) = m +update(ggc::GatedGraphConv, m, x) = m function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} check_num_nodes(fg, H) m, n = size(H) @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." - adj = adjacency_list(fg) if m < ggc.out_ch Hpad = similar(H, S, ggc.out_ch - m, n) H = vcat(H, fill!(Hpad, 0)) end for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H - _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +) + _, M = propagate(ggc, fg, nothing, M, +) H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381 end H @@ -370,12 +353,13 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...) @functor EdgeConv -message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) -update(ec::EdgeConv, m::AbstractVector, x) = m +message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) + +update(ec::EdgeConv, m, x) = m function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) - _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) + _, X = propagate(ec, fg, nothing, X, ec.aggr) X end diff --git a/src/layers/gn.jl b/src/layers/gn.jl deleted file mode 100644 index 451f9c24c..000000000 --- a/src/layers/gn.jl +++ /dev/null @@ -1,76 +0,0 @@ -abstract type GraphNet end - -_view(x::AbstractMatrix, i) = view(x, :, i) # use standard indexing instead of views? -_view(x::Nothing, i) = nothing - -aggregate(aggr::typeof(+), X::AbstractMatrix) = vec(sum(X, dims=2)) -aggregate(aggr::typeof(-), X::AbstractMatrix) = -vec(sum(X, dims=2)) -aggregate(aggr::typeof(*), X::AbstractMatrix) = vec(prod(X, dims=2)) -aggregate(aggr::typeof(/), X::AbstractMatrix) = 1 ./ vec(prod(X, dims=2)) -aggregate(aggr::typeof(max), X::AbstractMatrix) = vec(maximum(X, dims=2)) -aggregate(aggr::typeof(min), X::AbstractMatrix) = vec(minimum(X, dims=2)) -aggregate(aggr::typeof(mean), X::AbstractMatrix) = vec(mean(X, dims=2)) -aggregate(aggr::Nothing, X::AbstractMatrix) = nothing - -## Step 1. - -function update_batch_edge(gn::GraphNet, st, E, X, u) - s, t = st - message(gn, X[:,t], X[:,s], E, u) # use view instead of indexing? -end - -message(gn::GraphNet, x_i, x_j, e_ij, u) = x_j -# message(gn::GraphNet, i, j, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij, u) # TODO add in the future - -## Step 2 - -function aggregate_neighbors(gn::GraphNet, aggr, st, E) - s, t = st - NNlib.scatter(aggr, E, t) -end - -aggregate_neighbors(gn::GraphNet, aggr::Nothing, st, E) = nothing - -## Step 3 - -update_batch_vertex(gn::GraphNet, M, X, u) = update(gn, M, X, u) - -update(gn::GraphNet, m, x, u) = x -# update(gn::GraphNet, i, m, x, u) = update(gn, m, x, u) - - -## Step 4 - -aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E) - -## Step 5 - -aggregate_vertices(gn::GraphNet, aggr, X) = aggregate(aggr, X) - -## Step 6 - -update_global(gn::GraphNet, ē, x̄, u) = u - -### end steps ### - - -function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) - E, X, u = propagate(gn, edge_index(fg), - edge_feature(fg), node_feature(fg), global_feature(fg), - naggr, eaggr, vaggr) - FeaturedGraph(fg, nf=X, ef=E, gf=u) -end - -function propagate(gn::GraphNet, st::Tuple, E, X, u, - naggr=nothing, eaggr=nothing, vaggr=nothing) - E = update_batch_edge(gn, st, E, X, u) - @show E - Ē = aggregate_neighbors(gn, naggr, st, E) - @show Ē - X = update_batch_vertex(gn, Ē, X, u) - @show X - ē = aggregate_edges(gn, eaggr, E) - x̄ = aggregate_vertices(gn, vaggr, X) - u = update_global(gn, ē, x̄, u) - return E, X, u -end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 3637fd0b9..2be9101c8 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -1,7 +1,25 @@ -abstract type MessagePassing <: GraphNet end +# Adapted message passing from paper +# "Relational inductive biases, deep learning, and graph networks" +abstract type MessagePassing end + +function propagate(gn::MessagePassing, fg::FeaturedGraph, aggr=+) + E, X, u = propagate(gn, fg, + edge_feature(fg), node_feature(fg), global_feature(fg), + aggr) + FeaturedGraph(fg, nf=X, ef=E, gf=u) +end + +function propagate(gn::MessagePassing, fg::FeaturedGraph, E, X, u, aggr=+) + M = compute_batch_message(gn, fg, E, X, u) + E = update_batch_edge(gn, M, E, u) + M̄ = aggregate_neighbors(gn, aggr, fg, M) + X = update_batch_vertex(gn, M̄, X, u) + u = update_global(gn, E, X, u) + return E, X, u +end """ -message(mp::MessagePassing, x_i, x_j, e_ij) + message(mp::MessagePassing, x_i, x_j, e_ij) Message function for the message-passing scheme, returning the message from node `j` to node `i` . @@ -25,7 +43,7 @@ See also [`update`](@ref). function message end """ -update(mp::GraphNet, m, x) +update(mp::MessagePassing, m, x) Update function for the message-passing scheme, returning a new set of node features `x′` based on old @@ -46,12 +64,51 @@ See also [`message`](@ref). """ function message end -function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+) - E, X = propagate(mp, edge_index(fg), edge_feature(fg), node_feature(fg), aggr) - FeaturedGraph(fg, nf=X, ef=E) + +_gather(x, i) = NNlib.gather(x, i) +_gather(x::Nothing, i) = nothing + +## Step 1. + +function compute_batch_message(gn::MessagePassing, fg, E, X, u) + s, t = edge_index(fg) + Xi = _gather(X, t) + Xj = _gather(X, s) + M = message(gn, Xi, Xj, E, u) + return M end -function propagate(mp::MessagePassing, eindex::Tuple, E, X, aggr) - E, X, u = propagate(mp, eindex, E, X, nothing, aggr, nothing, nothing) - E, X +# @inline message(gn::MessagePassing, i, j, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij, u) # TODO add in the future +@inline message(gn::MessagePassing, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij) +@inline message(gn::MessagePassing, x_i, x_j, e_ij) = message(gn, x_i, x_j) +@inline message(gn::MessagePassing, x_i, x_j) = x_j + +## Step 2 + +function aggregate_neighbors(gn::MessagePassing, aggr, fg, E) + s, t = edge_index(fg) + NNlib.scatter(aggr, E, t) end + +aggregate_neighbors(gn::MessagePassing, aggr::Nothing, fg, E) = nothing + +## Step 3 + +update_batch_vertex(gn::MessagePassing, M̄, X, u) = update(gn, M̄, X, u) + +# @inline update(gn::MessagePassing, i, m̄, x, u) = update(gn, m, x, u) +@inline update(gn::MessagePassing, m̄, x, u) = update(gn, m̄, x) +@inline update(gn::MessagePassing, m̄, x) = m̄ + +## Step 4 +update_batch_edge(gn::MessagePassing, M, E, u) = update_edge(gn::MessagePassing, M, E, u) + +@inline update_edge(gn::MessagePassing, M, E, u) = update_edge(gn::MessagePassing, M, E) +@inline update_edge(gn::MessagePassing, M, E) = E + +## Step 5 + +@inline update_global(gn::MessagePassing, E, X, u) = u + +### end steps ### + diff --git a/test/featured_graph.jl b/test/featured_graph.jl index e0dc7031e..94d42a154 100644 --- a/test/featured_graph.jl +++ b/test/featured_graph.jl @@ -1,5 +1,3 @@ -using LightGraphs - @testset "FeaturedGraph" begin @testset "symmetric graph" begin u = [1, 2, 3, 4, 2, 3, 4, 1] @@ -68,6 +66,4 @@ using LightGraphs @test adjacency_list(fg, dir=:in) == adj_list_in end - - end \ No newline at end of file diff --git a/test/layers/conv.jl b/test/layers/conv.jl index b55a437a9..8c27d3c43 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -1,23 +1,22 @@ -T = Float32 -in_channel = 3 -out_channel = 5 -N = 4 -adj = T[0. 1. 0. 1.; - 1. 0. 1. 0.; - 0. 1. 0. 1.; - 1. 0. 1. 0.] - -fg = FeaturedGraph(adj) - -adj_single_vertex = T[0. 0. 0. 1.; - 0. 0. 0. 0.; - 0. 0. 0. 1.; - 1. 0. 1. 0.] - -fg_single_vertex = FeaturedGraph(adj_single_vertex) - - @testset "layer" begin + T = Float32 + in_channel = 3 + out_channel = 5 + N = 4 + adj = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + + fg = FeaturedGraph(adj) + + adj_single_vertex = [0 0 0 1 + 0 0 0 0 + 0 0 0 1 + 1 0 1 0] + + fg_single_vertex = FeaturedGraph(adj_single_vertex) + @testset "GCNConv" begin X = rand(T, in_channel, N) Xt = transpose(rand(T, N, in_channel)) @@ -25,7 +24,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(fg, in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test all(adjacency_matrix(gc.fg) .== adj) + @test adjacency_matrix(gc.fg) == adj Y = gc(X) @test size(Y) == (out_channel, N) diff --git a/test/layers/gn.jl b/test/layers/gn.jl deleted file mode 100644 index cf7e61d91..000000000 --- a/test/layers/gn.jl +++ /dev/null @@ -1,78 +0,0 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 14 -T = Float32 - -adj = [0 1 0 0 0 0; - 1 0 0 1 1 1; - 0 0 0 0 0 1; - 0 1 0 0 1 0; - 0 1 0 1 0 1; - 0 1 1 0 1 0] - -struct NewGNLayer <: GraphNet -end - -V = rand(T, in_channel, num_V) -E = rand(T, in_channel, num_E) -u = rand(T, in_channel) - -@testset "gn" begin - l = NewGNLayer() - - @testset "without aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg) - - fg = FeaturedGraph(adj, nf=V) - fg_ = l(fg) - - @test adjacency_matrix(fg_) == adj - @test size(node_feature(fg_)) == (in_channel, num_V) - @test edge_feature(fg_) === nothing - @test global_feature(fg_) === nothing - end - - @testset "with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(2)) - l = NewGNLayer() - fg_ = l(fg) - - @test adjacency_matrix(fg_) == adj - @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (in_channel, num_E) - @test size(global_feature(fg_)) == (2,) - end - - GeometricFlux.message(l::NewGNLayer, xi, xj, e, u) = rand(T, out_channel) - - @testset "update edge with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(2)) - l = NewGNLayer() - fg_ = l(fg) - - @test all(adjacency_matrix(fg_) .== adj) - @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (2,) - end - - GeometricFlux.update(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) - - @testset "update edge/vertex with all aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +) - - fg = FeaturedGraph(adj, nf=V, ef=E, gf=u) - l = NewGNLayer() - fg_ = l(fg) - - @test all(adjacency_matrix(fg_) .== adj) - @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (in_channel,) - end -end diff --git a/test/layers/msgpass.jl b/test/layers/msgpass.jl index f7b49c9c6..8cba084d1 100644 --- a/test/layers/msgpass.jl +++ b/test/layers/msgpass.jl @@ -1,57 +1,114 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 7 -T = Float32 - -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] - -struct NewLayer <: MessagePassing - weight -end -NewLayer(m, n) = NewLayer(randn(T, m,n)) +@testset "MessagePassing" begin + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 14 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + struct NewLayer <: MessagePassing end -(l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + X = rand(T, in_channel, num_V) + E = rand(T, in_channel, num_E) + u = rand(T, in_channel) -X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V)) -fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E)) -l = NewLayer(out_channel, in_channel) + @testset "default aggregation (+)" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg) -@testset "msgpass" begin - @testset "no message or update" begin + fg = FeaturedGraph(adj, nf=X) fg_ = l(fg) - @test all(adjacency_matrix(fg_) .== adj) + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (in_channel, 2*num_E) + @test edge_feature(fg_) === nothing @test global_feature(fg_) === nothing end - GeometricFlux.message(l::NewLayer, x_i, x_j, e_ij) = l.weight * x_j - - @testset "message function" begin + @testset "neighbor aggregation (+)" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) + fg_ = l(fg) + + @test adjacency_matrix(fg_) == adj + @test size(node_feature(fg_)) == (in_channel, num_V) + @test edge_feature(fg_) ≈ E + @test global_feature(fg_) ≈ u + end + + GeometricFlux.message(l::NewLayer, xi, xj, e, u) = ones(T, out_channel, size(e,2)) + + @testset "custom message and neighbor aggregation" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) fg_ = l(fg) @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test global_feature(fg_) === nothing + @test edge_feature(fg_) ≈ edge_feature(fg) + @test global_feature(fg_) ≈ global_feature(fg) end - - GeometricFlux.update(l::NewLayer, m, x) = l.weight * x + m - @testset "message and update" begin + GeometricFlux.update_edge(l::NewLayer, m, e) = m + + @testset "update_edge" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) fg_ = l(fg) @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) - @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test global_feature(fg_) === nothing + @test size(edge_feature(fg_)) == (out_channel, num_E) + @test global_feature(fg_) ≈ global_feature(fg) + end + + GeometricFlux.update(l::NewLayer, m̄, xi, u) = rand(T, 2*out_channel, size(xi, 2)) + + @testset "update edge/vertex" begin + l = NewLayer() + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) + fg_ = l(fg) + + @test all(adjacency_matrix(fg_) .== adj) + @test size(node_feature(fg_)) == (2*out_channel, num_V) + @test size(edge_feature(fg_)) == (out_channel, num_E) + @test size(global_feature(fg_)) == (in_channel,) + end + + struct NewLayerW <: MessagePassing + weight + end + + NewLayerW(in, out) = NewLayerW(randn(T, out, in)) + + GeometricFlux.message(l::NewLayerW, x_i, x_j, e_ij) = l.weight * x_j + GeometricFlux.update(l::NewLayerW, m, x) = l.weight * x + m + + @testset "message and update with weights" begin + l = NewLayerW(in_channel, out_channel) + (l::NewLayerW)(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=E, gf=u) + fg_ = l(fg) + + @test adjacency_matrix(fg_) == adj + @test size(node_feature(fg_)) == (out_channel, num_V) + @test edge_feature(fg_) === E + @test global_feature(fg_) === u end end diff --git a/test/runtests.jl b/test/runtests.jl index 547e655a7..df2bc7e92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,6 @@ using GeometricFlux.Datasets using Flux using Flux: @functor using FillArrays -# using GraphSignals using LightGraphs using LinearAlgebra using NNlib @@ -19,12 +18,11 @@ cuda_tests = [ tests = [ "featured_graph", - "layers/gn", - # "layers/msgpass", - # "layers/conv", - # "layers/pool", - # "layers/misc", - # "models", + "layers/msgpass", + "layers/conv", + "layers/pool", + "layers/misc", + "models", ] if Flux.use_cuda[] From 342bd4e8afddb8cfa830c64159902eb280c4e35a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 09:58:10 +0200 Subject: [PATCH 05/11] cleanup --- Project.toml | 2 -- src/GeometricFlux.jl | 1 - test/runtests.jl | 4 ++-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index aa0f9ff27..f909a8201 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "0.7.6" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" @@ -22,7 +21,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3.3" DataStructures = "0.18" -FillArrays = "0.11, 0.12" Flux = "0.12" GraphMLDatasets = "0.1" LightGraphs = "1.3" diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 4d063d2b1..e4496f448 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -2,7 +2,6 @@ module GeometricFlux using Statistics: mean using LinearAlgebra -using FillArrays: Fill using CUDA using Flux diff --git a/test/runtests.jl b/test/runtests.jl index df2bc7e92..8e2d5ab6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,8 +12,8 @@ using Zygote using Test cuda_tests = [ - # "cuda/conv", - # "cuda/msgpass", + "cuda/conv", + "cuda/msgpass", ] tests = [ From bdd16b79035ac63197258306e214357e3415e999 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 10:25:53 +0200 Subject: [PATCH 06/11] passing tests except for some layers not gpu compatible --- src/layers/conv.jl | 5 +-- test/cuda/conv.jl | 100 ++++++++++++++++++++++----------------------- test/runtests.jl | 1 - 3 files changed, 50 insertions(+), 56 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 492c881f1..0fc050f85 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -12,7 +12,6 @@ Graph convolutional layer. - `bias`: Add learnable bias. - `init`: Weights' initializer. - The input to the layer is a node feature array `X` of size `(num_features, num_nodes)`. """ @@ -194,7 +193,7 @@ Graph attentional layer. - `out`: The dimension of output features. - `bias::Bool`: Keyword argument, whether to learn the additive bias. - `heads`: Number attention heads -- `concat`: Concatenate layer output or not. If not, layer output is averaged. +- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads. - `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU. """ struct GATConv{V<:AbstractFeaturedGraph, T, A<:AbstractMatrix{T}, B} <: MessagePassing @@ -314,7 +313,7 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H _, M = propagate(ggc, fg, nothing, M, +) - H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381 + H, _ = ggc.gru(H, M) end H end diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 386edc94e..5ca623e6b 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -1,56 +1,52 @@ -using Flux: Dense, gpu - -in_channel = 3 -out_channel = 5 -N = 4 -adj = [0 1 0 1; - 1 0 1 0; - 0 1 0 1; - 1 0 1 0] - -fg = FeaturedGraph(adj) - @testset "cuda/conv" begin - @testset "GCNConv" begin - gc = GCNConv(fg, in_channel=>out_channel) |> gpu - @test size(gc.weight) == (out_channel, in_channel) - @test size(gc.bias) == (out_channel,) - @test Array(adjacency_matrix(gc.fg)) == adj - - X = rand(in_channel, N) |> gpu - Y = gc(X) - @test size(Y) == (out_channel, N) - - g = Zygote.gradient(x -> sum(gc(x)), X)[1] - @test size(g) == size(X) - - g = Zygote.gradient(model -> sum(model(X)), gc)[1] - @test size(g.weight) == size(gc.weight) - @test size(g.bias) == size(gc.bias) - end - - - @testset "ChebConv" begin - k = 6 - cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu - @test size(cc.weight) == (out_channel, in_channel, k) - @test size(cc.bias) == (out_channel,) - @test Array(adjacency_matrix(cc.fg)) == adj - @test cc.k == k - @test cc.in_channel == in_channel - @test cc.out_channel == out_channel - - X = rand(in_channel, N) |> gpu - Y = cc(X) - @test size(Y) == (out_channel, N) - - # g = Zygote.gradient(x -> sum(cc(x)), X)[1] - # @test size(g) == size(X) - - # g = Zygote.gradient(model -> sum(model(X)), cc)[1] - # @test size(g.weight) == size(cc.weight) - # @test size(g.bias) == size(cc.bias) - end + in_channel = 3 + out_channel = 5 + N = 4 + adj = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + + fg = FeaturedGraph(adj) + + # @testset "GCNConv" begin + # gc = GCNConv(fg, in_channel=>out_channel) |> gpu + # @test size(gc.weight) == (out_channel, in_channel) + # @test size(gc.bias) == (out_channel,) + # @test adjacency_matrix(gc.fg |> cpu) == adj + + # X = rand(in_channel, N) |> gpu + # Y = gc(X) + # @test size(Y) == (out_channel, N) + + # g = Zygote.gradient(x -> sum(gc(x)), X)[1] + # @test size(g) == size(X) + + # g = Zygote.gradient(model -> sum(model(X)), gc)[1] + # @test size(g.weight) == size(gc.weight) + # @test size(g.bias) == size(gc.bias) + # end + + + # @testset "ChebConv" begin + # k = 6 + # cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu + # @test size(cc.weight) == (out_channel, in_channel, k) + # @test size(cc.bias) == (out_channel,) + # @test adjacency_matrix(cc.fg |> cpu) == adj + # @test cc.k == k + + # X = rand(in_channel, N) |> gpu + # Y = cc(X) + # @test size(Y) == (out_channel, N) + + # g = Zygote.gradient(x -> sum(cc(x)), X)[1] + # @test size(g) == size(X) + + # g = Zygote.gradient(model -> sum(model(X)), cc)[1] + # @test size(g.weight) == size(cc.weight) + # @test size(g.bias) == size(cc.bias) + # end @testset "GraphConv" begin gc = GraphConv(fg, in_channel=>out_channel) |> gpu diff --git a/test/runtests.jl b/test/runtests.jl index 8e2d5ab6c..61a3e222b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,6 @@ using GeometricFlux using GeometricFlux.Datasets using Flux using Flux: @functor -using FillArrays using LightGraphs using LinearAlgebra using NNlib From 79dfed378f373f8959719ab25e0bf7167f770084 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 11:36:24 +0200 Subject: [PATCH 07/11] message passing GCNConv; gpu friendly adjacency_matrix --- src/GeometricFlux.jl | 5 ++++- src/featured_graph.jl | 19 ++++++++++++++++--- src/layers/conv.jl | 19 ++++++++++++++++--- src/layers/msgpass.jl | 4 ++-- test/cuda/conv.jl | 28 ++++++++++++++-------------- test/featured_graph.jl | 12 ++++++++++++ test/runtests.jl | 1 - 7 files changed, 64 insertions(+), 24 deletions(-) diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index e4496f448..15977d383 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,5 +1,7 @@ module GeometricFlux +using ChainRulesCore: eltype +using LinearAlgebra: similar using Statistics: mean using LinearAlgebra @@ -10,7 +12,8 @@ using NNlib, NNlibCUDA using Zygote using ChainRulesCore import LightGraphs -using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, adjacency_matrix +using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, + adjacency_matrix, degree export FeaturedGraph, diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 67a3e5e47..41c5a47c7 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -165,14 +165,27 @@ end # TODO return sparse matrix function LightGraphs.adjacency_matrix(fg::FeaturedGraph, T::DataType=Int; dir=:out) # TODO dir=:both - u, v = fg.edge_index + u, v = edge_index(fg) n = fg.num_nodes - adj_mat = zeros(T, n, n) + adj_mat = fill!(similar(u, T, (n, n)), 0) adj_mat[u .+ n .* (v .- 1)] .= 1 # exploiting linear indexing return dir == :out ? adj_mat : adj_mat' end -Zygote.@nograd adjacency_matrix, adjacency_list +function LightGraphs.degree(fg::FeaturedGraph; dir=:both) + s, t = edge_index(fg) + degs = fill!(similar(s, eltype(s), fg.num_nodes), 0) + o = fill!(similar(s, eltype(s), fg.num_edges), 1) + if dir ∈ [:out, :both] + NNlib.scatter!(+, degs, o, s) + end + if dir ∈ [:in, :both] + NNlib.scatter!(+, degs, o, t) + end + return degs +end + +Zygote.@nograd adjacency_matrix, adjacency_list, degree # function ChainRulesCore.rrule(::typeof(copy), x) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0fc050f85..8d3b30b18 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -15,7 +15,7 @@ Graph convolutional layer. The input to the layer is a node feature array `X` of size `(num_features, num_nodes)`. """ -struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} +struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} <: MessagePassing weight::A bias::B σ::F @@ -35,9 +35,22 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) = @functor GCNConv +# function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) +# L̃ = normalized_laplacian(fg, eltype(x); selfloop=true) +# l.σ.(l.weight * x * L̃ .+ l.bias) +# end + +message(l::GCNConv, xi, xj) = xj +update(l::GCNConv, m, x) = m + function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) - L̃ = normalized_laplacian(fg, eltype(x); selfloop=true) - l.σ.(l.weight * x * L̃ .+ l.bias) + # TODO handle self loops + # cout = sqrt.(degree(fg, dir=:out)) + cin = reshape(sqrt.(degree(fg, dir=:in)), 1, :) + x = cin .* x + _, x = propagate(l, fg, nothing, x, +) + x = cin .* x + return l.σ.(l.weight * x .+ l.bias) end (l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 2be9101c8..82ec12ba3 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -43,7 +43,7 @@ See also [`update`](@ref). function message end """ -update(mp::MessagePassing, m, x) + update(mp::MessagePassing, m, x) Update function for the message-passing scheme, returning a new set of node features `x′` based on old @@ -62,7 +62,7 @@ specialize this method with custom behavior. See also [`message`](@ref). """ -function message end +function update end _gather(x, i) = NNlib.gather(x, i) diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 5ca623e6b..b2652084a 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -9,23 +9,23 @@ fg = FeaturedGraph(adj) - # @testset "GCNConv" begin - # gc = GCNConv(fg, in_channel=>out_channel) |> gpu - # @test size(gc.weight) == (out_channel, in_channel) - # @test size(gc.bias) == (out_channel,) - # @test adjacency_matrix(gc.fg |> cpu) == adj + @testset "GCNConv" begin + gc = GCNConv(fg, in_channel=>out_channel) |> gpu + @test size(gc.weight) == (out_channel, in_channel) + @test size(gc.bias) == (out_channel,) + @test adjacency_matrix(gc.fg |> cpu) == adj - # X = rand(in_channel, N) |> gpu - # Y = gc(X) - # @test size(Y) == (out_channel, N) + X = rand(in_channel, N) |> gpu + Y = gc(X) + @test size(Y) == (out_channel, N) - # g = Zygote.gradient(x -> sum(gc(x)), X)[1] - # @test size(g) == size(X) + g = Zygote.gradient(x -> sum(gc(x)), X)[1] + @test size(g) == size(X) - # g = Zygote.gradient(model -> sum(model(X)), gc)[1] - # @test size(g.weight) == size(gc.weight) - # @test size(g.bias) == size(gc.bias) - # end + g = Zygote.gradient(model -> sum(model(X)), gc)[1] + @test size(g.weight) == size(gc.weight) + @test size(g.bias) == size(gc.bias) + end # @testset "ChebConv" begin diff --git a/test/featured_graph.jl b/test/featured_graph.jl index 94d42a154..66e193a39 100644 --- a/test/featured_graph.jl +++ b/test/featured_graph.jl @@ -29,6 +29,12 @@ fg = FeaturedGraph(adj_mat) adjacency_matrix(fg; dir=:out) == adj_mat adjacency_matrix(fg; dir=:in) == adj_mat + end + + @testset "degree" begin + fg = FeaturedGraph(adj_mat) + @test degree(fg, dir=:out) == vec(sum(adj_mat, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat, dims=1)) end end @@ -64,6 +70,12 @@ @test adjacency_list(fg, dir=:out) == adj_list_out @test adjacency_matrix(fg, dir=:in) == adj_mat_in @test adjacency_list(fg, dir=:in) == adj_list_in + + @testset "degree" begin + fg = FeaturedGraph(adj_mat_out) + @test degree(fg, dir=:out) == vec(sum(adj_mat_out, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat_out, dims=1)) + end end end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 61a3e222b..9ab59bacb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,6 @@ using Flux: @functor using LightGraphs using LinearAlgebra using NNlib -using SparseArrays: SparseMatrixCSC using Statistics: mean using Zygote using Test From 12c8d1e8bbadf2a3d56bc87e7172a76f59d45d6f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 12:09:23 +0200 Subject: [PATCH 08/11] more tests --- src/GeometricFlux.jl | 11 ++++++++--- src/featured_graph.jl | 12 ++++++------ test/cuda/conv.jl | 36 ++++++++++++++++++++---------------- test/cuda/featured_graph.jl | 36 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 ++++--- 5 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 test/cuda/featured_graph.jl diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 15977d383..bc37f8e09 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,6 +1,7 @@ module GeometricFlux -using ChainRulesCore: eltype +using NNlib: similar +using ChainRulesCore: eltype, reshape using LinearAlgebra: similar using Statistics: mean using LinearAlgebra @@ -16,10 +17,14 @@ using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv adjacency_matrix, degree export + # featured_graph FeaturedGraph, - adjacency_list, + edge_index, node_feature, edge_feature, global_feature, - ne, nv, adjacency_matrix, # from LightGraphs + adjacency_list, normalized_laplacian, scaled_laplacian, + + # from LightGraphs + ne, nv, adjacency_matrix, # layers/msgpass MessagePassing, diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 41c5a47c7..2d22ae83c 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -255,18 +255,18 @@ LightGraphs.is_directed(g::AbstractMatrix) = !issymmetric(Matrix(g)) ## from GraphLaplacians """ - normalized_laplacian(g[, T]; selfloop=false, dir=:out) + normalized_laplacian(fg, T=Float32; selfloop=false, dir=:out) Normalized Laplacian matrix of graph `g`. # Arguments -- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). -- `T`: result element type of degree vector; default is the element type of `g` (optional). -- `selfloop`: adding self loop while calculating the matrix (optional). +- `fg`: A `FeaturedGraph`. +- `T`: result element type of degree vector; default `Float32`. +- `selfloop`: adding self loop while calculating the matrix. - `dir`: the edge directionality considered (:out, :in, :both). """ -function normalized_laplacian(fg::FeaturedGraph, T::DataType=Int; selfloop::Bool=false, dir::Symbol=:out) +function normalized_laplacian(fg::FeaturedGraph, T::DataType=Float32; selfloop::Bool=false, dir::Symbol=:out) A = adjacency_matrix(fg, T; dir=dir) selfloop && (A += I) degs = vec(sum(A; dims=2)) @@ -286,7 +286,7 @@ defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normal - `T`: result element type of degree vector; default is the element type of `g` (optional). - `dir`: the edge directionality considered (:out, :in, :both). """ -function scaled_laplacian(fg::FeaturedGraph, T::DataType=Int; dir=:out) +function scaled_laplacian(fg::FeaturedGraph, T::DataType=Float32; dir=:out) A = adjacency_matrix(fg, T; dir=dir) @assert issymmetric(A) "scaled_laplacian only works with symmetric matrices" E = eigen(Symmetric(A)).values diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index b2652084a..5ddf68411 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -28,25 +28,29 @@ end - # @testset "ChebConv" begin - # k = 6 - # cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu - # @test size(cc.weight) == (out_channel, in_channel, k) - # @test size(cc.bias) == (out_channel,) - # @test adjacency_matrix(cc.fg |> cpu) == adj - # @test cc.k == k + @testset "ChebConv" begin + k = 6 + cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu + @test size(cc.weight) == (out_channel, in_channel, k) + @test size(cc.bias) == (out_channel,) + @test adjacency_matrix(cc.fg |> cpu) == adj + @test cc.k == k - # X = rand(in_channel, N) |> gpu - # Y = cc(X) - # @test size(Y) == (out_channel, N) + @test_broken begin + X = rand(in_channel, N) |> gpu + Y = cc(X) + @test size(Y) == (out_channel, N) - # g = Zygote.gradient(x -> sum(cc(x)), X)[1] - # @test size(g) == size(X) + g = Zygote.gradient(x -> sum(cc(x)), X)[1] + @test size(g) == size(X) - # g = Zygote.gradient(model -> sum(model(X)), cc)[1] - # @test size(g.weight) == size(cc.weight) - # @test size(g.bias) == size(cc.bias) - # end + g = Zygote.gradient(model -> sum(model(X)), cc)[1] + @test size(g.weight) == size(cc.weight) + @test size(g.bias) == size(cc.bias) + + true + end + end @testset "GraphConv" begin gc = GraphConv(fg, in_channel=>out_channel) |> gpu diff --git a/test/cuda/featured_graph.jl b/test/cuda/featured_graph.jl new file mode 100644 index 000000000..15430e849 --- /dev/null +++ b/test/cuda/featured_graph.jl @@ -0,0 +1,36 @@ +@testset "featured graph" begin + s = [1,1,2,3,4,5,5,5] + t = [2,5,3,2,1,4,3,1] + s, t = [s; t], [t; s] #symmetrize + fg = FeaturedGraph(s, t) + fg_gpu = fg |> gpu + + @testset "functor" begin + s_gpu, t_gpu = edge_index(fg_gpu) + @test s_gpu isa CuVector{Int} + @test Array(s_gpu) == s + @test t_gpu isa CuVector{Int} + @test Array(t_gpu) == t + end + + @testset "adjacency_matrix" begin + mat = adjacency_matrix(fg) + mat_gpu = adjacency_matrix(fg_gpu) + @test mat_gpu isa CuMatrix{Int} + end + + @testset "normalized_laplacian" begin + mat = normalized_laplacian(fg) + mat_gpu = normalized_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + end + + @testset "scaled_laplacian" begin + @test_broken begin + mat = scaled_laplacian(fg) + mat_gpu = scaled_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + true + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9ab59bacb..427e45bc9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,10 @@ using NNlib using Statistics: mean using Zygote using Test +using CUDA +using Flux: gpu +using NNlibCUDA +CUDA.allowscalar(false) cuda_tests = [ "cuda/conv", @@ -24,9 +28,6 @@ tests = [ ] if Flux.use_cuda[] - using CUDA - using Flux: gpu - using NNlibCUDA append!(tests, cuda_tests) else @warn "CUDA unavailable, not testing GPU support" From ac4b3a876ad0bf8ddeb05df13081463249f6566c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 29 Jul 2021 19:01:53 +0200 Subject: [PATCH 09/11] more cleanup; add_self_loops --- src/featured_graph.jl | 164 ++++++++++++++++------------------------- src/layers/conv.jl | 12 +-- src/layers/msgpass.jl | 117 +++++++++++++++++++---------- test/cuda/msgpass.jl | 2 +- test/layers/msgpass.jl | 6 +- 5 files changed, 150 insertions(+), 151 deletions(-) diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 2d22ae83c..4732439c5 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -1,23 +1,8 @@ #=================================== Define FeaturedGraph type as a subtype of LightGraphs' AbstractGraph. - -All LightGraphs functions rely on a standard API to function. -As long as your graph structure is a subtype of AbstractGraph and -implements the following API functions with the given return values, -all functions within the LightGraphs package should just work: - edges - Base.eltype - edgetype (example: edgetype(g::CustomGraph) = LightGraphs.SimpleEdge{eltype(g)})) - has_edge - has_vertex - inneighbors - ne - nv - outneighbors - vertices - is_directed(::Type{CustomGraph})::Bool (example: is_directed(::Type{<:CustomGraph}) = false) - is_directed(g::CustomGraph)::Bool - zero +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/LightGraphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/LightGraphs.jl/latest/developing/#Developing-Alternate-Graph-Types =============================================# abstract type AbstractFeaturedGraph <: AbstractGraph{Int} end @@ -43,8 +28,8 @@ struct FeaturedGraph <: AbstractFeaturedGraph end -function FeaturedGraph(u::AbstractVector{Int}, v::AbstractVector{Int}; - num_nodes = max(maximum(u), maximum(v)), +function FeaturedGraph(s::AbstractVector{Int}, t::AbstractVector{Int}; + num_nodes = max(maximum(s), maximum(t)), # ndata = Dict{String, Any}(), # edata = Dict{String, Any}(), # gdata = Dict{String, Any}(), @@ -52,11 +37,11 @@ function FeaturedGraph(u::AbstractVector{Int}, v::AbstractVector{Int}; ef = nothing, gf = nothing) - @assert length(u) == length(v) - @assert min(minimum(u), minimum(v)) >= 1 - @assert max(maximum(u), maximum(v)) <= num_nodes + @assert length(s) == length(t) + @assert min(minimum(s), minimum(t)) >= 1 + @assert max(maximum(s), maximum(t)) <= num_nodes - num_edges = length(u) + num_edges = length(s) ## I would like to have dict data store, but currently this ## doesn't play well with zygote due to @@ -66,7 +51,7 @@ function FeaturedGraph(u::AbstractVector{Int}, v::AbstractVector{Int}; # gdata["g"] = gf - FeaturedGraph((u, v), num_nodes, num_edges, + FeaturedGraph((s, t), num_nodes, num_edges, nf, ef, gf) end @@ -77,20 +62,20 @@ function FeaturedGraph(adj_mat::AbstractMatrix; dir=:out, kws...) @assert num_nodes == size(adj_mat, 2) @assert all(x -> (x == 1) || (x == 0), adj_mat) num_edges = round(Int, sum(adj_mat)) - u = zeros(Int, num_edges) - v = zeros(Int, num_edges) + s = zeros(Int, num_edges) + t = zeros(Int, num_edges) e = 0 for j in 1:num_nodes for i in 1:num_nodes if adj_mat[i, j] == 1 e += 1 - u[e] = i - v[e] = j + s[e] = i + t[e] = j end end end @assert e == num_edges - FeaturedGraph(u, v; num_nodes, kws...) + FeaturedGraph(s, t; num_nodes, kws...) end @@ -99,20 +84,21 @@ function FeaturedGraph(adj_list::AbstractVector{<:AbstractVector}; dir=:out, kws @assert dir == :out # TODO num_nodes = length(adj_list) num_edges = sum(length.(adj_list)) - u = zeros(Int, num_edges) - v = zeros(Int, num_edges) + s = zeros(Int, num_edges) + t = zeros(Int, num_edges) e = 0 for i in 1:num_nodes for j in adj_list[i] e += 1 - u[e] = i - v[e] = j + s[e] = i + t[e] = j end end @assert e == num_edges - FeaturedGraph(u, v; num_nodes, kws...) + FeaturedGraph(s, t; num_nodes, kws...) end +FeaturedGraph(g::AbstractGraph; kws...) = FeaturedGraph(adjacency_matrix(g, dir=:out); kws...) # from other featured_graph function FeaturedGraph(fg::FeaturedGraph; @@ -126,15 +112,25 @@ end @functor FeaturedGraph +""" + edge_index(fg::FeaturedGraph) + +Return a tuple containing two vectors, respectively containing the source and target +nodes of the edges in the graph `fg`. + +```julia +s, t = edge_index(fg) +``` +""" edge_index(fg::FeaturedGraph) = fg.edge_index -LightGraphs.edges(fg::FeaturedGraph) = zip(fg.edge_index[1], fg.edge_index[2]) +LightGraphs.edges(fg::FeaturedGraph) = zip(edge_index(fg)...) -LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{eltype(fg.edge_index[1]), eltype(fg.edge_index[2])} +LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{Int, Int} function LightGraphs.has_edge(fg::FeaturedGraph, i::Integer, j::Integer) - u, v = fg.edge_index - return any((u .== i) .& (v .== j)) + s, t = edge_index(fg) + return any((s .== i) .& (t .== j)) end LightGraphs.nv(fg::FeaturedGraph) = fg.num_nodes @@ -143,13 +139,13 @@ LightGraphs.has_vertex(fg::FeaturedGraph, i::Int) = i in 1:fg.num_nodes LightGraphs.vertices(fg::FeaturedGraph) = 1:fg.num_nodes function LightGraphs.outneighbors(fg::FeaturedGraph, i::Integer) - u, v = fg.edge_index - return v[u .== i] + s, t = edge_index(fg) + return t[s .== i] end function LightGraphs.inneighbors(fg::FeaturedGraph, i::Integer) - u, v = fg.edge_index - return u[v .== i] + s, t = edge_index(fg) + return s[t .== i] end LightGraphs.is_directed(::FeaturedGraph) = true @@ -165,10 +161,10 @@ end # TODO return sparse matrix function LightGraphs.adjacency_matrix(fg::FeaturedGraph, T::DataType=Int; dir=:out) # TODO dir=:both - u, v = edge_index(fg) + s, t = edge_index(fg) n = fg.num_nodes - adj_mat = fill!(similar(u, T, (n, n)), 0) - adj_mat[u .+ n .* (v .- 1)] .= 1 # exploiting linear indexing + adj_mat = fill!(similar(s, T, (n, n)), 0) + adj_mat[s .+ n .* (t .- 1)] .= 1 # exploiting linear indexing return dir == :out ? adj_mat : adj_mat' end @@ -185,14 +181,6 @@ function LightGraphs.degree(fg::FeaturedGraph; dir=:both) return degs end -Zygote.@nograd adjacency_matrix, adjacency_list, degree - - -# function ChainRulesCore.rrule(::typeof(copy), x) -# copy_pullback(ȳ) = (NoTangent(), ȳ) -# return copy(x), copy_pullback -# end - # node_feature(fg::FeaturedGraph) = fg.ndata["x"] # edge_feature(fg::FeaturedGraph) = fg.edata["e"] # global_feature(fg::FeaturedGraph) = fg.gdata["g"] @@ -201,12 +189,6 @@ node_feature(fg::FeaturedGraph) = fg.nf edge_feature(fg::FeaturedGraph) = fg.ef global_feature(fg::FeaturedGraph) = fg.gf -## TO DEPRECATE EVERYTHING BELOW ??? ############################## - -# has_graph(fg::FeaturedGraph) = true -# has_graph(fg::NullGraph) = false -# graph(fg::FeaturedGraph) = adjacency_list(fg) # DEPRECATE - # function Base.getproperty(fg::FeaturedGraph, sym::Symbol) # if sym === :nf # return fg.ndata["x"] @@ -219,33 +201,6 @@ global_feature(fg::FeaturedGraph) = fg.gf # end # end -## Already in GraphSignals ############## -LightGraphs.ne(adj_list::AbstractVector{<:AbstractVector}) = sum(length.(adj_list)) -LightGraphs.nv(adj_list::AbstractVector{<:AbstractVector}) = length(adj_list) -LightGraphs.ne(adj_mat::AbstractMatrix) = round(Int, sum(adj_mat)) -LightGraphs.nv(adj_mat::AbstractMatrix) = size(adj_mat, 1) - -adjacency_list(adj::AbstractVector{<:AbstractVector}) = adj - -function LightGraphs.is_directed(g::AbstractVector{T}) where {T<:AbstractVector} - edges = Set{Tuple{Int64,Int64}}() - for (i, js) in enumerate(g) - for j in Set(js) - if i != j - e = (i,j) - if e in edges - pop!(edges, e) - else - push!(edges, (j,i)) - end - end - end - end - !isempty(edges) -end - -LightGraphs.is_directed(g::AbstractMatrix) = !issymmetric(Matrix(g)) - # function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:out) # A = adjacency_matrix(fg, T; dir=dir) # D = Diagonal(vec(sum(A; dims=2))) @@ -296,21 +251,28 @@ function scaled_laplacian(fg::FeaturedGraph, T::DataType=Float32; dir=:out) return 2 / maximum(E) * Lnorm - I end - -function add_self_loop!(adj::AbstractVector{<:AbstractVector}) - for i = 1:length(adj) - i in adj[i] || push!(adj[i], i) - end - adj +function add_self_loops(fg::FeaturedGraph) + s, t = edge_index(fg) + @assert edge_feature(fg) === nothing + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + n = fg.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + FeaturedGraph(s, t, nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) end -# # TODO Do we need a separate package just for laplacians? -# GraphLaplacians.scaled_laplacian(fg::FeaturedGraph, T::DataType) = -# scaled_laplacian(adjacency_matrix(fg, T)) -# GraphLaplacians.normalized_laplacian(fg::FeaturedGraph, T::DataType; kws...) = -# normalized_laplacian(adjacency_matrix(fg, T); kws...) - - @non_differentiable normalized_laplacian(x...) @non_differentiable scaled_laplacian(x...) -@non_differentiable add_self_loop!(x...) +@non_differentiable adjacency_matrix(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable degree(x...) +@non_differentiable add_self_loops(x...) + +# delete when https://github.com/JuliaDiff/ChainRules.jl/pull/472 is merged +function ChainRulesCore.rrule(::typeof(copy), x) + copy_pullback(ȳ) = (NoTangent(), ȳ) + return copy(x), copy_pullback +end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8d3b30b18..598b1c668 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -44,11 +44,11 @@ message(l::GCNConv, xi, xj) = xj update(l::GCNConv, m, x) = m function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) - # TODO handle self loops + fg = add_self_loops(fg) # cout = sqrt.(degree(fg, dir=:out)) cin = reshape(sqrt.(degree(fg, dir=:in)), 1, :) x = cin .* x - _, x = propagate(l, fg, nothing, x, +) + _, x = propagate(l, fg, nothing, x, nothing, +) x = cin .* x return l.σ.(l.weight * x .+ l.bias) end @@ -171,7 +171,7 @@ update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) check_num_nodes(fg, x) - _, x = propagate(gc, fg, nothing, x, +) + _, x = propagate(gc, fg, nothing, x, nothing, +) x end @@ -236,7 +236,7 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...) function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) - # add_self_loop!(adj) #TODO + fg = add_self_loops(fg) chin, chout = gat.channel heads = gat.heads @@ -325,7 +325,7 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T end for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H - _, M = propagate(ggc, fg, nothing, M, +) + _, M = propagate(ggc, fg, nothing, M, nothing, +) H, _ = ggc.gru(H, M) end H @@ -371,7 +371,7 @@ update(ec::EdgeConv, m, x) = m function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) check_num_nodes(fg, X) - _, X = propagate(ec, fg, nothing, X, ec.aggr) + _, X = propagate(ec, fg, nothing, X, nothing, ec.aggr) X end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 82ec12ba3..5547199c7 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -2,24 +2,61 @@ # "Relational inductive biases, deep learning, and graph networks" abstract type MessagePassing end -function propagate(gn::MessagePassing, fg::FeaturedGraph, aggr=+) - E, X, u = propagate(gn, fg, +""" + propagate(mp::MessagePassing, fg::FeaturedGraph, aggr) + propagate(mp::MessagePassing, fg::FeaturedGraph, E, X, u, aggr) + +Perform the sequence of operation implementing the message-passing scheme +and updating node, edge, and global features `X`, `E`, and `u` respectively. + +The computation involved is the following: + +```julia +M = compute_batch_message(mp, fg, E, X, u) +E = update_edge(mp, M, E, u) +M̄ = aggregate_neighbors(mp, aggr, fg, M) +X = update(mp, M̄, X, u) +u = update_global(mp, E, X, u) +``` + +Custom layers sub-typing [`MessagePassing`](@ref) +typically call define their own [`update`](@ref) +and [`message`](@ref) function, than call +this method in the forward pass: + +```julia +function (l::GNNLayer)(fg, X) + ... some prepocessing if needed ... + E = nothing + u = nothing + propagate(l, fg, E, X, u, +) +end +``` + +See also [`message`](@ref) and [`update`](@ref). +""" +function propagate end + +function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr) + E, X, u = propagate(mp, fg, edge_feature(fg), node_feature(fg), global_feature(fg), aggr) FeaturedGraph(fg, nf=X, ef=E, gf=u) end -function propagate(gn::MessagePassing, fg::FeaturedGraph, E, X, u, aggr=+) - M = compute_batch_message(gn, fg, E, X, u) - E = update_batch_edge(gn, M, E, u) - M̄ = aggregate_neighbors(gn, aggr, fg, M) - X = update_batch_vertex(gn, M̄, X, u) - u = update_global(gn, E, X, u) +function propagate(mp::MessagePassing, fg::FeaturedGraph, E, X, u, aggr) + M = compute_batch_message(mp, fg, E, X, u) + E = update_edge(mp, M, E, u) + M̄ = aggregate_neighbors(mp, aggr, fg, M) + X = update(mp, M̄, X, u) + u = update_global(mp, E, X, u) return E, X, u end """ + message(mp::MessagePassing, x_i, x_j, e_ij, u) message(mp::MessagePassing, x_i, x_j, e_ij) + message(mp::MessagePassing, x_i, x_j) Message function for the message-passing scheme, returning the message from node `j` to node `i` . @@ -33,34 +70,37 @@ specialize this method with custom behavior. # Arguments -- `mp`: message-passing layer. -- `x_i`: the features of node `i`. -- `x_j`: the features of the nighbor `j` of node `i`. -- `e_ij`: the features of edge (`i`, `j`). +- `mp`: A [`MessagePassing`](@ref) layer. +- `x_i`: Features of the central node `i`. +- `x_j`: Features of the neighbor `j` of node `i`. +- `e_ij`: Features of edge (`i`, `j`). +- `u`: Global features. -See also [`update`](@ref). +See also [`update`](@ref) and [`propagate`](@ref). """ function message end """ - update(mp::MessagePassing, m, x) + update(mp::MessagePassing, m̄, x, u) + update(mp::MessagePassing, m̄, x) Update function for the message-passing scheme, returning a new set of node features `x′` based on old features `x` and the incoming message from the neighborhood -aggregation `m`. +aggregation `m̄`. -By default, the function returns `m`. +By default, the function returns `m̄`. Layers subtyping [`MessagePassing`](@ref) should specialize this method with custom behavior. # Arguments -- `mp`: message-passing layer. -- `m`: the aggregated edge messages from the [`message`](@ref) function. -- `x`: the node features to be updated. +- `mp`: A [`MessagePassing`](@ref) layer. +- `m̄`: Aggregated edge messages from the [`message`](@ref) function. +- `x`: Node features to be updated. +- `u`: Global features. -See also [`message`](@ref). +See also [`message`](@ref) and [`propagate`](@ref). """ function update end @@ -70,45 +110,42 @@ _gather(x::Nothing, i) = nothing ## Step 1. -function compute_batch_message(gn::MessagePassing, fg, E, X, u) +function compute_batch_message(mp::MessagePassing, fg, E, X, u) s, t = edge_index(fg) Xi = _gather(X, t) Xj = _gather(X, s) - M = message(gn, Xi, Xj, E, u) + M = message(mp, Xi, Xj, E, u) return M end -# @inline message(gn::MessagePassing, i, j, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij, u) # TODO add in the future -@inline message(gn::MessagePassing, x_i, x_j, e_ij, u) = message(gn, x_i, x_j, e_ij) -@inline message(gn::MessagePassing, x_i, x_j, e_ij) = message(gn, x_i, x_j) -@inline message(gn::MessagePassing, x_i, x_j) = x_j +# @inline message(mp::MessagePassing, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future +@inline message(mp::MessagePassing, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij) +@inline message(mp::MessagePassing, x_i, x_j, e_ij) = message(mp, x_i, x_j) +@inline message(mp::MessagePassing, x_i, x_j) = x_j ## Step 2 -function aggregate_neighbors(gn::MessagePassing, aggr, fg, E) - s, t = edge_index(fg) - NNlib.scatter(aggr, E, t) -end - -aggregate_neighbors(gn::MessagePassing, aggr::Nothing, fg, E) = nothing +@inline update_edge(mp::MessagePassing, M, E, u) = update_edge(mp::MessagePassing, M, E) +@inline update_edge(mp::MessagePassing, M, E) = E ## Step 3 -update_batch_vertex(gn::MessagePassing, M̄, X, u) = update(gn, M̄, X, u) +function aggregate_neighbors(mp::MessagePassing, aggr, fg, E) + s, t = edge_index(fg) + NNlib.scatter(aggr, E, t) +end -# @inline update(gn::MessagePassing, i, m̄, x, u) = update(gn, m, x, u) -@inline update(gn::MessagePassing, m̄, x, u) = update(gn, m̄, x) -@inline update(gn::MessagePassing, m̄, x) = m̄ +aggregate_neighbors(mp::MessagePassing, aggr::Nothing, fg, E) = nothing ## Step 4 -update_batch_edge(gn::MessagePassing, M, E, u) = update_edge(gn::MessagePassing, M, E, u) -@inline update_edge(gn::MessagePassing, M, E, u) = update_edge(gn::MessagePassing, M, E) -@inline update_edge(gn::MessagePassing, M, E) = E +# @inline update(mp::MessagePassing, i, m̄, x, u) = update(mp, m, x, u) +@inline update(mp::MessagePassing, m̄, x, u) = update(mp, m̄, x) +@inline update(mp::MessagePassing, m̄, x) = m̄ ## Step 5 -@inline update_global(gn::MessagePassing, E, X, u) = u +@inline update_global(mp::MessagePassing, E, X, u) = u ### end steps ### diff --git a/test/cuda/msgpass.jl b/test/cuda/msgpass.jl index 5f16af3e6..8f1259799 100644 --- a/test/cuda/msgpass.jl +++ b/test/cuda/msgpass.jl @@ -15,7 +15,7 @@ end NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n)) @functor NewCudaLayer -(l::NewCudaLayer)(X) = GeometricFlux.propagate(l, X, +) +(l::NewCudaLayer)(fg) = GeometricFlux.propagate(l, fg, +) GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j GeometricFlux.update(::NewCudaLayer, m, x) = m diff --git a/test/layers/msgpass.jl b/test/layers/msgpass.jl index 8cba084d1..54e2ed2c3 100644 --- a/test/layers/msgpass.jl +++ b/test/layers/msgpass.jl @@ -19,15 +19,15 @@ u = rand(T, in_channel) - @testset "default aggregation (+)" begin + @testset "no aggregation" begin l = NewLayer() - (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg) + (l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, nothing) fg = FeaturedGraph(adj, nf=X) fg_ = l(fg) @test adjacency_matrix(fg_) == adj - @test size(node_feature(fg_)) == (in_channel, num_V) + @test node_feature(fg_) === nothing @test edge_feature(fg_) === nothing @test global_feature(fg_) === nothing end From 9d8770ce418a5536161f8d499b1b0526e6d152c4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 30 Jul 2021 13:55:25 +0200 Subject: [PATCH 10/11] updates --- src/featured_graph.jl | 17 +++++++++-------- src/layers/msgpass.jl | 8 ++++++++ test/featured_graph.jl | 2 +- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 4732439c5..9d06e60eb 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -55,9 +55,8 @@ function FeaturedGraph(s::AbstractVector{Int}, t::AbstractVector{Int}; nf, ef, gf) end -# Construct from adjacency matrix # TODO deprecate? function FeaturedGraph(adj_mat::AbstractMatrix; dir=:out, kws...) - @assert dir == :out # TODO + @assert dir ∈ [:out, :in] num_nodes = size(adj_mat, 1) @assert num_nodes == size(adj_mat, 2) @assert all(x -> (x == 1) || (x == 0), adj_mat) @@ -75,13 +74,14 @@ function FeaturedGraph(adj_mat::AbstractMatrix; dir=:out, kws...) end end @assert e == num_edges + if dir == :in + s, t = t, s + end FeaturedGraph(s, t; num_nodes, kws...) end - -# Construct from adjacency list # TODO deprecate? function FeaturedGraph(adj_list::AbstractVector{<:AbstractVector}; dir=:out, kws...) - @assert dir == :out # TODO + @assert dir ∈ [:out, :in] num_nodes = length(adj_list) num_edges = sum(length.(adj_list)) s = zeros(Int, num_edges) @@ -95,12 +95,14 @@ function FeaturedGraph(adj_list::AbstractVector{<:AbstractVector}; dir=:out, kws end end @assert e == num_edges + if dir == :in + s, t = t, s + end FeaturedGraph(s, t; num_nodes, kws...) end FeaturedGraph(g::AbstractGraph; kws...) = FeaturedGraph(adjacency_matrix(g, dir=:out); kws...) -# from other featured_graph function FeaturedGraph(fg::FeaturedGraph; # ndata=copy(fg.ndata), edata=copy(fg.edata), gdata=copy(fg.gdata), # copy keeps the refs to old data nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) @@ -152,8 +154,7 @@ LightGraphs.is_directed(::FeaturedGraph) = true LightGraphs.is_directed(::Type{FeaturedGraph}) = true function adjacency_list(fg::FeaturedGraph; dir=:out) - # TODO probably this has to be called with `dir=:in` by gnn layers - # TODO dir=:both + @assert dir ∈ [:out, :in] fneighs = dir == :out ? outneighbors : inneighbors return [fneighs(fg, i) for i in 1:fg.num_nodes] end diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 5547199c7..557d70d0d 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -1,5 +1,13 @@ # Adapted message passing from paper # "Relational inductive biases, deep learning, and graph networks" +""" + MessagePassing + +The abstract type from which all message passing layers are derived. + +Related methods are [`propagate`](@ref), [`message`](@ref), +[`update`](@ref), [`update_edge`](@ref), and [`update_global`](@ref). +""" abstract type MessagePassing end """ diff --git a/test/featured_graph.jl b/test/featured_graph.jl index 66e193a39..5000b8fa9 100644 --- a/test/featured_graph.jl +++ b/test/featured_graph.jl @@ -78,4 +78,4 @@ end end -end \ No newline at end of file +end From 4071e83b920c655a2ef45c3f914aa7fccf7e133d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 30 Jul 2021 14:13:16 +0200 Subject: [PATCH 11/11] more tests and a fix --- src/layers/conv.jl | 3 ++- test/layers/conv.jl | 25 ++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 598b1c668..4d9717162 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -45,8 +45,9 @@ update(l::GCNConv, m, x) = m function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) fg = add_self_loops(fg) + T = eltype(l.weight) # cout = sqrt.(degree(fg, dir=:out)) - cin = reshape(sqrt.(degree(fg, dir=:in)), 1, :) + cin = reshape(sqrt.(T.(degree(fg, dir=:in))), 1, :) x = cin .* x _, x = propagate(l, fg, nothing, x, nothing, +) x = cin .* x diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 8c27d3c43..ca5a47367 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -27,10 +27,12 @@ @test adjacency_matrix(gc.fg) == adj Y = gc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = gc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(gc(x)), X)[1] @@ -49,12 +51,14 @@ fg = FeaturedGraph(adj, nf=X) fg_ = gc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = gc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] @@ -85,10 +89,12 @@ @test cc.k == k Y = cc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = cc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(cc(x)), X)[1] @@ -108,12 +114,14 @@ fg = FeaturedGraph(adj, nf=X) fg_ = cc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError cc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = cc(fgt) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(cc(x))), fg)[1] @@ -141,10 +149,12 @@ @test size(gc.bias) == (out_channel,) Y = gc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = gc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(gc(x)), X)[1] @@ -164,12 +174,14 @@ fg = FeaturedGraph(adj, nf=X) fg_ = gc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = gc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] @@ -209,10 +221,12 @@ @test size(gat.a) == (2*out_channel, heads) Y = gat(X) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) # Test with transposed features Y = gat(Xt) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(gat(x)), X)[1] @@ -235,12 +249,14 @@ fg_ = gat(fg_gat) Y = node_feature(fg_) + @test Y isa Matrix{T} @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) @test_throws MethodError gat(X) # Test with transposed features fgt = FeaturedGraph(adj_gat, nf=Xt) fgt_ = gat(fgt) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(node_feature(gat(x))), fg_gat)[1] @@ -269,11 +285,12 @@ @test size(ggc.weight) == (out_channel, out_channel, num_layers) Y = ggc(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) - # Test with transposed features Y = ggc(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(ggc(x)), X)[1] @@ -289,12 +306,14 @@ fg = FeaturedGraph(adj, nf=X) fg_ = ggc(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ggc(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = ggc(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ggc(x))), fg)[1] @@ -313,10 +332,12 @@ @test adjacency_list(ec.fg) == [[2,4], [1,3], [2,4], [1,3]] Y = ec(X) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) # Test with transposed features Y = ec(Xt) + @test Y isa Matrix{T} @test size(Y) == (out_channel, N) g = Zygote.gradient(x -> sum(ec(x)), X)[1] @@ -332,12 +353,14 @@ fg = FeaturedGraph(adj, nf=X) fg_ = ec(fg) + @test node_feature(fg_) isa Matrix{T} @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ec(X) # Test with transposed features fgt = FeaturedGraph(adj, nf=Xt) fgt_ = ec(fgt) + @test node_feature(fgt_) isa Matrix{T} @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ec(x))), fg)[1]