From 05daca699c5aceec7a6ef5cd70a5c603ba613427 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 5 Mar 2022 17:03:56 +0100 Subject: [PATCH] dataloader support for vector of graphs (#143) --- .gitignore | 1 + docs/src/gnngraph.md | 36 ++++++++++-- docs/src/index.md | 66 +++++++++------------- examples/Project.toml | 3 + examples/graph_classification_tudataset.jl | 38 ++++++------- src/GNNGraphs/gnngraph.jl | 17 +++++- test/GNNGraphs/gnngraph.jl | 34 +++++++---- test/layers/conv.jl | 16 ++++-- 8 files changed, 127 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index 5be50c11d..4948d2158 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ Manifest.toml /docs/build/ .vscode +LocalPreferences.toml diff --git a/docs/src/gnngraph.md b/docs/src/gnngraph.md index 8afb27941..5745b6b68 100644 --- a/docs/src/gnngraph.md +++ b/docs/src/gnngraph.md @@ -144,38 +144,62 @@ julia> get_edge_weight(g) ## Batches and Subgraphs Multiple `GNNGraph`s can be batched togheter into a single graph -containing the total number of the original nodes +that contains the total number of the original nodes and where the original graphs are disjoint subgraphs. ```julia using Flux +using Flux.Data: DataLoader -gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160]) +data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:160] +gall = Flux.batch(data) +# gall is a GNNGraph containing many graphs @assert gall.num_graphs == 160 @assert gall.num_nodes == 1600 # 10 nodes x 160 graphs @assert gall.num_edges == 9600 # 30 undirected edges x 2 directions x 160 graphs +# Let's create a mini-batch from gall g23, _ = getgraph(gall, 2:3) @assert g23.num_graphs == 2 @assert g23.num_nodes == 20 # 10 nodes x 160 graphs @assert g23.num_edges == 120 # 30 undirected edges x 2 directions x 2 graphs x - -# DataLoader compatibility -train_loader = Flux.Data.DataLoader(gall, batchsize=16, shuffle=true) +# We can pass a GNNGraph to Flux's DataLoader +train_loader = DataLoader(gall, batchsize=16, shuffle=true) for g in train_loader @assert g.num_graphs == 16 @assert g.num_nodes == 160 @assert size(g.ndata.x) = (3, 160) - ..... + # ..... end # Access the nodes' graph memberships graph_indicator(gall) ``` +## DataLoader and mini-batch iteration + +While constructing a batched graph and passing it to the `DataLoader` is always +an option for mini-batch iteration, the recommended way is +to pass an array of graphs directly: + +```julia +using Flux.Data: DataLoader + +data = [rand_graph(10, 30, ndata=rand(Float32, 3, 10)) for _ in 1:320] + +train_loader = DataLoader(data, batchsize=16, shuffle=true) + +for g in train_loader + @assert g.num_graphs == 16 + @assert g.num_nodes == 160 + @assert size(g.ndata.x) = (3, 160) + # ..... +end +``` + ## Graph Manipulation ```julia diff --git a/docs/src/index.md b/docs/src/index.md index 232b59399..798b2e5b2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -23,64 +23,50 @@ Usage examples on real datasets can be found in the [examples](https://github.co ### Data preparation -First, we create our dataset consisting in multiple random graphs and associated data features. -Then we batch the graphs together into a unique graph. +We create a dataset consisting in multiple random graphs and associated data features. ```julia -julia> using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics - -julia> all_graphs = GNNGraph[]; - -julia> for _ in 1:1000 - g = GNNGraph(random_regular_graph(10, 4), - ndata=(; x = randn(Float32, 16,10)), # input node features - gdata=(; y = randn(Float32))) # regression target - push!(all_graphs, g) - end - -julia> gbatch = Flux.batch(all_graphs) -GNNGraph: - num_nodes = 10000 - num_edges = 40000 - num_graphs = 1000 - ndata: - x => (16, 10000) - gdata: - y => (1000,) -``` +using GraphNeuralNetworks, Graphs, Flux, CUDA, Statistics +using Flux.Data: DataLoader +all_graphs = GNNGraph[] + +for _ in 1:1000 + g = GNNGraph(random_regular_graph(10, 4), + ndata=(; x = randn(Float32, 16,10)), # input node features + gdata=(; y = randn(Float32))) # regression target + push!(all_graphs, g) +end +``` ### Model building -We concisely define our model as a [`GNNChain`](@ref) containing 2 graph convolutional -layers. If CUDA is available, our model will live on the gpu. +We concisely define our model as a [`GNNChain`](@ref) containing two graph convolutional layers. If CUDA is available, our model will live on the gpu. ```julia -julia> device = CUDA.functional() ? Flux.gpu : Flux.cpu; - -julia> model = GNNChain(GCNConv(16 => 64), - BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension) - x -> relu.(x), - GCNConv(64 => 64, relu), - GlobalPool(mean), # aggregate node-wise features into graph-wise features - Dense(64, 1)) |> device; +device = CUDA.functional() ? Flux.gpu : Flux.cpu; -julia> ps = Flux.params(model); +model = GNNChain(GCNConv(16 => 64), + BatchNorm(64), # Apply batch normalization on node features (nodes dimension is batch dimension) + x -> relu.(x), + GCNConv(64 => 64, relu), + GlobalPool(mean), # aggregate node-wise features into graph-wise features + Dense(64, 1)) |> device -julia> opt = ADAM(1f-4); +ps = Flux.params(model) +opt = ADAM(1f-4) ``` ### Training Finally, we use a standard Flux training pipeline to fit our dataset. -Flux's DataLoader iterates over mini-batches of graphs +Flux's `DataLoader` iterates over mini-batches of graphs (batched together into a `GNNGraph` object). ```julia -gtrain = getgraph(gbatch, 1:800) -gtest = getgraph(gbatch, 801:gbatch.num_graphs) -train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true) -test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false) +train_size = round(Int, 0.8 * length(all_graphs)) +train_loader = DataLoader(all_graphs[1:train_size], batchsize=32, shuffle=true) +test_loader = DataLoader(all_graphs[train_size+1:end], batchsize=32, shuffle=false) loss(g::GNNGraph) = mean((vec(model(g, g.ndata.x)) - g.gdata.y).^2) diff --git a/examples/Project.toml b/examples/Project.toml index 6145e4e5a..fe614b8d3 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -10,3 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" + +[extras] +CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" diff --git a/examples/graph_classification_tudataset.jl b/examples/graph_classification_tudataset.jl index 83617c246..271814451 100644 --- a/examples/graph_classification_tudataset.jl +++ b/examples/graph_classification_tudataset.jl @@ -27,18 +27,21 @@ function eval_loss_accuracy(model, data_loader, device) end function getdataset() - data = TUDataset("MUTAG") + tudata = TUDataset("MUTAG") - x = Array{Float32}(onehotbatch(data.node_labels, 0:6)) - y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2 + x = Array{Float32}(onehotbatch(tudata.node_labels, 0:6)) + y = (1 .+ Array{Float32}(tudata.graph_labels)) ./ 2 @assert all(∈([0,1]), y) # binary classification - # The dataset also has edge features but we won't be using them - e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels)))) - return GNNGraph(data.source, data.target, - num_nodes=data.num_nodes, - graph_indicator=data.graph_indicator, - ndata=(; x), edata=(; e), gdata=(; y)) + ## The dataset also has edge features but we won't be using them + # e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels)))) + + gall = GNNGraph(tudata.source, tudata.target, + num_nodes=tudata.num_nodes, + graph_indicator=tudata.graph_indicator, + ndata=(; x), gdata=(; y)) + + return [getgraph(gall, i) for i=1:gall.num_graphs] end # arguments for the `train` function @@ -66,23 +69,17 @@ function train(; kws...) end # LOAD DATA - - NUM_TRAIN = 150 - gfull = getdataset() - - @info gfull + data = getdataset() + shuffle!(data) - perm = randperm(gfull.num_graphs) - gtrain = getgraph(gfull, perm[1:NUM_TRAIN]) - gtest = getgraph(gfull, perm[NUM_TRAIN+1:end]) - train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true) - test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false) + train_loader = DataLoader(data[1:NUM_TRAIN], batchsize=args.batchsize, shuffle=true) + test_loader = DataLoader(data[NUM_TRAIN+1:end], batchsize=args.batchsize, shuffle=false) # DEFINE MODEL - nin = size(gtrain.ndata.x, 1) + nin = size(data[1].ndata.x, 1) nhidden = args.nhidden model = GNNChain(GraphConv(nin => nhidden, relu), @@ -94,7 +91,6 @@ function train(; kws...) ps = Flux.params(model) opt = ADAM(args.η) - # LOGGING FUNCTION function report(epoch) diff --git a/src/GNNGraphs/gnngraph.jl b/src/GNNGraphs/gnngraph.jl index f49f798e7..7100a4cfd 100644 --- a/src/GNNGraphs/gnngraph.jl +++ b/src/GNNGraphs/gnngraph.jl @@ -231,14 +231,27 @@ LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i) Flux.Data._nobs(g::GNNGraph) = g.num_graphs Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i) +# DataLoader compatibility passing a vector of graphs and +# effectively using `batch` as a collated function. +StatsBase.nobs(data::Vector{<:GNNGraph}) = length(data) +LearnBase.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i] +LearnBase.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i]) +Flux.Data._nobs(g::Vector{<:GNNGraph}) = StatsBase.nobs(g) +Flux.Data._getobs(g::Vector{<:GNNGraph}, i) = LearnBase.getobs(g, i) + + ######################### function Base.:(==)(g1::GNNGraph, g2::GNNGraph) g1 === g2 && return true - all(k -> getfield(g1, k) == getfield(g2, k), fieldnames(typeof(g1))) + for k in fieldnames(typeof(g1)) + k === :graph_indicator && continue + getfield(g1, k) != getfield(g2, k) && return false + end + return true end function Base.hash(g::T, h::UInt) where T<:GNNGraph - fs = (getfield(g, k) for k in fieldnames(typeof(g))) + fs = (getfield(g, k) for k in fieldnames(typeof(g)) if k !== :graph_indicator) return foldl((h, f) -> hash(f, h), fs, init=hash(T, h)) end diff --git a/test/GNNGraphs/gnngraph.jl b/test/GNNGraphs/gnngraph.jl index 08efe3ca5..5167b58b5 100644 --- a/test/GNNGraphs/gnngraph.jl +++ b/test/GNNGraphs/gnngraph.jl @@ -229,22 +229,34 @@ # Attach non array data g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T) @test g.edata.e == "ciao" - end + end @testset "LearnBase and DataLoader compat" begin n, m, num_graphs = 10, 30, 50 X = rand(10, n) - E = rand(10, 2m) + E = rand(10, m) U = rand(10, 1) - g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) - for _ in 1:num_graphs]) - - @test LearnBase.getobs(g, 3) == getgraph(g, 3) - @test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5) - @test StatsBase.nobs(g) == g.num_graphs - - d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false) - @test first(d) == getgraph(g, 1:2) + data = [rand_graph(n, m, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + for _ in 1:num_graphs] + g = Flux.batch(data) + + @testset "batch then pass to dataloader" begin + @test LearnBase.getobs(g, 3) == getgraph(g, 3) + @test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5) + @test StatsBase.nobs(g) == g.num_graphs + + d = Flux.Data.DataLoader(g, batchsize=2, shuffle=false) + @test first(d) == getgraph(g, 1:2) + end + + @testset "pass to dataloader and collate" begin + @test LearnBase.getobs(data, 3) == getgraph(g, 3) + @test LearnBase.getobs(data, 3:5) == getgraph(g, 3:5) + @test StatsBase.nobs(data) == g.num_graphs + + d = Flux.Data.DataLoader(data, batchsize=2, shuffle=false) + @test first(d) == getgraph(g, 1:2) + end end @testset "Graphs.jl integration" begin diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 0c3c6b964..a7de105b6 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -103,11 +103,12 @@ end @testset "GATConv" begin - + for heads in (1, 2), concat in (true, false) l = GATConv(in_channel => out_channel; heads, concat) for g in test_graphs test_layer(l, g, rtol=RTOL_LOW, + exclude_grad_fields = [:negative_slope], outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) end end @@ -116,7 +117,9 @@ ein = 3 l = GATConv((in_channel, ein) => out_channel, add_self_loops=false) g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges)) - test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=RTOL_LOW, + exclude_grad_fields = [:negative_slope], + outsize=(out_channel, g.num_nodes)) end @testset "num params" begin @@ -135,6 +138,7 @@ l = GATv2Conv(in_channel => out_channel, tanh; heads, concat) for g in test_graphs test_layer(l, g, rtol=RTOL_LOW, + exclude_grad_fields = [:negative_slope], outsize=(concat ? heads*out_channel : out_channel, g.num_nodes)) end end @@ -143,7 +147,9 @@ ein = 3 l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false) g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges)) - test_layer(l, g, rtol=RTOL_LOW, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=RTOL_LOW, + exclude_grad_fields = [:negative_slope], + outsize=(out_channel, g.num_nodes)) end @testset "num params" begin @@ -159,7 +165,9 @@ ein = 3 l = GATv2Conv((in_channel, ein) => out_channel, add_self_loops=false) g = GNNGraph(g1, edata=rand(T, ein, g1.num_edges)) - test_layer(l, g, rtol=1e-3, outsize=(out_channel, g.num_nodes)) + test_layer(l, g, rtol=1e-3, + exclude_grad_fields = [:negative_slope], + outsize=(out_channel, g.num_nodes)) end end