Skip to content

Commit

Permalink
refactor GNNGraph into its own module; implement add_edges
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 30, 2021
1 parent 375b787 commit f8a2695
Show file tree
Hide file tree
Showing 16 changed files with 965 additions and 188 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
Expand All @@ -17,6 +18,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -28,12 +30,14 @@ CUDA = "3.3"
ChainRulesCore = "1"
DataStructures = "0.18"
Flux = "0.12.7"
Functors = "0.2"
Graphs = "1.4"
KrylovKit = "0.5"
LearnBase = "0.4, 0.5"
MacroTools = "0.5"
NNlib = "0.7"
NNlibCUDA = "0.1"
Reexport = "1"
StatsBase = "0.32, 0.33"
TestEnv = "1"
julia = "1.6"
Expand Down
42 changes: 40 additions & 2 deletions docs/src/api/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,52 @@ Pages = ["gnngraph.md"]

## Docs

### GNNGraph

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["GNNGraphs/gnngraph.jl"]
Private = false
```

### Query

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["GNNGraphs/query.jl"]
Private = false
```

```@docs
Graphs.adjacency_matrix
Graphs.degree
Graphs.outneighbors
Graphs.inneighbors
```

### Transform

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["gnngraph.jl"]
Pages = ["GNNGraphs/transform.jl"]
Private = false
```

```@docs
Flux.batch
SparseArrays.blockdiag
Graphs.adjacency_matrix
```

### Generate

```@autodocs
Modules = [GraphNeuralNetworks]
Pages = ["GNNGraphs/generate.jl"]
Private = false
```

### Related methods

```@docs
SparseArrays.sparse
```
13 changes: 8 additions & 5 deletions docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ A GNNGraph can be created from several different data sources encoding the graph
using GraphNeuralNetworks, Graphs, SparseArrays


# Construct GNNGraph from From Graphs's graph
# Construct a GNNGraph from from a Graphs.jl's graph
lg = erdos_renyi(10, 30)
g = GNNGraph(lg)

# Same as above using convenience method rand_graph
g = rand_graph(10, 30)

# From an adjacency matrix
A = sprand(10, 10, 0.3)
g = GNNGraph(A)
Expand Down Expand Up @@ -123,21 +126,21 @@ for g in train_loader
.....
end

# Access the nodes' graph memberships through
gall.graph_indicator
# Access the nodes' graph memberships
graph_indicator(gall)
```

## Graph Manipulation

```julia
g′ = add_self_loops(g)

g′ = remove_self_loops(g)
g′ = add_edges(g, [1, 2], [2, 3]) # add edges 1->2 and 2->3
```

## JuliaGraphs ecosystem integration

Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.

```julia
@assert Graphs.isdirected(g)
Expand Down
43 changes: 43 additions & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
module GNNGraphs

using SparseArrays
using Functors: @functor
using CUDA
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
import Flux
using Flux: batch
import NNlib
import LearnBase
import StatsBase
using LearnBase: getobs
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random

include("gnngraph.jl")
export GNNGraph, node_features, edge_features, graph_features

include("query.jl")
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
graph_indicator

include("transform.jl")
export add_edges, add_self_loops, remove_self_loops, getgraph

include("generate.jl")
export rand_graph


include("convert.jl")
include("utils.jl")

export
# from Graphs
adjacency_matrix, degree, outneighbors, inneighbors,
# from SparseArrays
sprand, sparse, blockdiag,
# from Flux
batch

end #module
File renamed without changes.
14 changes: 14 additions & 0 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
rand_graph(n, m; directed=false, kws...)
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
If `directed=false` the output will contain `2m` edges:
the reverse edge of each edge will be present.
If `directed=true` instead, `m` unrelated edges are generated.
Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
"""
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
end
Loading

0 comments on commit f8a2695

Please sign in to comment.