Skip to content

Commit

Permalink
GNNGraph inherits from Graphs.AbstractGraph (#63)
Browse files Browse the repository at this point in the history
* improve GNNGraph docstring

* inherit from AbstractGraph

* cleanup
  • Loading branch information
CarloLucibello authored Oct 29, 2021
1 parent f647c1e commit 375b787
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 15 deletions.
4 changes: 4 additions & 0 deletions docs/src/api/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ CurrentModule = GraphNeuralNetworks

Documentation page for the graph type `GNNGraph` provided GraphNeuralNetworks.jl and its related methods.

Besides the methods documented here, one can rely on the large set of functionalities
given by [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)
since `GNNGraph` inherits from `Graphs.AbstractGraph`.

## Index

```@index
Expand Down
3 changes: 3 additions & 0 deletions docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`.
The underlying implementation allows for efficient application of graph neural network
operators, gpu movement, and storage of node/edge/graph related feature arrays.

`GNNGraph` inherits from [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)'s `AbstractGraph`,
therefore it supports most functionality from that library.

## Graph Creation
A GNNGraph can be created from several different data sources encoding the graph topology:

Expand Down
32 changes: 19 additions & 13 deletions src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,34 @@ const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
const CUMAT_T = Union{AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

"""
"""
GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir])
GNNGraph(g::GNNGraph; [ndata, edata, gdata])
A type representing a graph structure and storing also
feature arrays associated to nodes, edges, and to the whole graph (global features).
A type representing a graph structure that also stores
feature arrays associated to nodes, edges, and the graph itself.
A `GNNGraph` can be constructed out of different objects `data` expressing
the connections inside the graph. The internal representation type
A `GNNGraph` can be constructed out of different `data` objects
expressing the connections inside the graph. The internal representation type
is determined by `graph_type`.
When constructed from another `GNNGraph`, the internal graph representation
is preserved and shared. The node/edge/global features are transmitted
as well, unless explicitely changed though keyword arguments.
is preserved and shared. The node/edge/graph features are retained
as well, unless explicitely set by the keyword arguments
`ndata`, `edata`, and `gdata`.
A `GNNGraph` can also represent multiple graphs batched togheter
(see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)).
The field `g.graph_indicator` contains the graph membership
of each node.
A `GNNGraph` is a Graphs' `AbstractGraph`, therefore any functionality
from the Graphs' graph library can be used on it.
`GNNGraph`s are always directed graphs, therefore each edge is defined
by a source node and a target node (see [`edge_index`](@ref)).
Self loops (edges connecting a node to itself) and multiple edges
(more than one edge between the same pair of nodes) are supported.
A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most
functionality from that library.
# Arguments
Expand All @@ -54,9 +60,9 @@ from the Graphs' graph library can be used on it.
Possible values are `:out` and `:in`. Default `:out`.
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
- `ndata`: Node features. A named tuple of arrays whose last dimension has size num_nodes.
- `edata`: Edge features. A named tuple of arrays whose whose last dimension has size num_edges.
- `gdata`: Global features. A named tuple of arrays whose has size num_graphs.
- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`.
- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`.
- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`.
# Usage.
Expand Down Expand Up @@ -97,7 +103,7 @@ g = g |> gpu
source, target = edge_index(g)
```
"""
struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractGraph{Int}
graph::T
num_nodes::Int
num_edges::Int
Expand Down
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee

if n == 1
# If last array dimension is not 1, add a new dimension.
# This is mostly usefule to reshape globale feature vectors
# of size D to Dx1 matrices.
# This is mostly useful to reshape graph feature vectors
# of size D into Dx1 matrices.
function unsqz(v)
if v isa AbstractArray && size(v)[end] != 1
v = reshape(v, size(v)..., 1)
Expand Down
5 changes: 5 additions & 0 deletions test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
end

@testset "Graphs.jl integration" begin
g = GNNGraph(erdos_renyi(10, 20))
@test g isa Graphs.AbstractGraph
end
end


0 comments on commit 375b787

Please sign in to comment.