Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for batched graphs #17

Merged
merged 1 commit into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ export

# layers/pool
GlobalPool,
LocalPool,
TopKPool,
topk_index

Expand Down
70 changes: 59 additions & 11 deletions src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T

"""
GNNGraph(data; [graph_type, dir, num_nodes, nf, ef, gf])
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
GNNGraph(g::GNNGraph; [nf, ef, gf])

A type representing a graph structure and storing also arrays
Expand Down Expand Up @@ -43,11 +43,13 @@ from the LightGraphs' graph library can be used on it.
- `:dense`. A dense adjacency matrix representation.
Default `:coo`.
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
Possible values are `:out` and `:in`. Defaul `:out`.
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing.
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing.
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing.
- `gf`: Global features. Default nothing.
Possible values are `:out` and `:in`. Default `:out`.
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
- `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
- `gf`: Global features. Default `nothing`.

# Usage.

Expand Down Expand Up @@ -87,6 +89,8 @@ struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
graph::T
num_nodes::Int
num_edges::Int
num_graphs::Int
graph_indicator
nf
ef
gf
Expand All @@ -99,7 +103,9 @@ end
@functor GNNGraph

function GNNGraph(data;
num_nodes = nothing,
num_nodes = nothing,
num_graphs = 1,
graph_indicator = nothing,
graph_type = :coo,
dir = :out,
nf = nothing,
Expand All @@ -119,6 +125,9 @@ function GNNGraph(data;
elseif graph_type == :sparse
g, num_nodes, num_edges = to_sparse(data; dir)
end
if num_graphs > 1
@assert len(graph_indicator) = num_nodes "When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
end

## Possible future implementation of feature maps.
## Currently this doesn't play well with zygote due to
Expand All @@ -127,8 +136,9 @@ function GNNGraph(data;
# edata["e"] = ef
# gdata["g"] = gf


GNNGraph(g, num_nodes, num_edges, nf, ef, gf)
GNNGraph(g, num_nodes, num_edges,
num_graphs, graph_indicator,
nf, ef, gf)
end

# COO convenience constructors
Expand All @@ -147,7 +157,7 @@ function GNNGraph(g::GNNGraph;
nf=node_feature(g), ef=edge_feature(g), gf=global_feature(g))
# ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data

GNNGraph(g.graph, g.num_nodes, g.num_edges, nf, ef, gf) # ndata, edata, gdata,
GNNGraph(g.graph, g.num_nodes, g.num_edges, g.num_graphs, g.graph_indicator, nf, ef, gf) # ndata, edata, gdata,
end


Expand Down Expand Up @@ -370,6 +380,7 @@ function add_self_loops(g::GNNGraph{<:COO_T})
t = [t; nodes]

GNNGraph((s, t, nothing), g.num_nodes, length(s),
g.num_graphs, g.graph_indicator,
node_feature(g), edge_feature(g), global_feature(g))
end

Expand All @@ -379,6 +390,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true)
A += I
num_edges = g.num_edges + g.num_nodes
GNNGraph(A, g.num_nodes, num_edges,
g.num_graphs, g.graph_indicator,
node_feature(g), edge_feature(g), global_feature(g))
end

Expand All @@ -392,10 +404,46 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
s = s[mask_old_loops]
t = t[mask_old_loops]

GNNGraph((s, t, nothing), g.num_nodes, length(s),
GNNGraph((s, t, nothing), g.num_nodes, length(s),
g.num_graphs, g.graph_indicator,
node_feature(g), edge_feature(g), global_feature(g))
end

function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
s1, t1 = edge_index(g1)
s2, t2 = edge_index(g2)
nv1, nv2 = g1.num_nodes, g2.num_nodes
s = vcat(s1, nv1 .+ s2)
t = vcat(t1, nv1 .+ t2)
w = cat_features(edge_weight(g1), edge_weight(g2))

ind1 = isnothing(g1.graph_indicator) ? fill!(similar(s1, Int, nv1), 1) : g1.graph_indicator
ind2 = isnothing(g2.graph_indicator) ? fill!(similar(s2, Int, nv2), 1) : g2.graph_indicator
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)

GNNGraph(
(s, t, w),
nv1 + nv2, g1.num_edges + g2.num_edges,
g1.num_graphs + g2.num_graphs, graph_indicator,
cat_features(node_feature(g1), node_feature(g2)),
cat_features(edge_feature(g1), edge_feature(g2)),
cat_features(global_feature(g1), global_feature(g2)),
)
end

# Cat public interfaces
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
@assert length(gothers) >= 1
g = g1
for go in gothers
g = _catgraphs(g, go)
end
return g
end

Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
#########################

@non_differentiable normalized_laplacian(x...)
@non_differentiable normalized_adjacency(x...)
@non_differentiable scaled_laplacian(x...)
Expand Down
61 changes: 29 additions & 32 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,42 @@
using DataStructures: nlargest

"""
GlobalPool(aggr, dim...)

