diff --git a/docs/src/api/gnngraph.md b/docs/src/api/gnngraph.md index 0a2e9ffdd..9f7cefde9 100644 --- a/docs/src/api/gnngraph.md +++ b/docs/src/api/gnngraph.md @@ -6,6 +6,10 @@ CurrentModule = GraphNeuralNetworks Documentation page for the graph type `GNNGraph` provided GraphNeuralNetworks.jl and its related methods. +Besides the methods documented here, one can rely on the large set of functionalities +given by [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl) +since `GNNGraph` inherits from `Graphs.AbstractGraph`. + ## Index ```@index diff --git a/docs/src/gnngraph.md b/docs/src/gnngraph.md index 4c9853e6a..bb0cabb6b 100644 --- a/docs/src/gnngraph.md +++ b/docs/src/gnngraph.md @@ -5,6 +5,9 @@ A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`. The underlying implementation allows for efficient application of graph neural network operators, gpu movement, and storage of node/edge/graph related feature arrays. +`GNNGraph` inherits from [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)'s `AbstractGraph`, +therefore it supports most functionality from that library. + ## Graph Creation A GNNGraph can be created from several different data sources encoding the graph topology: diff --git a/src/gnngraph.jl b/src/gnngraph.jl index 2c8cccf58..e7e47ac30 100644 --- a/src/gnngraph.jl +++ b/src/gnngraph.jl @@ -11,28 +11,34 @@ const ADJMAT_T = AbstractMatrix const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T const CUMAT_T = Union{AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} -""" +""" GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) GNNGraph(g::GNNGraph; [ndata, edata, gdata]) -A type representing a graph structure and storing also -feature arrays associated to nodes, edges, and to the whole graph (global features). +A type representing a graph structure that also stores +feature arrays associated to nodes, edges, and the graph itself. -A `GNNGraph` can be constructed out of different objects `data` expressing -the connections inside the graph. The internal representation type +A `GNNGraph` can be constructed out of different `data` objects +expressing the connections inside the graph. The internal representation type is determined by `graph_type`. When constructed from another `GNNGraph`, the internal graph representation -is preserved and shared. The node/edge/global features are transmitted -as well, unless explicitely changed though keyword arguments. +is preserved and shared. The node/edge/graph features are retained +as well, unless explicitely set by the keyword arguments +`ndata`, `edata`, and `gdata`. A `GNNGraph` can also represent multiple graphs batched togheter (see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). The field `g.graph_indicator` contains the graph membership of each node. -A `GNNGraph` is a Graphs' `AbstractGraph`, therefore any functionality -from the Graphs' graph library can be used on it. +`GNNGraph`s are always directed graphs, therefore each edge is defined +by a source node and a target node (see [`edge_index`](@ref)). +Self loops (edges connecting a node to itself) and multiple edges +(more than one edge between the same pair of nodes) are supported. + +A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most +functionality from that library. # Arguments @@ -54,9 +60,9 @@ from the Graphs' graph library can be used on it. Possible values are `:out` and `:in`. Default `:out`. - `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. - `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`. -- `ndata`: Node features. A named tuple of arrays whose last dimension has size num_nodes. -- `edata`: Edge features. A named tuple of arrays whose whose last dimension has size num_edges. -- `gdata`: Global features. A named tuple of arrays whose has size num_graphs. +- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`. +- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`. +- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`. # Usage. @@ -97,7 +103,7 @@ g = g |> gpu source, target = edge_index(g) ``` """ -struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} +struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractGraph{Int} graph::T num_nodes::Int num_edges::Int diff --git a/src/utils.jl b/src/utils.jl index 7f9506d27..b173b2898 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -39,8 +39,8 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee if n == 1 # If last array dimension is not 1, add a new dimension. - # This is mostly usefule to reshape globale feature vectors - # of size D to Dx1 matrices. + # This is mostly useful to reshape graph feature vectors + # of size D into Dx1 matrices. function unsqz(v) if v isa AbstractArray && size(v)[end] != 1 v = reshape(v, size(v)..., 1) diff --git a/test/gnngraph.jl b/test/gnngraph.jl index 4eca56ea8..55383d183 100644 --- a/test/gnngraph.jl +++ b/test/gnngraph.jl @@ -277,6 +277,11 @@ d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false) @test first(d) == getgraph(g, 1:2) end + + @testset "Graphs.jl integration" begin + g = GNNGraph(erdos_renyi(10, 20)) + @test g isa Graphs.AbstractGraph + end end