Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve docs #72

Merged
merged 1 commit into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,33 @@ See also the related methods [`Graphs.adjacency_matrix`](@ref), [`edge_index`](@
## Basic Queries

```julia
source = [1,1,2,2,3,3,3,4]
target = [2,3,1,3,1,2,4,3]
g = GNNGraph(source, target)
julia> source = [1,1,2,2,3,3,3,4];

julia> target = [2,3,1,3,1,2,4,3];

julia> g = GNNGraph(source, target)
GNNGraph:
num_nodes = 4
num_edges = 8


julia> @assert g.num_nodes == 4 # number of nodes

julia> @assert g.num_edges == 8 # number of edges

julia> @assert g.num_graphs == 1 # number of subgraphs (a GNNGraph can batch many graphs together)

julia> is_directed(g) # a GNNGraph is always directed
true

julia> is_bidirected(g) # for each edge, also the reverse edge is present
true

julia> has_self_loops(g)
false

@assert g.num_nodes == 4 # number of nodes
@assert g.num_edges == 8 # number of edges
@assert g.num_graphs == 1 # number of subgraphs (a GNNGraph can batch many graphs together)
is_directed(g) # a GGNGraph is always directed
julia> has_multi_edges(g)
false
```

## Data Features
Expand Down
60 changes: 58 additions & 2 deletions docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,65 @@ and [`NNlib.scatter`](@ref) methods.

## Examples

### Basic use of propagate and apply_edges
### Basic use of apply_edges and propagate

TODO
The function [`apply_edges`](@ref) can be used to broadcast node data
on each edge and produce new edge data.
```julia
julia> using GraphNeuralNetworks, Graphs, Statistics

julia> g = rand_graph(10, 20)
GNNGraph:
num_nodes = 10
num_edges = 20


julia> x = ones(2,10)
2×10 Matrix{Float64}:
1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0

julia> z = 2ones(2,10)
2×10 Matrix{Float64}:
2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0

julia> apply_edges((xi, xj, e) -> xi .+ xj, g, xi=x, xj=z)
2×20 Matrix{Float64}:
3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0
3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0

# now returning a named tuple
julia> apply_edges((xi, xj, e) -> (a=xi .+ xj, b=xi .- xj), g, xi=x, xj=z)
(a = [3.0 3.0 … 3.0 3.0; 3.0 3.0 … 3.0 3.0], b = [-1.0 -1.0 … -1.0 -1.0; -1.0 -1.0 … -1.0 -1.0])

# Here we provide a named tuple input
julia> apply_edges((xi, xj, e) -> xi.a + xi.b .* xj, g, xi=(a=x,b=z), xj=z)
2×20 Matrix{Float64}:
5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0
5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0 5.0
```
The function [@propagate](@ref) instead performs also the the [`apply_edges`](@ref) operation
but then applies a reduction over each node's neighborhood.
```julia
julia> propagate((xi, xj, e) -> xi .+ xj, g, +, xi=x, xj=z)
2×10 Matrix{Float64}:
3.0 6.0 9.0 9.0 0.0 6.0 6.0 3.0 15.0 3.0
3.0 6.0 9.0 9.0 0.0 6.0 6.0 3.0 15.0 3.0

julia> degree(g)
10-element Vector{Int64}:
1
2
3
3
0
2
2
1
5
1
```

### Implementing a custom Graph Convolutional Layer

Expand Down
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using SparseArrays
using Functors: @functor
using CUDA
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops, is_directed
import Flux
using Flux: batch
import NNlib
Expand All @@ -26,6 +26,7 @@ export adjacency_list,
edge_index,
graph_indicator,
has_multi_edges,
is_directed,
is_bidirected,
normalized_laplacian,
scaled_laplacian,
Expand Down