Global pooling layer.

It pools all features with `aggr` operation.

# Arguments

- `aggr`: An aggregate function applied to pool all features.
"""
struct GlobalPool{A}
aggr
cluster::A
function GlobalPool(aggr, dim...)
cluster = ones(Int64, dim)
new{typeof(cluster)}(aggr, cluster)
end
end
@doc raw"""
GlobalPool(aggr)

(l::GlobalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster)
Global pooling layer for graph neural networks.
Takes a graph and feature nodes as inputs
and performs the operation

"""
LocalPool(aggr, cluster)
```math
\mathbf{u}_V = \box_{i \in V} \mathbf{x}_i
````
where ``V`` is the set of nodes of the input graph and
the type of aggregation represented by `\box` is selected by the `aggr` argument.
Commonly used aggregations are are `mean`, `max`, and `+`.

Local pooling layer.
```julia
using GraphNeuralNetworks, LightGraphs

It pools features with `aggr` operation accroding to `cluster`. It is implemented with `scatter` operation.
pool = GlobalPool(mean)

# Arguments

- `aggr`: An aggregate function applied to pool all features.
- `cluster`: An index structure which indicates what features to aggregate with.
g = GNNGraph(random_regular_graph(10, 4))
X = rand(32, 10)
pool(g, X) # => 32x1 matrix
```
"""
struct LocalPool{A<:AbstractArray}
aggr
cluster::A
struct GlobalPool{F}
aggr::F
end

(l::LocalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster)
function (l::GlobalPool)(g::GNNGraph, X::AbstractArray)
if isnothing(g.graph_indicator)
# assume only one graph
indexes = fill!(similar(X, Int, g.num_nodes), 1)
else
indexes = g.graph_indicator
end
return NNlib.scatter(l.aggr, X, indexes)
end

"""
TopKPool(adj, k, in_channel)
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ function sort_edge_index(u, v)
p = sortperm(uv) # isless lexicographically defined for tuples
return u[p], v[p]
end

cat_features(x1::Nothing, x2::Nothing) = nothing
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1))
12 changes: 12 additions & 0 deletions test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,16 @@
@test adjacency_matrix(fg2) == A2
@test fg2.num_edges == sum(A2)
end

@testset "batch" begin
g1 = GNNGraph(random_regular_graph(10,2), nf=rand(16,10))
g2 = GNNGraph(random_regular_graph(4,2), nf=rand(16,4))
g3 = GNNGraph(random_regular_graph(7,2), nf=rand(16,7))

g12 = Flux.batch([g1, g2])
g12b = blockdiag(g1, g2)

g123 = Flux.batch([g1, g2, g3])
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
end
end
16 changes: 5 additions & 11 deletions test/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
cluster = [1 1 1 1; 2 2 3 3; 4 4 5 5]
X = Array(reshape(1:24, 2, 3, 4))

@testset "pool" begin
@testset "GlobalPool" begin
glb_cltr = [1 1 1 1; 1 1 1 1; 1 1 1 1]
p = GlobalPool(+, 3, 4)
@test p(X) == NNlib.scatter(+, X, glb_cltr)
end

@testset "LocalPool" begin
p = LocalPool(+, cluster)
@test p(X) == NNlib.scatter(+, X, cluster)
n = 10
X = rand(16, n)
g = GNNGraph(random_regular_graph(n, 4))
p = GlobalPool(+)
@test p(g, X) ≈ NNlib.scatter(+, X, ones(Int, n))
end

@testset "TopKPool" begin
Expand Down