diff --git a/Project.toml b/Project.toml index 7275f88b8..3db4503e7 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" @@ -17,6 +18,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -28,12 +30,14 @@ CUDA = "3.3" ChainRulesCore = "1" DataStructures = "0.18" Flux = "0.12.7" +Functors = "0.2" Graphs = "1.4" KrylovKit = "0.5" LearnBase = "0.4, 0.5" MacroTools = "0.5" NNlib = "0.7" NNlibCUDA = "0.1" +Reexport = "1" StatsBase = "0.32, 0.33" TestEnv = "1" julia = "1.6" diff --git a/docs/make.jl b/docs/make.jl index 038ab6f8a..53881728e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,9 @@ using Flux, NNlib, GraphNeuralNetworks, Graphs, SparseArrays using Documenter -DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive=true) +DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, + :(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux); + recursive=true) makedocs(; modules=[GraphNeuralNetworks, NNlib, Flux, Graphs, SparseArrays], diff --git a/docs/src/api/gnngraph.md b/docs/src/api/gnngraph.md index 9f7cefde9..0483c1f78 100644 --- a/docs/src/api/gnngraph.md +++ b/docs/src/api/gnngraph.md @@ -4,7 +4,12 @@ CurrentModule = GraphNeuralNetworks # GNNGraph -Documentation page for the graph type `GNNGraph` provided GraphNeuralNetworks.jl and its related methods. +Documentation page for the graph type `GNNGraph` provided by GraphNeuralNetworks.jl and related methods. + +```@contents +Pages = ["gnngraph.md"] +Depth = 5 +``` Besides the methods documented here, one can rely on the large set of functionalities given by [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl) @@ -17,16 +22,41 @@ Order = [:type, :function] Pages = ["gnngraph.md"] ``` -## Docs +## GNNGraph type ```@autodocs -Modules = [GraphNeuralNetworks] +Modules = [GraphNeuralNetworks.GNNGraphs] Pages = ["gnngraph.jl"] Private = false ``` +## Query + +```@autodocs +Modules = [GraphNeuralNetworks.GNNGraphs] +Pages = ["query.jl"] +Private = false +``` + ```@docs -Flux.batch -SparseArrays.blockdiag Graphs.adjacency_matrix +Graphs.degree +Graphs.outneighbors +Graphs.inneighbors +``` + +## Transform + +```@autodocs +Modules = [GraphNeuralNetworks.GNNGraphs] +Pages = ["transform.jl"] +Private = false +``` + +## Generate + +```@autodocs +Modules = [GraphNeuralNetworks.GNNGraphs] +Pages = ["generate.jl"] +Private = false ``` diff --git a/docs/src/gnngraph.md b/docs/src/gnngraph.md index bb0cabb6b..c1bfc2d80 100644 --- a/docs/src/gnngraph.md +++ b/docs/src/gnngraph.md @@ -15,10 +15,13 @@ A GNNGraph can be created from several different data sources encoding the graph using GraphNeuralNetworks, Graphs, SparseArrays -# Construct GNNGraph from From Graphs's graph +# Construct a GNNGraph from from a Graphs.jl's graph lg = erdos_renyi(10, 30) g = GNNGraph(lg) +# Same as above using convenience method rand_graph +g = rand_graph(10, 30) + # From an adjacency matrix A = sprand(10, 10, 0.3) g = GNNGraph(A) @@ -33,7 +36,7 @@ target = [2,3,1,3,1,2,4,3] g = GNNGraph(source, target) ``` -See also the related methods [`adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref). +See also the related methods [`Graphs.adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref). ## Basic Queries @@ -123,21 +126,21 @@ for g in train_loader ..... end -# Access the nodes' graph memberships through -gall.graph_indicator +# Access the nodes' graph memberships +graph_indicator(gall) ``` ## Graph Manipulation ```julia g′ = add_self_loops(g) - g′ = remove_self_loops(g) +g′ = add_edges(g, [1, 2], [2, 3]) # add edges 1->2 and 2->3 ``` ## JuliaGraphs ecosystem integration -Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs. +Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl. ```julia @assert Graphs.isdirected(g) diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl new file mode 100644 index 000000000..d188359bd --- /dev/null +++ b/src/GNNGraphs/GNNGraphs.jl @@ -0,0 +1,43 @@ +module GNNGraphs + +using SparseArrays +using Functors: @functor +using CUDA +import Graphs +using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree +import Flux +using Flux: batch +import NNlib +import LearnBase +import StatsBase +using LearnBase: getobs +import KrylovKit +using ChainRulesCore +using LinearAlgebra, Random + +include("gnngraph.jl") +export GNNGraph, node_features, edge_features, graph_features + +include("query.jl") +export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian, + graph_indicator + +include("transform.jl") +export add_edges, add_self_loops, remove_self_loops, getgraph + +include("generate.jl") +export rand_graph + + +include("convert.jl") +include("utils.jl") + +export + # from Graphs + adjacency_matrix, degree, outneighbors, inneighbors, + # from SparseArrays + sprand, sparse, blockdiag, + # from Flux + batch + +end #module diff --git a/src/graph_conversions.jl b/src/GNNGraphs/convert.jl similarity index 100% rename from src/graph_conversions.jl rename to src/GNNGraphs/convert.jl diff --git a/src/GNNGraphs/generate.jl b/src/GNNGraphs/generate.jl new file mode 100644 index 000000000..67cfcac30 --- /dev/null +++ b/src/GNNGraphs/generate.jl @@ -0,0 +1,14 @@ +""" + rand_graph(n, m; directed=false, kws...) + +Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes. + +If `directed=false` the output will contain `2m` edges: +the reverse edge of each edge will be present. +If `directed=true` instead, `m` unrelated edges are generated. + +Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor. +""" +function rand_graph(n::Integer, m::Integer; directed=false, kws...) + return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...) +end diff --git a/src/GNNGraphs/gnngraph.jl b/src/GNNGraphs/gnngraph.jl new file mode 100644 index 000000000..8a36b7396 --- /dev/null +++ b/src/GNNGraphs/gnngraph.jl @@ -0,0 +1,223 @@ +#=================================== +Define GNNGraph type as a subtype of Graphs' AbstractGraph. +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types +=============================================# + +const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V} +const ADJLIST_T = AbstractVector{T} where T <: AbstractVector +const ADJMAT_T = AbstractMatrix +const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T +const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} + + +""" + GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) + GNNGraph(g::GNNGraph; [ndata, edata, gdata]) + +A type representing a graph structure that also stores +feature arrays associated to nodes, edges, and the graph itself. + +A `GNNGraph` can be constructed out of different `data` objects +expressing the connections inside the graph. The internal representation type +is determined by `graph_type`. + +When constructed from another `GNNGraph`, the internal graph representation +is preserved and shared. The node/edge/graph features are retained +as well, unless explicitely set by the keyword arguments +`ndata`, `edata`, and `gdata`. + +A `GNNGraph` can also represent multiple graphs batched togheter +(see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). +The field `g.graph_indicator` contains the graph membership +of each node. + +`GNNGraph`s are always directed graphs, therefore each edge is defined +by a source node and a target node (see [`edge_index`](@ref)). +Self loops (edges connecting a node to itself) and multiple edges +(more than one edge between the same pair of nodes) are supported. + +A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most +functionality from that library. + +# Arguments + +- `data`: Some data representing the graph topology. Possible type are + - An adjacency matrix + - An adjacency list. + - A tuple containing the source and target vectors (COO representation) + - A Graphs' graph. +- `graph_type`: A keyword argument that specifies + the underlying representation used by the GNNGraph. + Currently supported values are + - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge + connects the node `source[k]` to node `target[k]`. + Optionally, also edge weights can be given: `(source, target, weights)`. + - `:sparse`. A sparse adjacency matrix representation. + - `:dense`. A dense adjacency matrix representation. + Default `:coo`. +- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. + Possible values are `:out` and `:in`. Default `:out`. +- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. +- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`. +- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`. +- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`. +- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`. + +# Usage. + +```julia +using Flux, GraphNeuralNetworks + +# Construct from adjacency list representation +data = [[2,3], [1,4,5], [1], [2,5], [2,4]] +g = GNNGraph(data) + +# Number of nodes, edges, and batched graphs +g.num_nodes # 5 +g.num_edges # 10 +g.num_graphs # 1 + +# Same graph in COO representation +s = [1,1,2,2,2,3,4,4,5,5] +t = [2,3,1,4,5,3,2,5,2,4] +g = GNNGraph(s, t) + +# From a Graphs' graph +g = GNNGraph(erdos_renyi(100, 20)) + +# Add 2 node feature arrays +g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes))) + +# Add node features and edge features with default names `x` and `e` +g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges)) + +g.ndata.x +g.ndata.e + +# Send to gpu +g = g |> gpu + +# Collect edges' source and target nodes. +# Both source and target are vectors of length num_edges +source, target = edge_index(g) +``` +""" +struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractGraph{Int} + graph::T + num_nodes::Int + num_edges::Int + num_graphs::Int + graph_indicator # vector of ints or nothing + ndata::NamedTuple + edata::NamedTuple + gdata::NamedTuple +end + +@functor GNNGraph + +function GNNGraph(data; + num_nodes = nothing, + graph_indicator = nothing, + graph_type = :coo, + dir = :out, + ndata = (;), + edata = (;), + gdata = (;), + ) + + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(data; dir) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(data; dir) + end + + num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 + + ndata = normalize_graphdata(ndata, default_name=:x, n=num_nodes) + edata = normalize_graphdata(edata, default_name=:e, n=num_edges, duplicate_if_needed=true) + gdata = normalize_graphdata(gdata, default_name=:u, n=num_graphs) + + GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) +end + +# COO convenience constructors +GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...) +GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) + +# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...) + +function GNNGraph(g::AbstractGraph; kws...) + s = Graphs.src.(Graphs.edges(g)) + t = Graphs.dst.(Graphs.edges(g)) + if !Graphs.is_directed(g) + # add reverse edges since GNNGraph is directed + s, t = [s; t], [t; s] + end + GNNGraph((s, t); num_nodes=Graphs.nv(g), kws...) +end + + +function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, graph_type=nothing) + + ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes) + edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, duplicate_if_needed=true) + gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs) + + if !isnothing(graph_type) + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(g.graph) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(g.graph) + end + @assert num_nodes == g.num_nodes + @assert num_edges == g.num_edges + else + graph = g.graph + end + GNNGraph(graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, gdata) +end + +function Base.show(io::IO, g::GNNGraph) + println(io, "GNNGraph: + num_nodes = $(g.num_nodes) + num_edges = $(g.num_edges) + num_graphs = $(g.num_graphs)") + println(io, " ndata:") + for k in keys(g.ndata) + println(io, " $k => $(size(g.ndata[k]))") + end + println(io, " edata:") + for k in keys(g.edata) + println(io, " $k => $(size(g.edata[k]))") + end + println(io, " gdata:") + for k in keys(g.gdata) + println(io, " $k => $(size(g.gdata[k]))") + end +end + +### StatsBase/LearnBase compatibility +StatsBase.nobs(g::GNNGraph) = g.num_graphs +LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i) + +# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683 +Flux.Data._nobs(g::GNNGraph) = g.num_graphs +Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i) + +######################### +Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1))) diff --git a/src/GNNGraphs/query.jl b/src/GNNGraphs/query.jl new file mode 100644 index 000000000..6cf476412 --- /dev/null +++ b/src/GNNGraphs/query.jl @@ -0,0 +1,238 @@ + +""" + edge_index(g::GNNGraph) + +Return a tuple containing two vectors, respectively storing +the source and target nodes for each edges in `g`. + +```julia +s, t = edge_index(g) +``` +""" +edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2] + +edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][1:2] + +edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] + +edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3] + +Graphs.edges(g::GNNGraph) = zip(edge_index(g)...) + +Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int} + +function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer) + s, t = edge_index(g) + return any((s .== i) .& (t .== j)) +end + +Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0 + +Graphs.nv(g::GNNGraph) = g.num_nodes +Graphs.ne(g::GNNGraph) = g.num_edges +Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes +Graphs.vertices(g::GNNGraph) = 1:g.num_nodes + +function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return t[s .== i] +end + +function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[i,:]) +end + +function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return s[t .== i] +end + +function Graphs.inneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[:,i]) +end + +Graphs.is_directed(::GNNGraph) = true +Graphs.is_directed(::Type{<:GNNGraph}) = true + +""" + adjacency_list(g; dir=:out) + +Return the adjacency list representation (a vector of vectors) +of the graph `g`. + +Calling `a` the adjacency list, if `dir=:out` +`a[i]` will contain the neighbors of node `i` through +outgoing edges. If `dir=:in`, it will contain neighbors from +incoming edges instead. +""" +function adjacency_list(g::GNNGraph; dir=:out) + @assert dir ∈ [:out, :in] + fneighs = dir == :out ? outneighbors : inneighbors + return [fneighs(g, i) for i in 1:g.num_nodes] +end + +function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out) + if g.graph[1] isa CuVector + # TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152 + A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes) + else + A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes) + end + @assert size(A) == (n, n) + return dir == :out ? A : A' +end + +function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g.graph); dir=:out) + @assert dir ∈ [:in, :out] + A = g.graph + A = T != eltype(A) ? T.(A) : A + return dir == :out ? A : A' +end + +function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out) + s, t = edge_index(g) + T = isnothing(T) ? eltype(s) : T + degs = fill!(similar(s, T, g.num_nodes), 0) + src = 1 + if dir ∈ [:out, :both] + NNlib.scatter!(+, degs, src, s) + end + if dir ∈ [:in, :both] + NNlib.scatter!(+, degs, src, t) + end + return degs +end + +function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out) + @assert dir ∈ (:in, :out) + A = adjacency_matrix(g, T) + return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1)) +end + +function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=Int; dir::Symbol=:out) + A = adjacency_matrix(g, T; dir=dir) + D = Diagonal(vec(sum(A; dims=2))) + return D - A +end + + +""" + normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `add_self_loops`: add self-loops while calculating the matrix. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(g::GNNGraph, T::DataType=Float32; + add_self_loops::Bool=false, dir::Symbol=:out) + Ã = normalized_adjacency(g, T; dir, add_self_loops) + return I - Ã +end + +function normalized_adjacency(g::GNNGraph, T::DataType=Float32; + add_self_loops::Bool=false, dir::Symbol=:out) + A = adjacency_matrix(g, T; dir=dir) + if add_self_loops + A = A + I + end + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g, T=Float32; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(g::GNNGraph, T::DataType=Float32; dir=:out) + L = normalized_laplacian(g, T) + @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" + λmax = _eigmax(L) + return 2 / λmax * L - I +end + +# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays +function _eigmax(A) + x0 = _rand_dense_vector(A) + KrylovKit.eigsolve(Symmetric(A), x0, 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available +end + +_rand_dense_vector(A::AbstractMatrix{T}) where T = randn(float(T), size(A, 1)) +_rand_dense_vector(A::CUMAT_T)= CUDA.randn(size(A, 1)) + +# Eigenvalues for cuarray don't seem to be well supported. +# https://github.com/JuliaGPU/CUDA.jl/issues/154 +# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 + +""" + graph_indicator(g) + +Return a vector containing the graph membership +(an integer from `1` to `g.num_graphs`) of each node in the graph. +""" +function graph_indicator(g; edges=false) + if isnothing(g.graph_indicator) + gi = ones_like(edge_index(g)[1], Int, g.num_nodes) + else + gi = g.graph_indicator + end + if edges + s, t = edge_index(g) + return gi[s] + else + return gi + end +end + +function node_features(g::GNNGraph) + if isempty(g.ndata) + return nothing + elseif length(g.ndata) > 1 + @error "Multiple feature arrays, access directly through `g.ndata`" + else + return g.ndata[1] + end +end + +function edge_features(g::GNNGraph) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple feature arrays, access directly through `g.edata`" + else + return g.edata[1] + end +end + +function graph_features(g::GNNGraph) + if isempty(g.gdata) + return nothing + elseif length(g.gdata) > 1 + @error "Multiple feature arrays, access directly through `g.gdata`" + else + return g.gdata[1] + end +end + +@non_differentiable normalized_laplacian(x...) +@non_differentiable normalized_adjacency(x...) +@non_differentiable scaled_laplacian(x...) +@non_differentiable adjacency_matrix(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable degree(x...) +@non_differentiable graph_indicator(x...) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl new file mode 100644 index 000000000..1d9b783d7 --- /dev/null +++ b/src/GNNGraphs/transform.jl @@ -0,0 +1,211 @@ + +""" + add_self_loops(g::GNNGraph) + +Return a graph with the same features as `g` +but also adding edges connecting the nodes to themselves. + +Nodes with already existing +self-loops will obtain a second self-loop. +""" +function add_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + @assert g.edata === (;) + @assert edge_weight(g) === nothing + n = g.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + + GNNGraph((s, t, nothing), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +function add_self_loops(g::GNNGraph{<:ADJMAT_T}) + A = g.graph + @assert g.edata === (;) + num_edges = g.num_edges + g.num_nodes + A = A + I + GNNGraph(A, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + + +function remove_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + # TODO remove these constraints + @assert g.edata === (;) + @assert edge_weight(g) === nothing + + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + + GNNGraph((s, t, nothing), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +""" + add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata]) + +Add to graph `g` the edges with source nodes `s` and target nodes `t`. + +""" +function add_edges(g::GNNGraph{<:COO_T}, + snew::AbstractVector{<:Integer}, + tnew::AbstractVector{<:Integer}; + edata=nothing) + + @assert length(snew) == length(tnew) + # TODO remove this constraint + @assert edge_weight(g) === nothing + + edata = normalize_graphdata(edata, default_name=:e, n=length(snew)) + edata = cat_features(g.edata, edata) + + s, t = edge_index(g) + s = [s; snew] + t = [t; tnew] + + GNNGraph((s, t, nothing), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) + nv1, nv2 = g1.num_nodes, g2.num_nodes + if g1.graph isa COO_T + s1, t1 = edge_index(g1) + s2, t2 = edge_index(g2) + s = vcat(s1, nv1 .+ s2) + t = vcat(t1, nv1 .+ t2) + w = cat_features(edge_weight(g1), edge_weight(g2)) + graph = (s, t, w) + ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator + elseif g1.graph isa ADJMAT_T + graph = blockdiag(g1.graph, g2.graph) + ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, Int, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, Int, nv2) : g2.graph_indicator + end + graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) + + GNNGraph(graph, + nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs, + graph_indicator, + cat_features(g1.ndata, g2.ndata), + cat_features(g1.edata, g2.edata), + cat_features(g1.gdata, g2.gdata)) +end + +# PIRACY +function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) + m1, n1 = size(A1) + @assert m1 == n1 + m2, n2 = size(A2) + @assert m2 == n2 + O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0) + O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0) + return [A1 O1 + O2 A2] +end + +### Cat public interfaces ############# + +""" + blockdiag(xs::GNNGraph...) + +Equivalent to [`Flux.batch`](@ref). +""" +function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) + g = g1 + for go in gothers + g = blockdiag(g, go) + end + return g +end + +""" + batch(xs::Vector{<:GNNGraph}) + +Batch together multiple `GNNGraph`s into a single one +containing the total number of original nodes and edges. + +Equivalent to [`SparseArrays.blockdiag`](@ref). +""" +Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...) + + +""" + getgraph(g::GNNGraph, i; nmap=false) + +Return the subgraph of `g` induced by those nodes `j` +for which `g.graph_indicator[j] == i` or, +if `i` is a collection, `g.graph_indicator[j] ∈ i`. +In other words, it extract the component graphs from a batched graph. + +If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones. +The node `i` in the subgraph will correspond to the node `v[i]` in `g`. +""" +getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...) + +function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false) + if g.graph_indicator === nothing + @assert i == [1] + if nmap + return g, 1:g.num_nodes + else + return g + end + end + + node_mask = g.graph_indicator .∈ Ref(i) + + nodes = (1:g.num_nodes)[node_mask] + nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes)) + + graphmap = Dict(i => inew for (inew, i) in enumerate(i)) + graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]] + + s, t = edge_index(g) + w = edge_weight(g) + edge_mask = s .∈ Ref(nodes) + + if g.graph isa COO_T + s = [nodemap[i] for i in s[edge_mask]] + t = [nodemap[i] for i in t[edge_mask]] + w = isnothing(w) ? nothing : w[edge_mask] + graph = (s, t, w) + elseif g.graph isa ADJMAT_T + graph = g.graph[nodes, nodes] + end + + ndata = getobs(g.ndata, node_mask) + edata = getobs(g.edata, edge_mask) + gdata = getobs(g.gdata, i) + + num_edges = sum(edge_mask) + num_nodes = length(graph_indicator) + num_graphs = length(i) + + gnew = GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) + + if nmap + return gnew, nodes + else + return gnew + end +end + +@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule diff --git a/src/GNNGraphs/utils.jl b/src/GNNGraphs/utils.jl new file mode 100644 index 000000000..5b9fc92d2 --- /dev/null +++ b/src/GNNGraphs/utils.jl @@ -0,0 +1,72 @@ +function check_num_nodes(g::GNNGraph, x::AbstractArray) + @assert g.num_nodes == size(x, ndims(x)) +end +function check_num_edges(g::GNNGraph, e::AbstractArray) + @assert g.num_edges == size(e, ndims(e)) +end + +sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) + +function sort_edge_index(u, v) + uv = collect(zip(u, v)) + p = sortperm(uv) # isless lexicographically defined for tuples + return u[p], v[p] +end + +cat_features(x1::Nothing, x2::Nothing) = nothing +cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1)) +cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) = + cat(x1, x2, dims=1) + +function cat_features(x1::NamedTuple, x2::NamedTuple) + sort(collect(keys(x1))) == sort(collect(keys(x2))) || @error "cannot concatenate feature data with different keys" + + NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1)) +end + +# Turns generic type into named tuple +normalize_graphdata(data::Nothing; kws...) = NamedTuple() + +normalize_graphdata(data; default_name::Symbol, kws...) = +normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) + +function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed=false) + # This had to workaround two Zygote bugs with NamedTuples + # https://github.com/FluxML/Zygote.jl/issues/1071 + # https://github.com/FluxML/Zygote.jl/issues/1072 + + if n == 1 + # If last array dimension is not 1, add a new dimension. + # This is mostly usefule to reshape globale feature vectors + # of size D to Dx1 matrices. + function unsqz(v) + if v isa AbstractArray && size(v)[end] != 1 + v = reshape(v, size(v)..., 1) + end + v + end + + data = NamedTuple{keys(data)}(unsqz.(values(data))) + end + + sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data) + if duplicate_if_needed + # Used to copy edge features on reverse edges + @assert all(s -> s == 0 || s == n || s == n÷2, sz) + + function duplicate(v) + if v isa AbstractArray && size(v)[end] == n÷2 + v = cat(v, v, dims=ndims(v)) + end + v + end + data = NamedTuple{keys(data)}(duplicate.(values(data))) + else + @assert all(s -> s == 0 || s == n, sz) + end + return data +end + +ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1) +ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz) +ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 73d057d59..f9c19c85f 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -2,36 +2,24 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random -using SparseArrays -import KrylovKit using Base: tail using CUDA using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch using MacroTools: @forward -import LearnBase -import StatsBase -using LearnBase: getobs using NNlib, NNlibCUDA using NNlib: scatter, gather using ChainRulesCore -import Graphs -using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree +using Reexport -export - # gnngraph - GNNGraph, - edge_index, - node_features, edge_features, graph_features, - adjacency_list, normalized_laplacian, scaled_laplacian, - add_self_loops, remove_self_loops, - getgraph, +using SparseArrays, Graphs # not needed but if removed Documenter will complain - # from Graphs - adjacency_matrix, - # from SparseArrays - sprand, sparse, blockdiag, +include("GNNGraphs/GNNGraphs.jl") +@reexport using .GNNGraphs +using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, + check_num_nodes, check_num_edges +export # utils reduce_nodes, reduce_edges, softmax_nodes, softmax_edges, @@ -67,8 +55,6 @@ export topk_index -include("gnngraph.jl") -include("graph_conversions.jl") include("utils.jl") include("layers/basic.jl") include("layers/conv.jl") diff --git a/src/gnngraph.jl b/src/gnngraph.jl deleted file mode 100644 index e7e47ac30..000000000 --- a/src/gnngraph.jl +++ /dev/null @@ -1,621 +0,0 @@ -#=================================== -Define GNNGraph type as a subtype of Graphs' AbstractGraph. -For the core methods to be implemented by any AbstractGraph, see -https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type -https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types -=============================================# - -const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V} -const ADJLIST_T = AbstractVector{T} where T <: AbstractVector -const ADJMAT_T = AbstractMatrix -const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T -const CUMAT_T = Union{AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} - -""" - GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) - GNNGraph(g::GNNGraph; [ndata, edata, gdata]) - -A type representing a graph structure that also stores -feature arrays associated to nodes, edges, and the graph itself. - -A `GNNGraph` can be constructed out of different `data` objects -expressing the connections inside the graph. The internal representation type -is determined by `graph_type`. - -When constructed from another `GNNGraph`, the internal graph representation -is preserved and shared. The node/edge/graph features are retained -as well, unless explicitely set by the keyword arguments -`ndata`, `edata`, and `gdata`. - -A `GNNGraph` can also represent multiple graphs batched togheter -(see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). -The field `g.graph_indicator` contains the graph membership -of each node. - -`GNNGraph`s are always directed graphs, therefore each edge is defined -by a source node and a target node (see [`edge_index`](@ref)). -Self loops (edges connecting a node to itself) and multiple edges -(more than one edge between the same pair of nodes) are supported. - -A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most -functionality from that library. - -# Arguments - -- `data`: Some data representing the graph topology. Possible type are - - An adjacency matrix - - An adjacency list. - - A tuple containing the source and target vectors (COO representation) - - A Graphs' graph. -- `graph_type`: A keyword argument that specifies - the underlying representation used by the GNNGraph. - Currently supported values are - - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge - connects the node `source[k]` to node `target[k]`. - Optionally, also edge weights can be given: `(source, target, weights)`. - - `:sparse`. A sparse adjacency matrix representation. - - `:dense`. A dense adjacency matrix representation. - Default `:coo`. -- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. - Possible values are `:out` and `:in`. Default `:out`. -- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. -- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`. -- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`. -- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`. -- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`. - -# Usage. - -```julia -using Flux, GraphNeuralNetworks - -# Construct from adjacency list representation -data = [[2,3], [1,4,5], [1], [2,5], [2,4]] -g = GNNGraph(data) - -# Number of nodes, edges, and batched graphs -g.num_nodes # 5 -g.num_edges # 10 -g.num_graphs # 1 - -# Same graph in COO representation -s = [1,1,2,2,2,3,4,4,5,5] -t = [2,3,1,4,5,3,2,5,2,4] -g = GNNGraph(s, t) - -# From a Graphs' graph -g = GNNGraph(erdos_renyi(100, 20)) - -# Add 2 node feature arrays -g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes))) - -# Add node features and edge features with default names `x` and `e` -g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges)) - -g.ndata.x -g.ndata.e - -# Send to gpu -g = g |> gpu - -# Collect edges' source and target nodes. -# Both source and target are vectors of length num_edges -source, target = edge_index(g) -``` -""" -struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractGraph{Int} - graph::T - num_nodes::Int - num_edges::Int - num_graphs::Int - graph_indicator # vector of ints or nothing - ndata::NamedTuple - edata::NamedTuple - gdata::NamedTuple -end - -@functor GNNGraph - -function GNNGraph(data; - num_nodes = nothing, - graph_indicator = nothing, - graph_type = :coo, - dir = :out, - ndata = (;), - edata = (;), - gdata = (;), - ) - - @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" - @assert dir ∈ [:in, :out] - - if graph_type == :coo - graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) - elseif graph_type == :dense - graph, num_nodes, num_edges = to_dense(data; dir) - elseif graph_type == :sparse - graph, num_nodes, num_edges = to_sparse(data; dir) - end - - num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 - - ndata = normalize_graphdata(ndata, default_name=:x, n=num_nodes) - edata = normalize_graphdata(edata, default_name=:e, n=num_edges, duplicate_if_needed=true) - gdata = normalize_graphdata(gdata, default_name=:u, n=num_graphs) - - GNNGraph(graph, - num_nodes, num_edges, num_graphs, - graph_indicator, - ndata, edata, gdata) -end - -# COO convenience constructors -GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...) -GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) - -# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...) - -function GNNGraph(g::AbstractGraph; kws...) - s = Graphs.src.(Graphs.edges(g)) - t = Graphs.dst.(Graphs.edges(g)) - if !Graphs.is_directed(g) - # add reverse edges since GNNGraph are directed - s, t = [s; t], [t; s] - end - GNNGraph((s, t); num_nodes=Graphs.nv(g), kws...) -end - - -function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, graph_type=nothing) - - ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes) - edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, duplicate_if_needed=true) - gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs) - - if !isnothing(graph_type) - if graph_type == :coo - graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes) - elseif graph_type == :dense - graph, num_nodes, num_edges = to_dense(g.graph) - elseif graph_type == :sparse - graph, num_nodes, num_edges = to_sparse(g.graph) - end - @assert num_nodes == g.num_nodes - @assert num_edges == g.num_edges - else - graph = g.graph - end - GNNGraph(graph, - g.num_nodes, g.num_edges, g.num_graphs, - g.graph_indicator, - ndata, edata, gdata) -end - -function Base.show(io::IO, g::GNNGraph) - println(io, "GNNGraph: - num_nodes = $(g.num_nodes) - num_edges = $(g.num_edges) - num_graphs = $(g.num_graphs)") - println(io, " ndata:") - for k in keys(g.ndata) - println(io, " $k => $(size(g.ndata[k]))") - end - println(io, " edata:") - for k in keys(g.edata) - println(io, " $k => $(size(g.edata[k]))") - end - println(io, " gdata:") - for k in keys(g.gdata) - println(io, " $k => $(size(g.gdata[k]))") - end -end - -""" - edge_index(g::GNNGraph) - -Return a tuple containing two vectors, respectively storing -the source and target nodes for each edges in `g`. - -```julia -s, t = edge_index(g) -``` -""" -edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2] - -edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][1:2] - -edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] - -edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3] - -Graphs.edges(g::GNNGraph) = zip(edge_index(g)...) - -Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int} - -function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer) - s, t = edge_index(g) - return any((s .== i) .& (t .== j)) -end - -Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0 - -Graphs.nv(g::GNNGraph) = g.num_nodes -Graphs.ne(g::GNNGraph) = g.num_edges -Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes -Graphs.vertices(g::GNNGraph) = 1:g.num_nodes - -function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer) - s, t = edge_index(g) - return t[s .== i] -end - -function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) - A = g.graph - return findall(!=(0), A[i,:]) -end - -function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer) - s, t = edge_index(g) - return s[t .== i] -end - -function Graphs.inneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) - A = g.graph - return findall(!=(0), A[:,i]) -end - -Graphs.is_directed(::GNNGraph) = true -Graphs.is_directed(::Type{<:GNNGraph}) = true - -""" - adjacency_list(g; dir=:out) - -Return the adjacency list representation (a vector of vectors) -of the graph `g`. - -Calling `a` the adjacency list, if `dir=:out` -`a[i]` will contain the neighbors of node `i` through -outgoing edges. If `dir=:in`, it will contain neighbors from -incoming edges instead. -""" -function adjacency_list(g::GNNGraph; dir=:out) - @assert dir ∈ [:out, :in] - fneighs = dir == :out ? outneighbors : inneighbors - return [fneighs(g, i) for i in 1:g.num_nodes] -end - -function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out) - if g.graph[1] isa CuVector - # TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152 - A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes) - else - A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes) - end - @assert size(A) == (n, n) - return dir == :out ? A : A' -end - -function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g.graph); dir=:out) - @assert dir ∈ [:in, :out] - A = g.graph - A = T != eltype(A) ? T.(A) : A - return dir == :out ? A : A' -end - -function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out) - s, t = edge_index(g) - T = isnothing(T) ? eltype(s) : T - degs = fill!(similar(s, T, g.num_nodes), 0) - src = 1 - if dir ∈ [:out, :both] - NNlib.scatter!(+, degs, src, s) - end - if dir ∈ [:in, :both] - NNlib.scatter!(+, degs, src, t) - end - return degs -end - -function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=Int; dir=:out) - @assert dir ∈ (:in, :out) - A = adjacency_matrix(g, T) - return dir == :out ? vec(sum(A, dims=2)) : vec(sum(A, dims=1)) -end - -function Graphs.laplacian_matrix(g::GNNGraph, T::DataType=Int; dir::Symbol=:out) - A = adjacency_matrix(g, T; dir=dir) - D = Diagonal(vec(sum(A; dims=2))) - return D - A -end - - -""" - normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out) - -Normalized Laplacian matrix of graph `g`. - -# Arguments - -- `g`: A `GNNGraph`. -- `T`: result element type. -- `add_self_loops`: add self-loops while calculating the matrix. -- `dir`: the edge directionality considered (:out, :in, :both). -""" -function normalized_laplacian(g::GNNGraph, T::DataType=Float32; - add_self_loops::Bool=false, dir::Symbol=:out) - Ã = normalized_adjacency(g, T; dir, add_self_loops) - return I - Ã -end - -function normalized_adjacency(g::GNNGraph, T::DataType=Float32; - add_self_loops::Bool=false, dir::Symbol=:out) - A = adjacency_matrix(g, T; dir=dir) - if add_self_loops - A = A + I - end - degs = vec(sum(A; dims=2)) - inv_sqrtD = Diagonal(inv.(sqrt.(degs))) - return inv_sqrtD * A * inv_sqrtD -end - -@doc raw""" - scaled_laplacian(g, T=Float32; dir=:out) - -Scaled Laplacian matrix of graph `g`, -defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. - -# Arguments - -- `g`: A `GNNGraph`. -- `T`: result element type. -- `dir`: the edge directionality considered (:out, :in, :both). -""" -function scaled_laplacian(g::GNNGraph, T::DataType=Float32; dir=:out) - L = normalized_laplacian(g, T) - @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" - λmax = _eigmax(L) - return 2 / λmax * L - I -end - -# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays -function _eigmax(A) - x0 = _rand_dense_vector(A) - KrylovKit.eigsolve(Symmetric(A), x0, 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available -end - -_rand_dense_vector(A::AbstractMatrix{T}) where T = randn(float(T), size(A, 1)) -_rand_dense_vector(A::CUMAT_T)= CUDA.randn(size(A, 1)) - -# Eigenvalues for cuarray don't seem to be well supported. -# https://github.com/JuliaGPU/CUDA.jl/issues/154 -# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 - -""" - add_self_loops(g::GNNGraph) - -Return a graph with the same features as `g` -but also adding edges connecting the nodes to themselves. - -Nodes with already existing -self-loops will obtain a second self-loop. -""" -function add_self_loops(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - @assert g.edata === (;) - @assert edge_weight(g) === nothing - n = g.num_nodes - nodes = convert(typeof(s), [1:n;]) - s = [s; nodes] - t = [t; nodes] - - GNNGraph((s, t, nothing), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -function add_self_loops(g::GNNGraph{<:ADJMAT_T}) - A = g.graph - @assert g.edata === (;) - num_edges = g.num_edges + g.num_nodes - A = A + I - GNNGraph(A, - g.num_nodes, num_edges, g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - - -function remove_self_loops(g::GNNGraph{<:COO_T}) - s, t = edge_index(g) - # TODO remove these constraints - @assert g.edata === (;) - @assert edge_weight(g) === nothing - - mask_old_loops = s .!= t - s = s[mask_old_loops] - t = t[mask_old_loops] - - GNNGraph((s, t, nothing), - g.num_nodes, length(s), g.num_graphs, - g.graph_indicator, - g.ndata, g.edata, g.gdata) -end - -function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) - nv1, nv2 = g1.num_nodes, g2.num_nodes - if g1.graph isa COO_T - s1, t1 = edge_index(g1) - s2, t2 = edge_index(g2) - s = vcat(s1, nv1 .+ s2) - t = vcat(t1, nv1 .+ t2) - w = cat_features(edge_weight(g1), edge_weight(g2)) - graph = (s, t, w) - ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator - ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator - elseif g1.graph isa ADJMAT_T - graph = blockdiag(g1.graph, g2.graph) - ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, Int, nv1) : g1.graph_indicator - ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, Int, nv2) : g2.graph_indicator - end - graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) - - GNNGraph(graph, - nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs, - graph_indicator, - cat_features(g1.ndata, g2.ndata), - cat_features(g1.edata, g2.edata), - cat_features(g1.gdata, g2.gdata)) -end - -# PIRACY -function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) - m1, n1 = size(A1) - @assert m1 == n1 - m2, n2 = size(A2) - @assert m2 == n2 - O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0) - O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0) - return [A1 O1 - O2 A2] -end - -### Cat public interfaces ############# - -""" - blockdiag(xs::GNNGraph...) - -Equivalent to [`Flux.batch`](@ref). -""" -function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) - g = g1 - for go in gothers - g = blockdiag(g, go) - end - return g -end - -""" - batch(xs::Vector{<:GNNGraph}) - -Batch together multiple `GNNGraph`s into a single one -containing the total number of original nodes and edges. - -Equivalent to [`SparseArrays.blockdiag`](@ref). -""" -Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...) - -### StatsBase/LearnBase compatibility -StatsBase.nobs(g::GNNGraph) = g.num_graphs -LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i) - -# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683 -Flux.Data._nobs(g::GNNGraph) = g.num_graphs -Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i) - -######################### -Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1))) - -""" - getgraph(g::GNNGraph, i; nmap=false) - -Return the subgraph of `g` induced by those nodes `j` -for which `g.graph_indicator[j] == i` or, -if `i` is a collection, `g.graph_indicator[j] ∈ i`. -In other words, it extract the component graphs from a batched graph. - -If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones. -The node `i` in the subgraph will correspond to the node `v[i]` in `g`. -""" -getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...) - -function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false) - if g.graph_indicator === nothing - @assert i == [1] - if nmap - return g, 1:g.num_nodes - else - return g - end - end - - node_mask = g.graph_indicator .∈ Ref(i) - - nodes = (1:g.num_nodes)[node_mask] - nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes)) - - graphmap = Dict(i => inew for (inew, i) in enumerate(i)) - graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]] - - s, t = edge_index(g) - w = edge_weight(g) - edge_mask = s .∈ Ref(nodes) - - if g.graph isa COO_T - s = [nodemap[i] for i in s[edge_mask]] - t = [nodemap[i] for i in t[edge_mask]] - w = isnothing(w) ? nothing : w[edge_mask] - graph = (s, t, w) - elseif g.graph isa ADJMAT_T - graph = g.graph[nodes, nodes] - end - - ndata = getobs(g.ndata, node_mask) - edata = getobs(g.edata, edge_mask) - gdata = getobs(g.gdata, i) - - num_edges = sum(edge_mask) - num_nodes = length(graph_indicator) - num_graphs = length(i) - - gnew = GNNGraph(graph, - num_nodes, num_edges, num_graphs, - graph_indicator, - ndata, edata, gdata) - - if nmap - return gnew, nodes - else - return gnew - end -end - -function node_features(g::GNNGraph) - if isempty(g.ndata) - return nothing - elseif length(g.ndata) > 1 - @error "Multiple feature arrays, access directly through `g.ndata`" - else - return g.ndata[1] - end -end - -function edge_features(g::GNNGraph) - if isempty(g.edata) - return nothing - elseif length(g.edata) > 1 - @error "Multiple feature arrays, access directly through `g.edata`" - else - return g.edata[1] - end -end - -function graph_features(g::GNNGraph) - if isempty(g.gdata) - return nothing - elseif length(g.gdata) > 1 - @error "Multiple feature arrays, access directly through `g.gdata`" - else - return g.gdata[1] - end -end - - -@non_differentiable normalized_laplacian(x...) -@non_differentiable normalized_adjacency(x...) -@non_differentiable scaled_laplacian(x...) -@non_differentiable adjacency_matrix(x...) -@non_differentiable adjacency_list(x...) -@non_differentiable degree(x...) -@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule -@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule diff --git a/src/utils.jl b/src/utils.jl index b173b2898..c5b4ec698 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,79 +1,3 @@ -function check_num_nodes(g::GNNGraph, x::AbstractArray) - @assert g.num_nodes == size(x, ndims(x)) -end -function check_num_edges(g::GNNGraph, e::AbstractArray) - @assert g.num_edges == size(e, ndims(e)) -end - -sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) - -function sort_edge_index(u, v) - uv = collect(zip(u, v)) - p = sortperm(uv) # isless lexicographically defined for tuples - return u[p], v[p] -end - -cat_features(x1::Nothing, x2::Nothing) = nothing -cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1)) -cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) = - cat(x1, x2, dims=1) - - -function cat_features(x1::NamedTuple, x2::NamedTuple) - sort(collect(keys(x1))) == sort(collect(keys(x2))) || - @error "cannot concatenate feature data with different keys" - - NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1)) -end - -# Turns generic type into named tuple -normalize_graphdata(data::Nothing; kws...) = NamedTuple() - -normalize_graphdata(data; default_name::Symbol, kws...) = - normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) - -function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed=false) - # This had to workaround two Zygote bugs with NamedTuples - # https://github.com/FluxML/Zygote.jl/issues/1071 - # https://github.com/FluxML/Zygote.jl/issues/1072 - - if n == 1 - # If last array dimension is not 1, add a new dimension. - # This is mostly useful to reshape graph feature vectors - # of size D into Dx1 matrices. - function unsqz(v) - if v isa AbstractArray && size(v)[end] != 1 - v = reshape(v, size(v)..., 1) - end - v - end - - data = NamedTuple{keys(data)}(unsqz.(values(data))) - end - - sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data) - if duplicate_if_needed - # Used to copy edge features on reverse edges - @assert all(s -> s == 0 || s == n || s == n÷2, sz) - - function duplicate(v) - if v isa AbstractArray && size(v)[end] == n÷2 - v = cat(v, v, dims=ndims(v)) - end - v - end - - data = NamedTuple{keys(data)}(duplicate.(values(data))) - else - @assert all(s -> s == 0 || s == n, sz) - end - return data -end - -ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1) -ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz) -ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz) - ofeltype(x, y) = convert(float(eltype(x)), y) # Considers the src a zero dimensional object. @@ -197,7 +121,7 @@ Softmax over each node's neighborhood of the edge features `e`. ```math \mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} - {\sum_{j'\in N(i)} e^{\mathbf{e}_{j\to i}}}. + {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. ``` """ function softmax_edge_neighbors(g::GNNGraph, e) @@ -233,17 +157,3 @@ function broadcast_edges(g::GNNGraph, x) return gather(x, gi) end - -function graph_indicator(g; edges=false) - if isnothing(g.graph_indicator) - gi = ones_like(edge_index(g)[1], Int, g.num_nodes) - else - gi = g.graph_indicator - end - if edges - s, t = edge_index(g) - return gi[s] - else - return gi - end -end diff --git a/test/GNNGraphs/generate.jl b/test/GNNGraphs/generate.jl new file mode 100644 index 000000000..326844659 --- /dev/null +++ b/test/GNNGraphs/generate.jl @@ -0,0 +1,21 @@ +@testset "generate" begin + @testset "rand_graph" begin + n, m = 10, 20 + x = rand(3, n) + e = rand(4, m) + g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T) + @test g.num_nodes == n + @test g.num_edges == 2m + @test g.ndata.x === x + if GRAPH_T == :coo + s, t = edge_index(g) + @test s[1:m] == t[m+1:end] + @test t[1:m] == s[m+1:end] + @test g.edata.e[:,1:m] == e + @test g.edata.e[:,m+1:end] == e + end + g = rand_graph(n, m, directed=true, graph_type=GRAPH_T) + @test g.num_nodes == n + @test g.num_edges == m + end +end diff --git a/test/gnngraph.jl b/test/GNNGraphs/gnngraph.jl similarity index 74% rename from test/gnngraph.jl rename to test/GNNGraphs/gnngraph.jl index 55383d183..6227399ea 100644 --- a/test/gnngraph.jl +++ b/test/GNNGraphs/gnngraph.jl @@ -23,7 +23,7 @@ @test sort(outneighbors(g, 1)) == [2, 4] @test sort(inneighbors(g, 1)) == [2, 4] @test is_directed(g) == true - s1, t1 = GraphNeuralNetworks.sort_edge_index(edge_index(g)) + s1, t1 = sort_edge_index(edge_index(g)) @test s1 == s @test t1 == t @test vertices(g) == 1:g.num_nodes @@ -123,7 +123,7 @@ @test sort(inneighbors(g, 1)) == [4] @test is_directed(g) == true @test is_directed(typeof(g)) == true - s1, t1 = GraphNeuralNetworks.sort_edge_index(edge_index(g)) + s1, t1 = sort_edge_index(edge_index(g)) @test s1 == s @test t1 == t @@ -154,69 +154,6 @@ end end - @testset "add self-loops" begin - A = [1 1 0 0 - 0 0 1 0 - 0 0 0 1 - 1 0 0 0] - A2 = [2 1 0 0 - 0 1 1 0 - 0 0 1 1 - 1 0 0 1] - - g = GNNGraph(A; graph_type=GRAPH_T) - fg2 = add_self_loops(g) - @test adjacency_matrix(g) == A - @test g.num_edges == sum(A) - @test adjacency_matrix(fg2) == A2 - @test fg2.num_edges == sum(A2) - end - - @testset "batch" begin - #TODO add graph_type=GRAPH_T - g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10)) - g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4)) - g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7)) - - g12 = Flux.batch([g1, g2]) - g12b = blockdiag(g1, g2) - - g123 = Flux.batch([g1, g2, g3]) - @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] - - s, t = edge_index(g123) - @test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]] - @test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]] - @test node_features(g123)[:,11:14] ≈ node_features(g2) - - # scalar graph features - g1 = GNNGraph(random_regular_graph(10,2), gdata=rand()) - g2 = GNNGraph(random_regular_graph(4,2), gdata=rand()) - g3 = GNNGraph(random_regular_graph(4,2), gdata=rand()) - g123 = Flux.batch([g1, g2, g3]) - @test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u] - end - - @testset "getgraph" begin - g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T) - g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T) - g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7), graph_type=GRAPH_T) - g = Flux.batch([g1, g2, g3]) - - g2b, nodemap = getgraph(g, 2, nmap=true) - s, t = edge_index(g2b) - @test s == edge_index(g2)[1] - @test t == edge_index(g2)[2] - @test node_features(g2b) ≈ node_features(g2) - - g2c = getgraph(g, 2) - @test g2c isa GNNGraph{typeof(g.graph)} - - g1b, nodemap = getgraph(g1, 1, nmap=true) - @test g1b === g1 - @test nodemap == 1:g1.num_nodes - end - @testset "Features" begin g = GNNGraph(sprand(10, 10, 0.3), graph_type=GRAPH_T) @@ -267,7 +204,7 @@ X = rand(10, n) E = rand(10, 2m) U = rand(10, 1) - g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U) + g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) for _ in 1:num_graphs]) @test LearnBase.getobs(g, 3) == getgraph(g, 3) diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl new file mode 100644 index 000000000..aabdcf221 --- /dev/null +++ b/test/GNNGraphs/transform.jl @@ -0,0 +1,83 @@ +@testset "transform" begin + @testset "add self-loops" begin + A = [1 1 0 0 + 0 0 1 0 + 0 0 0 1 + 1 0 0 0] + A2 = [2 1 0 0 + 0 1 1 0 + 0 0 1 1 + 1 0 0 1] + + g = GNNGraph(A; graph_type=GRAPH_T) + fg2 = add_self_loops(g) + @test adjacency_matrix(g) == A + @test g.num_edges == sum(A) + @test adjacency_matrix(fg2) == A2 + @test fg2.num_edges == sum(A2) + end + + @testset "batch" begin + #TODO add graph_type=GRAPH_T + g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10)) + g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4)) + g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7)) + + g12 = Flux.batch([g1, g2]) + g12b = blockdiag(g1, g2) + + g123 = Flux.batch([g1, g2, g3]) + @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] + + s, t = edge_index(g123) + @test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]] + @test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]] + @test node_features(g123)[:,11:14] ≈ node_features(g2) + + # scalar graph features + g1 = GNNGraph(random_regular_graph(10,2), gdata=rand()) + g2 = GNNGraph(random_regular_graph(4,2), gdata=rand()) + g3 = GNNGraph(random_regular_graph(4,2), gdata=rand()) + g123 = Flux.batch([g1, g2, g3]) + @test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u] + end + + @testset "getgraph" begin + g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T) + g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T) + g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7), graph_type=GRAPH_T) + g = Flux.batch([g1, g2, g3]) + + g2b, nodemap = getgraph(g, 2, nmap=true) + s, t = edge_index(g2b) + @test s == edge_index(g2)[1] + @test t == edge_index(g2)[2] + @test node_features(g2b) ≈ node_features(g2) + + g2c = getgraph(g, 2) + @test g2c isa GNNGraph{typeof(g.graph)} + + g1b, nodemap = getgraph(g1, 1, nmap=true) + @test g1b === g1 + @test nodemap == 1:g1.num_nodes + end + + @testset "add_edges" begin + if GRAPH_T == :coo + s = [1,1,2,3] + t = [2,3,4,5] + g = GNNGraph(s, t, graph_type=GRAPH_T) + snew = [1] + tnew = [4] + gnew = add_edges(g, snew, tnew) + @test gnew.num_edges == 5 + @test sort(inneighbors(gnew, 4)) == [1, 2] + + g = GNNGraph(s, t, edata=(e1=rand(2,4), e2=rand(3,4)), graph_type=GRAPH_T) + # @test_throws ErrorException add_edges(g, snew, tnew) + gnew = add_edges(g, snew, tnew, edata=(e1=ones(2,1), e2=zeros(3,1))) + @test all(gnew.edata.e1[:,5] .== 1) + @test all(gnew.edata.e2[:,5] .== 0) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 830920bb4..44d08a5bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using GraphNeuralNetworks +using GraphNeuralNetworks.GNNGraphs: sort_edge_index using Flux using CUDA using Flux: gpu, @functor @@ -19,7 +20,9 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - "gnngraph", + "GNNGraphs/gnngraph", + "GNNGraphs/transform", + "GNNGraphs/generate", "utils", "msgpass", "layers/basic", @@ -31,9 +34,9 @@ tests = [ !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") -@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :sparse, :dense) +@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) global GRAPH_T = graph_type - global TEST_GPU = CUDA.functional() + global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) for t in tests startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI