diff --git a/Project.toml b/Project.toml index f4f02e75a..e5da43e8e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,39 +5,38 @@ version = "0.7.6" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc" GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8" -GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3.3" DataStructures = "0.18" FillArrays = "0.11, 0.12" Flux = "0.12" -GraphLaplacians = "0.1" GraphMLDatasets = "0.1" -GraphSignals = "0.2" +KrylovKit = "0.5" LightGraphs = "1.3" NNlib = "0.7" NNlibCUDA = "0.1" Reexport = "1.1" -Zygote = "0.6" julia = "1.6" [extras] SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["SparseArrays", "Test"] +test = ["SparseArrays", "Test", "Zygote"] diff --git a/docs/make.jl b/docs/make.jl index c94577157..ad84cd068 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,10 +19,12 @@ makedocs( ["Message passing scheme" => "abstractions/msgpass.md", "Graph network block" => "abstractions/gn.md"], "Manual" => - ["Convolutional Layers" => "manual/conv.md", + [ + "Graphs" => "manual/featuredgraph.md", + "Convolutional Layers" => "manual/conv.md", "Pooling Layers" => "manual/pool.md", "Models" => "manual/models.md", - "Linear Algebra" => "manual/linalg.md"] + ] ] ) diff --git a/docs/src/manual/featuredgraph.md b/docs/src/manual/featuredgraph.md new file mode 100644 index 000000000..7d39be937 --- /dev/null +++ b/docs/src/manual/featuredgraph.md @@ -0,0 +1,20 @@ +# Graphs + +GeometricFlux relies on the [`FeaturedGraph`](@ref) +type to represent graph structures and feature arrays associated to +nodes and edges. + + +```@docs +GeometrixFlux.FeaturedGraph +GeometrixFlux.edge_index +GeometrixFlux.graph +GeometrixFlux.adjacency_list +GeometrixFlux.adjacency_matrix +GeometrixFlux.add_self_loops +GeometrixFlux.remove_self_loops +GeometrixFlux.degree +GeometrixFlux.laplacian_matrix +GeometrixFlux.normalized_laplacian +GeometrixFlux.scaled_laplacian +``` diff --git a/docs/src/manual/linalg.md b/docs/src/manual/linalg.md index 00a8604ee..e69de29bb 100644 --- a/docs/src/manual/linalg.md +++ b/docs/src/manual/linalg.md @@ -1,22 +0,0 @@ -# Linear Algebra - - -```@docs -GraphSignals.degrees -``` - -```@docs -GraphSignals.degree_matrix -``` - -```@docs -GraphSignals.inv_sqrt_degree_matrix -``` - -```@docs -GraphSignals.laplacian_matrix -``` - -```@docs -GraphSignals.normalized_laplacian -``` diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index 0c10a3c36..ab74763fa 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,20 +1,32 @@ module GeometricFlux +using NNlib: similar +using LinearAlgebra: similar, fill! using Statistics: mean -using LinearAlgebra: Adjoint, norm, Transpose -using Reexport - +using LinearAlgebra +using SparseArrays +import KrylovKit using CUDA using FillArrays: Fill using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor using NNlib, NNlibCUDA -using GraphLaplacians -@reexport using GraphSignals -using LightGraphs -using Zygote +using ChainRulesCore +import LightGraphs +using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv, + adjacency_matrix, degree export + # featured_graph + FeaturedGraph, + graph, edge_index, + node_feature, edge_feature, global_feature, + adjacency_list, normalized_laplacian, scaled_laplacian, + add_self_loops, + + # from LightGraphs + adjacency_matrix, + # layers/gn GraphNet, @@ -50,8 +62,10 @@ export # utils generate_cluster + +include("featuredgraph.jl") +include("graph_conversions.jl") include("datasets.jl") - include("utils.jl") include("layers/gn.jl") diff --git a/src/featuredgraph.jl b/src/featuredgraph.jl new file mode 100644 index 000000000..2abc3a85c --- /dev/null +++ b/src/featuredgraph.jl @@ -0,0 +1,420 @@ +#=================================== +Define FeaturedGraph type as a subtype of LightGraphs' AbstractGraph. +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/LightGraphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/LightGraphs.jl/latest/developing/#Developing-Alternate-Graph-Types +=============================================# + +abstract type AbstractFeaturedGraph <: AbstractGraph{Int} end + +""" + NullGraph() + +Null object for `FeaturedGraph`. +""" +struct NullGraph <: AbstractFeaturedGraph end + +const COO_T = Tuple{T, T, V} where {T <: AbstractVector,V} +const ADJLIST_T = AbstractVector{T} where T <: AbstractVector +const ADJMAT_T = AbstractMatrix +const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T + +""" + FeaturedGraph(g; [graph_type, dir, num_nodes, nf, ef, gf]) + FeaturedGraph(fg::FeaturedGraph; [nf, ef, gf]) + +A type representing a graph structure and storing also arrays +that contain features associated to nodes, edges, and the whole graph. + +A `FeaturedGraph` can be constructed out of different objects `g` representing +the connections inside the graph, while the internal representation type +is governed by `graph_type`. +When constructed from another featured graph `fg`, the internal graph representation +is preserved and shared. + +A `FeaturedGraph` is a LightGraphs' `AbstractGraph`, therefore any functionality +from the LightGraphs' graph library can be used on it. + +# Arguments + +- `g`: Some data representing the graph topology. Possible type are + - An adjacency matrix + - An adjacency list. + - A tuple containing the source and target vectors (COO representation) + - A LightGraphs' graph. +- `graph_type`: A keyword argument that specifies + the underlying representation used by the FeaturedGraph. + Currently supported values are + - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge + connects the node `source[k]` to node `target[k]`. + Optionally, also edge weights can be given: `(source, target, weights)`. + - `:sparse`. A sparse adjacency matrix representation. + - `:dense`. A dense adjacency matrix representation. + Default `:coo`. +- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`. + Possible values are `:out` and `:in`. Defaul `:out`. +- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing. +- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing. +- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing. +- `gf`: Global features. Default nothing. + +# Usage. + +``` +using Flux, GeometricFlux + +# Construct from adjacency list representation +g = [[2,3], [1,4,5], [1], [2,5], [2,4]] +fg = FeaturedGraph(g) + +# Number of nodes and edges +fg.num_nodes # 5 +fg.num_edges # 10 + +# Same graph in COO representation +s = [1,1,2,2,2,3,4,4,5,5] +t = [2,3,1,4,5,3,2,5,2,4] +fg = FeaturedGraph((s, t)) +fg = FeaturedGraph(s, t) # other convenience constructor + +# From a LightGraphs' graph +fg = FeaturedGraph(erdos_renyi(100, 20)) + +# Copy featured graph while also adding node features +fg = FeaturedGraph(fg, nf=rand(100, 5)) + +# Send to gpu +fg = fg |> gpu + +# Collect edges' source and target nodes. +# Both source and target are vectors of length num_edges +source, target = edge_index(fg) +``` + +See also [`graph`](@ref), [`edge_index`](@ref), [`node_feature`](@ref), [`edge_feature`](@ref), and [`global_feature`](@ref) +""" +struct FeaturedGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractFeaturedGraph + graph::T + num_nodes::Int + num_edges::Int + nf + ef + gf + ## possible future property stores + # ndata::Dict{String, Any} # https://github.com/FluxML/Zygote.jl/issues/717 + # edata::Dict{String, Any} + # gdata::Dict{String, Any} +end + +@functor FeaturedGraph + +function FeaturedGraph(data; + num_nodes = nothing, + graph_type = :coo, + dir = :out, + nf = nothing, + ef = nothing, + gf = nothing, + # ndata = Dict{String, Any}(), + # edata = Dict{String, Any}(), + # gdata = Dict{String, Any}() + ) + + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + if graph_type == :coo + g, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + g, num_nodes, num_edges = to_dense(data; dir) + elseif graph_type == :sparse + g, num_nodes, num_edges = to_sparse(data; dir) + end + + ## Possible future implementation of feature maps. + ## Currently this doesn't play well with zygote due to + ## https://github.com/FluxML/Zygote.jl/issues/717 + # ndata["x"] = nf + # edata["e"] = ef + # gdata["g"] = gf + + + FeaturedGraph(g, num_nodes, num_edges, nf, ef, gf) +end + +# COO convenience constructors +FeaturedGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = FeaturedGraph((s, t, v); kws...) +FeaturedGraph((s, t)::NTuple{2}; kws...) = FeaturedGraph((s, t, nothing); kws...) + +# FeaturedGraph(g::AbstractGraph; kws...) = FeaturedGraph(adjacency_matrix(g, dir=:out); kws...) + +function FeaturedGraph(g::AbstractGraph; kws...) + s = LightGraphs.src.(LightGraphs.edges(g)) + t = LightGraphs.dst.(LightGraphs.edges(g)) + FeaturedGraph((s, t); kws...) +end + +function FeaturedGraph(fg::FeaturedGraph; + nf=node_feature(fg), ef=edge_feature(fg), gf=global_feature(fg)) + # ndata=copy(fg.ndata), edata=copy(fg.edata), gdata=copy(fg.gdata), # copy keeps the refs to old data + + FeaturedGraph(fg.graph, fg.num_nodes, fg.num_edges, nf, ef, gf) # ndata, edata, gdata, +end + + +""" + edge_index(fg::FeaturedGraph) + +Return a tuple containing two vectors, respectively storing +the source and target nodes for each edges in `fg`. + +```julia +s, t = edge_index(fg) +``` +""" +edge_index(fg::FeaturedGraph{<:COO_T}) = graph(fg)[1:2] + +edge_index(fg::FeaturedGraph{<:ADJMAT_T}) = to_coo(graph(fg))[1][1:2] + +edge_weight(fg::FeaturedGraph{<:COO_T}) = graph(fg)[3] + +""" + graph(fg::FeaturedGraph) + +Return the underlying implementation of the graph structure of `fg`, +either an adjacency matrix or an edge list in the COO format. +""" +graph(fg::FeaturedGraph) = fg.graph + +LightGraphs.edges(fg::FeaturedGraph) = zip(edge_index(fg)...) + +LightGraphs.edgetype(fg::FeaturedGraph) = Tuple{Int, Int} + +function LightGraphs.has_edge(fg::FeaturedGraph{<:COO_T}, i::Integer, j::Integer) + s, t = edge_index(fg) + return any((s .== i) .& (t .== j)) +end + +LightGraphs.has_edge(fg::FeaturedGraph{<:ADJMAT_T}, i::Integer, j::Integer) = graph(fg)[i,j] != 0 + +LightGraphs.nv(fg::FeaturedGraph) = fg.num_nodes +LightGraphs.ne(fg::FeaturedGraph) = fg.num_edges +LightGraphs.has_vertex(fg::FeaturedGraph, i::Int) = 1 <= i <= fg.num_nodes +LightGraphs.vertices(fg::FeaturedGraph) = 1:fg.num_nodes + +function LightGraphs.outneighbors(fg::FeaturedGraph{<:COO_T}, i::Integer) + s, t = edge_index(fg) + return t[s .== i] +end + +function LightGraphs.outneighbors(fg::FeaturedGraph{<:ADJMAT_T}, i::Integer) + A = graph(fg) + return findall(!=(0), A[i,:]) +end + +function LightGraphs.inneighbors(fg::FeaturedGraph{<:COO_T}, i::Integer) + s, t = edge_index(fg) + return s[t .== i] +end + +function LightGraphs.inneighbors(fg::FeaturedGraph{<:ADJMAT_T}, i::Integer) + A = graph(fg) + return findall(!=(0), A[:,i]) +end + +LightGraphs.is_directed(::FeaturedGraph) = true +LightGraphs.is_directed(::Type{FeaturedGraph}) = true + +function adjacency_list(fg::FeaturedGraph; dir=:out) + @assert dir ∈ [:out, :in] + fneighs = dir == :out ? outneighbors : inneighbors + return [fneighs(fg, i) for i in 1:fg.num_nodes] +end + +function LightGraphs.adjacency_matrix(fg::FeaturedGraph{<:COO_T}, T::DataType=Int; dir=:out) + A, n, m = to_sparse(graph(fg), T, num_nodes=fg.num_nodes) + @assert size(A) == (n, n) + return dir == :out ? A : A' +end + +function LightGraphs.adjacency_matrix(fg::FeaturedGraph{<:ADJMAT_T}, T::DataType=eltype(graph(fg)); dir=:out) + @assert dir ∈ [:in, :out] + A = graph(fg) + A = T != eltype(A) ? T.(A) : A + return dir == :out ? A : A' +end + +function LightGraphs.degree(fg::FeaturedGraph{<:COO_T}; dir=:out) + s, t = edge_index(fg) + degs = fill!(similar(s, eltype(s), fg.num_nodes), 0) + o = fill!(similar(s, eltype(s), fg.num_edges), 1) + if dir ∈ [:out, :both] + NNlib.scatter!(+, degs, o, s) + end + if dir ∈ [:in, :both] + NNlib.scatter!(+, degs, o, t) + end + return degs +end + +function LightGraphs.degree(fg::FeaturedGraph{<:ADJMAT_T}; dir=:out) + @assert dir ∈ (:in, :out) + A = graph(fg) + return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1)) +end + +# node_feature(fg::FeaturedGraph) = fg.ndata["x"] +# edge_feature(fg::FeaturedGraph) = fg.edata["e"] +# global_feature(fg::FeaturedGraph) = fg.gdata["g"] + + +""" + node_feature(fg::FeaturedGraph) + +Return the node features of `fg`. +""" +node_feature(fg::FeaturedGraph) = fg.nf + +""" + edge_feature(fg::FeaturedGraph) + +Return the edge features of `fg`. +""" +edge_feature(fg::FeaturedGraph) = fg.ef + +""" + global_feature(fg::FeaturedGraph) + +Return the global features of `fg`. +""" +global_feature(fg::FeaturedGraph) = fg.gf + +# function Base.getproperty(fg::FeaturedGraph, sym::Symbol) +# if sym === :nf +# return fg.ndata["x"] +# elseif sym === :ef +# return fg.edata["e"] +# elseif sym === :gf +# return fg.gdata["g"] +# else # fallback to getfield +# return getfield(fg, sym) +# end +# end + +function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:out) + A = adjacency_matrix(fg, T; dir=dir) + D = Diagonal(vec(sum(A; dims=2))) + return D - A +end + +""" + normalized_laplacian(fg, T=Float32; selfloop=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `fg`: A `FeaturedGraph`. +- `T`: result element type. +- `selfloop`: adding self loop while calculating the matrix. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(fg::FeaturedGraph, T::DataType=Float32; selfloop::Bool=false, dir::Symbol=:out) + A = adjacency_matrix(fg, T; dir=dir) + sz = size(A) + @assert sz[1] == sz[2] + if selfloop + A += I - Diagonal(A) + else + A -= Diagonal(A) + end + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return I - inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(fg, T=Float32; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `fg`: A `FeaturedGraph`. +- `T`: result element type. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(fg::FeaturedGraph, T::DataType=Float32; dir=:out) + L = normalized_laplacian(fg, T) + @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" + λmax = _eigmax(L) + return 2 / λmax * L - I +end + +# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays +_eigmax(A) = KrylovKit.eigsolve(Symmetric(A), 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available + +# Eigenvalues for cuarray don't seem to be well supported. +# https://github.com/JuliaGPU/CUDA.jl/issues/154 +# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 + +""" + add_self_loops(fg::FeaturedGraph) + +Return a featured graph with the same features as `fg` +but also adding edges connecting the nodes to themselves. +""" +function add_self_loops(fg::FeaturedGraph{<:COO_T}) + s, t = edge_index(fg) + @assert edge_feature(fg) === nothing + @assert edge_weight(fg) === nothing + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + n = fg.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + + FeaturedGraph((s, t, nothing), fg.num_nodes, length(s), + node_feature(fg), edge_feature(fg), global_feature(fg)) +end + +function add_self_loops(fg::FeaturedGraph{<:ADJMAT_T}) + A = graph(fg) + @assert edge_feature(fg) === nothing + nold = sum(Diagonal(A)) |> Int + A = A - Diagonal(A) + I + num_edges = fg.num_edges - nold + fg.num_nodes + FeaturedGraph(A, fg.num_nodes, num_edges, + node_feature(fg), edge_feature(fg), global_feature(fg)) +end + + +function remove_self_loops(fg::FeaturedGraph{<:COO_T}) + s, t = edge_index(fg) + # TODO remove these constraints + @assert edge_feature(fg) === nothing + @assert edge_weight(fg) === nothing + + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + + FeaturedGraph((s, t, nothing), fg.num_nodes, length(s), + node_feature(fg), edge_feature(fg), global_feature(fg)) +end + +@non_differentiable normalized_laplacian(x...) +@non_differentiable scaled_laplacian(x...) +@non_differentiable adjacency_matrix(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable degree(x...) +@non_differentiable add_self_loops(x...) # TODO this is wrong, since fg carries feature arrays, needs rrule +@non_differentiable remove_self_loops(x...) # TODO this is wrong, since fg carries feature arrays, needs rrule + +# # delete when https://github.com/JuliaDiff/ChainRules.jl/pull/472 is merged +# function ChainRulesCore.rrule(::typeof(copy), x) +# copy_pullback(ȳ) = (NoTangent(), ȳ) +# return copy(x), copy_pullback +# end diff --git a/src/graph_conversions.jl b/src/graph_conversions.jl new file mode 100644 index 000000000..007f990ea --- /dev/null +++ b/src/graph_conversions.jl @@ -0,0 +1,134 @@ +### CONVERT_TO_COO REPRESENTATION ######## + +function to_coo(coo::COO_T; dir=:out, num_nodes=nothing) + s, t, val = coo + num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + @assert isnothing(val) || length(val) == length(s) + @assert length(s) == length(t) + @assert min(minimum(s), minimum(t)) >= 1 + @assert max(maximum(s), maximum(t)) <= num_nodes + + num_edges = length(s) + return coo, num_nodes, num_edges +end + +function to_coo(A::ADJMAT_T; dir=:out, num_nodes=nothing) + nz = findall(!=(0), A) # vec of cartesian indexes + s, t = ntuple(i -> map(t->t[i], nz), 2) + if dir == :in + s, t = t, s + end + num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + num_edges = length(s) + return (s, t, nothing), num_nodes, num_edges +end + +function to_coo(adj_list::ADJLIST_T; dir=:out, num_nodes=nothing) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + s = similar(adj_list[1], eltype(adj_list[1]), num_edges) + t = similar(adj_list[1], eltype(adj_list[1]), num_edges) + e = 0 + for i in 1:num_nodes + for j in adj_list[i] + e += 1 + s[e] = i + t[e] = j + end + end + @assert e == num_edges + if dir == :in + s, t = t, s + end + (s, t, nothing), num_nodes, num_edges +end + +### CONVERT TO ADJACENCY MATRIX ################ + +### DENSE #################### + +to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...) + +function to_dense(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothing) + @assert dir ∈ [:out, :in] + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + # @assert all(x -> (x == 1) || (x == 0), A) + num_edges = round(Int, sum(A)) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + return A, num_nodes, num_edges +end + +function to_dense(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=nothing) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + A = similar(adj_list[1], T, (num_nodes, num_nodes)) + if dir == :out + for (i, neigs) in enumerate(adj_list) + A[i, neigs] .= 1 + end + else + for (i, neigs) in enumerate(adj_list) + A[neigs, i] .= 1 + end + end + A, num_nodes, num_edges +end + +function to_dense(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing) + # `dir` will be ignored since the input `coo` is always in source -> target format. + # The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j) + s, t, val = coo + n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + A = fill!(similar(s, T, (n, n)), 0) + if isnothing(val) + A[s .+ n .* (t .- 1)] .= 1 # exploiting linear indexing + else + A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing + end + return A, n, length(s) +end + +### SPARSE ############# + +function to_sparse(A::ADJMAT_T, T::DataType=eltype(A); dir=:out, num_nodes=nothing) + @assert dir ∈ [:out, :in] + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + num_edges = round(Int, sum(A)) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + return sparse(A), num_nodes, num_edges +end + + +function to_sparse(adj_list::ADJLIST_T, T::DataType=Int; dir=:out, num_nodes=nothing) + coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes) + to_sparse(coo; dir, num_nodes) +end + +function to_sparse(coo::COO_T, T::DataType=Int; dir=:out, num_nodes=nothing) + s, t, eweight = coo + eweight = isnothing(eweight) ? fill!(similar(s, T), 1) : eweight + num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + A = sparse(s, t, eweight, num_nodes, num_nodes) + num_edges = length(s) + A, num_nodes, num_edges +end + +@non_differentiable to_coo(x...) +@non_differentiable to_dense(x...) +@non_differentiable to_sparse(x...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a8831b4e1..dc6633b48 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -246,10 +246,6 @@ update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) n = size(adj, 1) - # a vertex must always receive a message from itself - Zygote.ignore() do - GraphLaplacians.add_self_loop!(adj, n) - end mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) end @@ -266,8 +262,10 @@ function update_batch_vertex(gat::GATConv, M::AbstractMatrix) end function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) + # a vertex must always receive a message from itself + adjlist = add_self_loops(adjacency_list(fg)) check_num_nodes(fg, X) - _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) + _, X = propagate(gat, adjlist, Fill(0.f0, 0, ne(fg)), X, +) X end @@ -428,9 +426,8 @@ update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m) @functor GINConv function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix) - gf = graph(fg) - GraphSignals.check_num_node(gf, X) - _, X = propagate(g, adjacency_list(gf), Fill(0.f0, 0, ne(gf)), X, +) + check_num_nodes(fg, X) + _, X = propagate(g, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) X end diff --git a/src/utils.jl b/src/utils.jl index 9a59b563f..f0f307443 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,9 +6,7 @@ value is accumulated numbers of edge (current vertex not included). """ accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}}) = [0, cumsum(map(length, adj))...] -Zygote.@nograd accumulated_edges - -Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, accu_edge) where {T,N} +function generate_cluster(M::AbstractArray{T,N}, accu_edge) where {T,N} num_V = length(accu_edge) - 1 num_E = accu_edge[end] cluster = similar(M, Int, num_E) @@ -21,31 +19,19 @@ Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, accu_edge) where end """ - edge_index_table(adj[, directed]) + edge_index_table(adj) Generate a mapping from vertex pair (i, j) to edge index. The edge indecies are determined by the sorted vertex indecies. """ -function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}, directed::Bool=is_directed(adj)) +function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}) table = Dict{Tuple{UInt32,UInt32},UInt64}() e = one(UInt64) - if directed - for (i, js) = enumerate(adj) - js = sort(js) - for j = js - table[(i, j)] = e - e += one(UInt64) - end - end - else - for (i, js) = enumerate(adj) - js = sort(js) - js = js[i .≤ js] - for j = js - table[(i, j)] = e - table[(j, i)] = e - e += one(UInt64) - end + for (i, js) = enumerate(adj) + js = sort(js) + for j = js + table[(i, j)] = e + e += one(UInt64) end end table @@ -59,21 +45,30 @@ function edge_index_table(vpair::AbstractVector{<:Tuple}) table end -edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph, fg.directed) +edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph) -Zygote.@nograd edge_index_table +function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) + @assert nv(fg) == size(x, ndims(x)) +end -### TODO move these to GraphSignals ###### -import GraphSignals: FeaturedGraph +function add_self_loops(adjlist::AbstractVector{<:AbstractVector}) + anew = deepcopy(adjlist) + for (i, neigs) in enumerate(anew) + if i ∉ neigs + push!(neigs, i) + end + end + return anew +end -function FeaturedGraph(fg::FeaturedGraph; - nf=node_feature(fg), - ef=edge_feature(fg), - gf=global_feature(fg)) +sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) - return FeaturedGraph(graph(fg); nf, ef, gf) +function sort_edge_index(u, v) + uv = collect(zip(u, v)) + p = sortperm(uv) # isless lexicographically defined for tuples + return u[p], v[p] end -function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) - @assert nv(fg) == size(x, ndims(x)) -end +@non_differentiable accumulated_edges(x...) +@non_differentiable generate_cluster(x...) +@non_differentiable edge_index_table(x...) diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index 386edc94e..e6d79d098 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -8,7 +8,7 @@ adj = [0 1 0 1; 0 1 0 1; 1 0 1 0] -fg = FeaturedGraph(adj) +fg = FeaturedGraph(adj, graph_type=GRAPH_T) @testset "cuda/conv" begin @testset "GCNConv" begin diff --git a/test/cuda/featured_graph.jl b/test/cuda/featured_graph.jl new file mode 100644 index 000000000..5cb0dae61 --- /dev/null +++ b/test/cuda/featured_graph.jl @@ -0,0 +1,55 @@ +@testset "cuda/featured graph" begin + s = [1,1,2,3,4,5,5,5] + t = [2,5,3,2,1,4,3,1] + s, t = [s; t], [t; s] #symmetrize + fg = FeaturedGraph(s, t, graph_type=GRAPH_T) + fg_gpu = fg |> gpu + + @testset "functor" begin + s_cpu, t_cpu = edge_index(fg) + s_gpu, t_gpu = edge_index(fg_gpu) + @test s_gpu isa CuVector{Int} + @test Array(s_gpu) == s_cpu + @test t_gpu isa CuVector{Int} + @test Array(t_gpu) == t_cpu + end + + @testset "adjacency_matrix" begin + function test_adj() + mat = adjacency_matrix(fg) + mat_gpu = adjacency_matrix(fg_gpu) + @test mat_gpu isa CuMatrix{Int} + true + end + + if GRAPH_T == :coo + # See https://github.com/JuliaGPU/CUDA.jl/pull/1093 + @test_broken test_adj() + else + test_adj() + end + end + + @testset "normalized_laplacian" begin + function test_normlapl() + mat = normalized_laplacian(fg) + mat_gpu = normalized_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + true + end + if GRAPH_T == :coo + @test_broken test_normlapl() + else + test_normlapl() + end + end + + @testset "scaled_laplacian" begin + @test_broken begin + mat = scaled_laplacian(fg) + mat_gpu = scaled_laplacian(fg_gpu) + @test mat_gpu isa CuMatrix{Float32} + true + end + end +end \ No newline at end of file diff --git a/test/cuda/msgpass.jl b/test/cuda/msgpass.jl index e372a41b5..18650dda9 100644 --- a/test/cuda/msgpass.jl +++ b/test/cuda/msgpass.jl @@ -20,7 +20,7 @@ GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j GeometricFlux.update(::NewCudaLayer, m, x) = m X = rand(T, in_channel, N) |> gpu -fg = FeaturedGraph(adj, nf=X) +fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) l = NewCudaLayer(out_channel, in_channel) |> gpu @testset "cuda/msgpass" begin diff --git a/test/featured_graph.jl b/test/featured_graph.jl new file mode 100644 index 000000000..b4ba80a0f --- /dev/null +++ b/test/featured_graph.jl @@ -0,0 +1,104 @@ +@testset "FeaturedGraph" begin + @testset "symmetric graph" begin + s = [1, 1, 2, 2, 3, 3, 4, 4] + t = [2, 4, 1, 3, 2, 4, 1, 3] + adj_mat = [0 1 0 1 + 1 0 1 0 + 0 1 0 1 + 1 0 1 0] + adj_list_out = [[2,4], [1,3], [2,4], [1,3]] + adj_list_in = [[2,4], [1,3], [2,4], [1,3]] + + # core functionality + fg = FeaturedGraph(s, t; graph_type=GRAPH_T) + @test fg.num_edges == 8 + @test fg.num_nodes == 4 + @test collect(edges(fg)) |> sort == collect(zip(s, t)) |> sort + @test sort(outneighbors(fg, 1)) == [2, 4] + @test sort(inneighbors(fg, 1)) == [2, 4] + @test is_directed(fg) == true + s1, t1 = sort_edge_index(edge_index(fg)) + @test s1 == s + @test t1 == t + + # adjacency + @test adjacency_matrix(fg) == adj_mat + @test adjacency_matrix(fg; dir=:in) == adj_mat + @test adjacency_matrix(fg; dir=:out) == adj_mat + @test sort.(adjacency_list(fg; dir=:in)) == adj_list_in + @test sort.(adjacency_list(fg; dir=:out)) == adj_list_out + + @testset "constructors" begin + fg = FeaturedGraph(adj_mat; graph_type=GRAPH_T) + adjacency_matrix(fg; dir=:out) == adj_mat + adjacency_matrix(fg; dir=:in) == adj_mat + end + + @testset "degree" begin + fg = FeaturedGraph(adj_mat; graph_type=GRAPH_T) + @test degree(fg, dir=:out) == vec(sum(adj_mat, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat, dims=1)) + end + end + + @testset "asymmetric graph" begin + s = [1, 2, 3, 4] + t = [2, 3, 4, 1] + adj_mat_out = [0 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + adj_list_out = [[2], [3], [4], [1]] + + + adj_mat_in = [0 0 0 1 + 1 0 0 0 + 0 1 0 0 + 0 0 1 0] + adj_list_in = [[4], [1], [2], [3]] + + # core functionality + fg = FeaturedGraph(s, t; graph_type=GRAPH_T) + @test fg.num_edges == 4 + @test fg.num_nodes == 4 + @test collect(edges(fg)) |> sort == collect(zip(s, t)) |> sort + @test sort(outneighbors(fg, 1)) == [2] + @test sort(inneighbors(fg, 1)) == [4] + @test is_directed(fg) == true + s1, t1 = sort_edge_index(edge_index(fg)) + @test s1 == s + @test t1 == t + + # adjacency + @test adjacency_matrix(fg) == adj_mat_out + @test adjacency_list(fg) == adj_list_out + @test adjacency_matrix(fg, dir=:out) == adj_mat_out + @test adjacency_list(fg, dir=:out) == adj_list_out + @test adjacency_matrix(fg, dir=:in) == adj_mat_in + @test adjacency_list(fg, dir=:in) == adj_list_in + + @testset "degree" begin + fg = FeaturedGraph(adj_mat_out; graph_type=GRAPH_T) + @test degree(fg, dir=:out) == vec(sum(adj_mat_out, dims=2)) + @test degree(fg, dir=:in) == vec(sum(adj_mat_out, dims=1)) + end + end + + @testset "add self-loops" begin + A = [1 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + A2 = [1 1 0 0 + 0 1 1 0 + 0 0 1 1 + 1 0 0 1] + + fg = FeaturedGraph(A; graph_type=GRAPH_T) + fg2 = add_self_loops(fg) + @test adjacency_matrix(fg) == A + @test fg.num_edges == sum(A) + @test adjacency_matrix(fg2) == A2 + @test fg2.num_edges == sum(A2) + end +end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 3d7131b0c..27a9f7da1 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -14,7 +14,7 @@ adj_single_vertex = T[0. 0. 0. 1.; 0. 0. 0. 1.; 1. 0. 1. 0.] -fg_single_vertex = FeaturedGraph(adj_single_vertex) +fg_single_vertex = FeaturedGraph(adj_single_vertex, graph_type=GRAPH_T) @testset "layer" begin @@ -25,7 +25,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(fg, in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test graph(gc.fg) === adj + @test adjacency_matrix(gc.fg) == adj Y = gc(X) @test size(Y) == (out_channel, N) @@ -46,20 +46,19 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GCNConv(in_channel=>out_channel) @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test !has_graph(gc.fg) - - fg = FeaturedGraph(adj, nf=X) + + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = gc(fg) @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features - fgt = FeaturedGraph(adj, nf=Xt) + fgt = FeaturedGraph(adj, nf=Xt, graph_type=GRAPH_T) fgt_ = gc(fgt) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight) == size(gc.weight) @@ -81,7 +80,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(fg, in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test graph(cc.fg) === adj + @test adjacency_matrix(cc.fg) == adj @test cc.k == k Y = cc(X) @@ -103,21 +102,20 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) cc = ChebConv(in_channel=>out_channel, k) @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test !has_graph(cc.fg) @test cc.k == k - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = cc(fg) @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError cc(X) # Test with transposed features - fgt = FeaturedGraph(adj, nf=Xt) + fgt = FeaturedGraph(adj, nf=Xt, graph_type=GRAPH_T) fgt_ = cc(fgt) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(cc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), cc)[1] @test size(g.weight) == size(cc.weight) @@ -162,18 +160,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(gc.weight2) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = gc(fg) @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError gc(X) # Test with transposed features - fgt = FeaturedGraph(adj, nf=Xt) + fgt = FeaturedGraph(adj, nf=Xt, graph_type=GRAPH_T) fgt_ = gc(fgt) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1] @test size(g.weight1) == size(gc.weight1) @@ -195,7 +193,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @testset "layer with graph" begin for heads = [1, 2], concat = [true, false], adj_gat in [adj, adj_single_vertex] - fg_gat = FeaturedGraph(adj_gat) + fg_gat = FeaturedGraph(adj_gat, graph_type=GRAPH_T) gat = GATConv(fg_gat, in_channel=>out_channel, heads=heads, concat=concat) if adj_gat == adj @@ -227,24 +225,23 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @testset "layer without graph" begin for heads = [1, 2], concat = [true, false], adj_gat in [adj, adj_single_vertex] - fg_gat = FeaturedGraph(adj_gat, nf=X) + fg_gat = FeaturedGraph(adj_gat, nf=X, graph_type=GRAPH_T) gat = GATConv(in_channel=>out_channel, heads=heads, concat=concat) @test size(gat.weight) == (out_channel * heads, in_channel) @test size(gat.bias) == (out_channel * heads,) @test size(gat.a) == (2*out_channel, heads) - fg_ = gat(fg_gat) Y = node_feature(fg_) @test size(Y) == (concat ? (out_channel*heads, N) : (out_channel, N)) @test_throws MethodError gat(X) # Test with transposed features - fgt = FeaturedGraph(adj_gat, nf=Xt) + fgt = FeaturedGraph(adj_gat, nf=Xt, graph_type=GRAPH_T) fgt_ = gat(fgt) @test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N)) g = Zygote.gradient(x -> sum(node_feature(gat(x))), fg_gat)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg_gat))), gat)[1] @test size(g.weight) == size(gat.weight) @@ -287,18 +284,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) ggc = GatedGraphConv(out_channel, num_layers) @test size(ggc.weight) == (out_channel, out_channel, num_layers) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = ggc(fg) @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ggc(X) # Test with transposed features - fgt = FeaturedGraph(adj, nf=Xt) + fgt = FeaturedGraph(adj, nf=Xt, graph_type=GRAPH_T) fgt_ = ggc(fgt) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ggc(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ggc)[1] @test size(g.weight) == size(ggc.weight) @@ -330,18 +327,18 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @testset "layer without graph" begin ec = EdgeConv(Dense(2*in_channel, out_channel)) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = ec(fg) @test size(node_feature(fg_)) == (out_channel, N) @test_throws MethodError ec(X) # Test with transposed features - fgt = FeaturedGraph(adj, nf=Xt) + fgt = FeaturedGraph(adj, nf=Xt, graph_type=GRAPH_T) fgt_ = ec(fgt) @test size(node_feature(fgt_)) == (out_channel, N) g = Zygote.gradient(x -> sum(node_feature(ec(x))), fg)[1] - @test size(g[].nf) == size(X) + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(fg))), ec)[1] @test size(g.nn.weight) == size(ec.nn.weight) @@ -359,7 +356,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) gc = GINConv(FeaturedGraph(adj), nn, eps=eps) @test size(gc.nn.layers[1].weight) == (out_channel, in_channel) @test size(gc.nn.layers[1].bias) == (out_channel, ) - @test graph(gc.fg) === adj + @test adjacency_matrix(gc.fg) == adj Y = gc(FeaturedGraph(adj, nf=X)) @test size(node_feature(Y)) == (out_channel, N) @@ -368,9 +365,9 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) Y = gc(FeaturedGraph(adj, nf=Xt)) @test size(node_feature(Y)) == (out_channel, N) - g = Zygote.gradient(x -> sum(node_feature(gc(x))), - FeaturedGraph(adj, nf=X))[1] - @test size(g.x.nf) == size(X) + fg = FeaturedGraph(adj, nf=X) + g = Zygote.gradient(fg -> sum(node_feature(gc(fg))), fg)[1] + @test size(g.nf) == size(X) g = Zygote.gradient(model -> sum(node_feature(model(FeaturedGraph(adj, nf=X)))), gc)[1] diff --git a/test/layers/gn.jl b/test/layers/gn.jl index 0c32ec1f9..fdd3e4288 100644 --- a/test/layers/gn.jl +++ b/test/layers/gn.jl @@ -1,74 +1,74 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 7 -T = Float32 - -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] - -struct NewGNLayer <: GraphNet -end - -V = rand(T, in_channel, num_V) -E = rand(T, in_channel, 2num_E) -u = rand(T, in_channel) - @testset "gn" begin - l = NewGNLayer() + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 7 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + struct NewGNLayer{G} <: GraphNet end + NewGNLayer() = NewGNLayer{GRAPH_T}() + + V = rand(T, in_channel, num_V) + E = rand(T, in_channel, 2num_E) + u = rand(T, in_channel) @testset "without aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg) + (l::NewGNLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg) - fg = FeaturedGraph(adj, nf=V) + fg = FeaturedGraph(adj, nf=V, graph_type=GRAPH_T) + l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) - @test size(edge_feature(fg_)) == (0, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test all(edge_feature(fg_) .== fill(nothing, 2*num_E)) + @test global_feature(fg_) === nothing end @testset "with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) + (l::NewGNLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +) - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) + fg = FeaturedGraph(adj, nf=V, ef=E, gf=nothing, graph_type=GRAPH_T) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (in_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test global_feature(fg_) === nothing end - GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = rand(T, out_channel) @testset "update edge with neighbor aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +) + (l::NewGNLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +) + GeometricFlux.update_edge(l::NewGNLayer{GRAPH_T}, e, vi, vj, u) = rand(T, out_channel) + - fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0)) + fg = FeaturedGraph(adj, nf=V, ef=E, gf=nothing, graph_type=GRAPH_T) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) - @test size(global_feature(fg_)) == (0,) + @test global_feature(fg_) === nothing end - GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = rand(T, out_channel) @testset "update edge/vertex with all aggregation" begin - (l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +) + (l::NewGNLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +, +, +) + GeometricFlux.update_vertex(l::NewGNLayer{GRAPH_T}, ē, vi, u) = rand(T, out_channel) - fg = FeaturedGraph(adj, nf=V, ef=E, gf=u) + fg = FeaturedGraph(adj, nf=V, ef=E, gf=u, graph_type=GRAPH_T) l = NewGNLayer() fg_ = l(fg) - @test graph(fg_) === adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) @test size(global_feature(fg_)) == (in_channel,) diff --git a/test/layers/misc.jl b/test/layers/misc.jl index 58ba83029..9ddcd0e01 100644 --- a/test/layers/misc.jl +++ b/test/layers/misc.jl @@ -11,12 +11,12 @@ ef = rand(5, E) gf = rand(7) - fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf) + fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf, graph_type=GRAPH_T) layer = bypass_graph(x -> x .+ 1., x -> x .+ 2., x -> x .+ 3.) fg_ = layer(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test node_feature(fg_) == nf .+ 1. @test edge_feature(fg_) == ef .+ 2. @test global_feature(fg_) == gf .+ 3. diff --git a/test/layers/msgpass.jl b/test/layers/msgpass.jl index 9e9c86caa..1c4feb642 100644 --- a/test/layers/msgpass.jl +++ b/test/layers/msgpass.jl @@ -1,53 +1,61 @@ -in_channel = 10 -out_channel = 5 -num_V = 6 -num_E = 7 -T = Float32 - -adj = T[0. 1. 0. 0. 0. 0.; - 1. 0. 0. 1. 1. 1.; - 0. 0. 0. 0. 0. 1.; - 0. 1. 0. 0. 1. 0.; - 0. 1. 0. 1. 0. 1.; - 0. 1. 1. 0. 1. 0.] - -struct NewLayer <: MessagePassing - weight -end -NewLayer(m, n) = NewLayer(randn(T, m,n)) - -(l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +) - -X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V)) -fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E)) - -l = NewLayer(out_channel, in_channel) - @testset "msgpass" begin + in_channel = 10 + out_channel = 5 + num_V = 6 + num_E = 7 + T = Float32 + + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] + + struct NewLayer{G} <: MessagePassing + weight + end + NewLayer(m, n) = NewLayer{GRAPH_T}(randn(T, m,n)) + + X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V)) + @testset "no message or update" begin + (l::NewLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +) + + fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E), graph_type=GRAPH_T) + l = NewLayer(out_channel, in_channel) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (in_channel, num_V) @test size(edge_feature(fg_)) == (in_channel, 2*num_E) @test size(global_feature(fg_)) == (0,) end - GeometricFlux.message(l::NewLayer, x_i, x_j, e_ij) = l.weight * x_j + @testset "message function" begin + (l::NewLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +) + GeometricFlux.message(l::NewLayer{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j + + fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E), graph_type=GRAPH_T) + l = NewLayer(out_channel, in_channel) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) @test size(global_feature(fg_)) == (0,) end - GeometricFlux.update(l::NewLayer, m, x) = l.weight * x + m @testset "message and update" begin + (l::NewLayer{GRAPH_T})(fg) = GeometricFlux.propagate(l, fg, +) + GeometricFlux.update(l::NewLayer{GRAPH_T}, m, x) = l.weight * x + m + + fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E), graph_type=GRAPH_T) + l = NewLayer(out_channel, in_channel) fg_ = l(fg) - @test graph(fg_) == adj + @test adjacency_matrix(fg_) == adj @test size(node_feature(fg_)) == (out_channel, num_V) @test size(edge_feature(fg_)) == (out_channel, 2*num_E) @test size(global_feature(fg_)) == (0,) diff --git a/test/models.jl b/test/models.jl index 0e4e62c48..d04b3cb70 100644 --- a/test/models.jl +++ b/test/models.jl @@ -9,7 +9,7 @@ adj = [0. 1. 0. 1.; 0. 1. 0. 1.; 1. 0. 1. 0.] -fg = FeaturedGraph(adj) +fg = FeaturedGraph(adj; graph_type=GRAPH_T) @testset "models" begin @testset "GAE" begin @@ -28,13 +28,13 @@ fg = FeaturedGraph(adj) @test size(Y) == (N, N) X = rand(T, 1, N) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = ipd(fg) Y = node_feature(fg_) @test size(Y) == (N, N) X = rand(T, in_channel, N) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = ipd(fg) Y = node_feature(fg_) @test size(Y) == (N, N) @@ -45,7 +45,7 @@ fg = FeaturedGraph(adj) gc = GCNConv(in_channel=>out_channel) ve = VariationalEncoder(gc, out_channel, z_dim) X = rand(T, in_channel, N) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = ve(fg) Z = node_feature(fg_) @test size(Z) == (z_dim, N) @@ -56,7 +56,7 @@ fg = FeaturedGraph(adj) gc = GCNConv(in_channel=>out_channel) vgae = VGAE(gc, out_channel, z_dim) X = rand(T, in_channel, N) - fg = FeaturedGraph(adj, nf=X) + fg = FeaturedGraph(adj, nf=X, graph_type=GRAPH_T) fg_ = vgae(fg) Y = node_feature(fg_) @test size(Y) == (N, N) diff --git a/test/runtests.jl b/test/runtests.jl index 46f335f15..e99a7ee56 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,23 +1,27 @@ using GeometricFlux using GeometricFlux.Datasets +using GeometricFlux: sort_edge_index using Flux +using CUDA +using Flux: gpu using Flux: @functor using FillArrays -using GraphSignals -using LightGraphs: SimpleGraph, SimpleDiGraph, add_edge!, nv, ne using LinearAlgebra using NNlib -using SparseArrays: SparseMatrixCSC +using LightGraphs using Statistics: mean using Zygote using Test +CUDA.allowscalar(false) cuda_tests = [ + "cuda/featured_graph", # "cuda/conv", # "cuda/msgpass", ] tests = [ + "featured_graph", "layers/gn", "layers/msgpass", "layers/conv", @@ -26,17 +30,18 @@ tests = [ "models", ] -if Flux.use_cuda[] - using CUDA - using Flux: gpu - using NNlibCUDA - append!(tests, cuda_tests) -else - @warn "CUDA unavailable, not testing GPU support" -end +!Flux.use_cuda[] && @warn("CUDA unavailable, not testing GPU support") -@testset "GeometricFlux" begin +# Testing all graph types. :sparse is a bit broken at the moment +@testset "GeometricFlux: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) + global GRAPH_T = graph_type for t in tests include("$(t).jl") end + + if Flux.use_cuda[] && GRAPH_T != :sparse + for t in cuda_tests + include("$(t).jl") + end + end end