Skip to content

Commit

Permalink
dataloader support for vector of graphs (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Mar 5, 2022
1 parent f935e8d commit 05daca6
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 84 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
Manifest.toml
/docs/build/
.vscode
LocalPreferences.toml
36 changes: 30 additions & 6 deletions docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,38 +144,62 @@ julia> get_edge_weight(g)
## Batches and Subgraphs

Multiple `GNNGraph`s can be batched togheter into a single graph
containing the total number of the original nodes
that contains the total number of the original nodes
and where the original graphs are disjoint subgraphs.

```julia
using Flux
using Flux.Data: DataLoader

gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:160]
gall = Flux.batch(data)

# gall is a GNNGraph containing many graphs
@assert gall.num_graphs == 160
@assert gall.num_nodes == 1600 # 10 nodes x 160 graphs
@assert gall.num_edges == 9600 # 30 undirected edges x 2 directions x 160 graphs

# Let's create a mini-batch from gall
g23, _ = getgraph(gall, 2:3)
@assert g23.num_graphs == 2
@assert g23.num_nodes == 20 # 10 nodes x 160 graphs
@assert g23.num_edges == 120 # 30 undirected edges x 2 directions x 2 graphs x


# DataLoader compatibility
train_loader = Flux.Data.DataLoader(gall, batchsize=16, shuffle=true)
# We can pass a GNNGraph to Flux's DataLoader
train_loader = DataLoader(gall, batchsize=16, shuffle=true)

for g in train_loader
@assert g.num_graphs == 16
@assert g.num_nodes == 160
@assert size(g.ndata.x) = (3, 160)
.....
# .....
end

# Access the nodes' graph memberships
graph_indicator(gall)
```

## DataLoader and mini-batch iteration

While constructing a batched graph and passing it to the `DataLoader` is always
an option for mini-batch iteration, the recommended way is
to pass an array of graphs directly:

```julia
using Flux.Data: DataLoader

data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:320]

train_loader = DataLoader(data, batchsize=16, shuffle=true)

for g in train_loader
@assert g.num_graphs == 16
@assert g.num_nodes == 160
@assert size(g.ndata.x) = (3, 160)
# .....
end
```

## Graph Manipulation

```julia
Expand Down
66 changes: 26 additions & 40 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,64 +23,50 @@ Usage examples on real datasets can be found in the [examples](https://github.co

### Data preparation

First, we create our dataset consisting in multiple random graphs and associated data features.
Then we batch the graphs together into a unique graph.
We create a dataset consisting in multiple random graphs and associated data features.

```julia
julia> using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics

julia> all_graphs = GNNGraph[];

julia> for _ in 1:1000
g = GNNGraph(random_regular_graph(10, 4),
ndata=(; x = randn(Float32, 16,10)), # input node features
gdata=(; y = randn(Float32))) # regression target
push!(all_graphs, g)
end

julia> gbatch = Flux.batch(all_graphs)
GNNGraph:
num_nodes = 10000
num_edges = 40000
num_graphs = 1000
ndata:
x => (16, 10000)
gdata:
y => (1000,)
```
using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics
using Flux.Data: DataLoader

all_graphs = GNNGraph[]

for _ in 1:1000
g = GNNGraph(random_regular_graph(10, 4),
ndata=(; x = randn(Float32, 16,10)), # input node features
gdata=(; y = randn(Float32))) # regression target
push!(all_graphs, g)
end
```

### Model building

We concisely define our model as a [`GNNChain`](@ref) containing 2 graph convolutional
layers. If CUDA is available, our model will live on the gpu.
We concisely define our model as a [`GNNChain`](@ref) containing two graph convolutional layers. If CUDA is available, our model will live on the gpu.

```julia
julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu;

julia> model = GNNChain(GCNConv(16 => 64),
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
x -> relu.(x),
GCNConv(64 => 64, relu),
GlobalPool(mean), # aggregate node-wise features into graph-wise features
Dense(64, 1)) |> device;
device = CUDA.functional() ? Flux.gpu : Flux.cpu;

julia> ps = Flux.params(model);
model = GNNChain(GCNConv(16 => 64),
BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension)
x -> relu.(x),
GCNConv(64 => 64, relu),
GlobalPool(mean), # aggregate node-wise features into graph-wise features
Dense(64, 1)) |> device

julia> opt = ADAM(1f-4);
ps = Flux.params(model)
opt = ADAM(1f-4)
```

### Training

Finally, we use a standard Flux training pipeline to fit our dataset.
Flux's DataLoader iterates over mini-batches of graphs
Flux's `DataLoader` iterates over mini-batches of graphs
(batched together into a `GNNGraph` object).

```julia
gtrain = getgraph(gbatch, 1:800)
gtest = getgraph(gbatch, 801:gbatch.num_graphs)
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
train_size = round(Int, 0.8 * length(all_graphs))
train_loader = DataLoader(all_graphs[1:train_size], batchsize=32, shuffle=true)
test_loader = DataLoader(all_graphs[train_size+1:end], batchsize=32, shuffle=false)

loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2)

Expand Down
3 changes: 3 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

[extras]
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
38 changes: 17 additions & 21 deletions examples/graph_classification_tudataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ function eval_loss_accuracy(model, data_loader, device)
end

function getdataset()
data = TUDataset("MUTAG")
tudata = TUDataset("MUTAG")

x = Array{Float32}(onehotbatch(data.node_labels, 0:6))
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
x = Array{Float32}(onehotbatch(tudata.node_labels, 0:6))
y = (1 .+ Array{Float32}(tudata.graph_labels)) ./ 2
@assert all(([0,1]), y) # binary classification
# The dataset also has edge features but we won't be using them
e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))

return GNNGraph(data.source, data.target,
num_nodes=data.num_nodes,
graph_indicator=data.graph_indicator,
ndata=(; x), edata=(; e), gdata=(; y))
## The dataset also has edge features but we won't be using them
# e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))

gall = GNNGraph(tudata.source, tudata.target,
num_nodes=tudata.num_nodes,
graph_indicator=tudata.graph_indicator,
ndata=(; x), gdata=(; y))

return [getgraph(gall, i) for i=1:gall.num_graphs]
end

# arguments for the `train` function
Expand Down Expand Up @@ -66,23 +69,17 @@ function train(; kws...)
end

# LOAD DATA


NUM_TRAIN = 150

gfull = getdataset()

@info gfull
data = getdataset()
shuffle!(data)

perm = randperm(gfull.num_graphs)
gtrain = getgraph(gfull, perm[1:NUM_TRAIN])
gtest = getgraph(gfull, perm[NUM_TRAIN+1:end])
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
train_loader = DataLoader(data[1:NUM_TRAIN], batchsize=args.batchsize, shuffle=true)
test_loader = DataLoader(data[NUM_TRAIN+1:end], batchsize=args.batchsize, shuffle=false)

# DEFINE MODEL

nin = size(gtrain.ndata.x, 1)
nin = size(data[1].ndata.x, 1)
nhidden = args.nhidden

model = GNNChain(GraphConv(nin => nhidden, relu),
Expand All @@ -94,7 +91,6 @@ function train(; kws...)
ps = Flux.params(model)
opt = ADAM(args.η)


# LOGGING FUNCTION

function report(epoch)
Expand Down
17 changes: 15 additions & 2 deletions src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,27 @@ LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)

# DataLoader compatibility passing a vector of graphs and
# effectively using `batch` as a collated function.
StatsBase.nobs(data::Vector{<:GNNGraph}) = length(data)
LearnBase.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
LearnBase.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])
Flux.Data._nobs(g::Vector{<:GNNGraph}) = StatsBase.nobs(g)
Flux.Data._getobs(g::Vector{<:GNNGraph}, i) = LearnBase.getobs(g, i)


#########################

function Base.:(==)(g1::GNNGraph, g2::GNNGraph)
g1 === g2 && return true
all(k -> getfield(g1, k) == getfield(g2, k), fieldnames(typeof(g1)))
for k in fieldnames(typeof(g1))
k === :graph_indicator && continue
getfield(g1, k) != getfield(g2, k) && return false
end
return true
end

function Base.hash(g::T, h::UInt) where T<:GNNGraph
fs = (getfield(g, k) for k in fieldnames(typeof(g)))
fs = (getfield(g, k) for k in fieldnames(typeof(g)) if k !== :graph_indicator)
return foldl((h, f) -> hash(f, h), fs, init=hash(T, h))
end
34 changes: 23 additions & 11 deletions test/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,34 @@
# Attach non array data
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
@test g.edata.e == "ciao"
end
end

@testset "LearnBase and DataLoader compat" begin
n, m, num_graphs = 10, 30, 50
X = rand(10, n)
E = rand(10, 2m)
E = rand(10, m)
U = rand(10, 1)
g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
for _ in 1:num_graphs])

@test LearnBase.getobs(g, 3) == getgraph(g, 3)
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
@test StatsBase.nobs(g) == g.num_graphs

d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
data = [rand_graph(n, m, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T)
for _ in 1:num_graphs]
g = Flux.batch(data)

@testset "batch then pass to dataloader" begin
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
@test StatsBase.nobs(g) == g.num_graphs

d = Flux.Data.DataLoader(g, batchsize=2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
end

@testset "pass to dataloader and collate" begin
@test LearnBase.getobs(data, 3) == getgraph(g, 3)
@test LearnBase.getobs(data, 3:5) == getgraph(g, 3:5)
@test StatsBase.nobs(data) == g.num_graphs

d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false)
@test first(d) == getgraph(g, 1:2)
end
end

@testset "Graphs.jl integration" begin
Expand Down
16 changes: 12 additions & 4 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@
end

@testset "GATConv" begin

for heads in (1, 2), concat in (true, false)
l = GATConv(in_channel => out_channel; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=RTOL_LOW,
exclude_grad_fields = [:negative_slope],
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end
Expand All @@ -116,7 +117,9 @@
ein = 3
l = GATConv((in_channel, ein) => out_channel, add_self_loops=false)
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
test_layer(l, g, rtol=RTOL_LOW,
exclude_grad_fields = [:negative_slope],
outsize=(out_channel, g.num_nodes))
end

@testset "num params" begin
Expand All @@ -135,6 +138,7 @@
l = GATv2Conv(in_channel => out_channel, tanh; heads, concat)
for g in test_graphs
test_layer(l, g, rtol=RTOL_LOW,
exclude_grad_fields = [:negative_slope],
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
end
end
Expand All @@ -143,7 +147,9 @@
ein = 3
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes))
test_layer(l, g, rtol=RTOL_LOW,
exclude_grad_fields = [:negative_slope],
outsize=(out_channel, g.num_nodes))
end

@testset "num params" begin
Expand All @@ -159,7 +165,9 @@
ein = 3
l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false)
g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges))
test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes))
test_layer(l, g, rtol=1e-3,
exclude_grad_fields = [:negative_slope],
outsize=(out_channel, g.num_nodes))
end
end

Expand Down

0 comments on commit 05daca6

Please sign in to comment.