diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 21cfceceb..684044593 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -47,7 +47,6 @@ export # layers/pool GlobalPool, - LocalPool, TopKPool, topk_index diff --git a/src/gnngraph.jl b/src/gnngraph.jl index 9629a6fbf..e41b100ec 100644 --- a/src/gnngraph.jl +++ b/src/gnngraph.jl @@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T """ - GNNGraph(data; [graph_type, dir, num_nodes, nf, ef, gf]) + GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir]) GNNGraph(g::GNNGraph; [nf, ef, gf]) A type representing a graph structure and storing also arrays @@ -43,11 +43,13 @@ from the LightGraphs' graph library can be used on it. - `:dense`. A dense adjacency matrix representation. Default `:coo`. - `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`. - Possible values are `:out` and `:in`. Defaul `:out`. -- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing. -- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing. -- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing. -- `gf`: Global features. Default nothing. + Possible values are `:out` and `:in`. Default `:out`. +- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`. +- `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`. +- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`. +- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`. +- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`. +- `gf`: Global features. Default `nothing`. # Usage. @@ -87,6 +89,8 @@ struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} graph::T num_nodes::Int num_edges::Int + num_graphs::Int + graph_indicator nf ef gf @@ -99,7 +103,9 @@ end @functor GNNGraph function GNNGraph(data; - num_nodes = nothing, + num_nodes = nothing, + num_graphs = 1, + graph_indicator = nothing, graph_type = :coo, dir = :out, nf = nothing, @@ -119,6 +125,9 @@ function GNNGraph(data; elseif graph_type == :sparse g, num_nodes, num_edges = to_sparse(data; dir) end + if num_graphs > 1 + @assert len(graph_indicator) = num_nodes "When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships." + end ## Possible future implementation of feature maps. ## Currently this doesn't play well with zygote due to @@ -127,8 +136,9 @@ function GNNGraph(data; # edata["e"] = ef # gdata["g"] = gf - - GNNGraph(g, num_nodes, num_edges, nf, ef, gf) + GNNGraph(g, num_nodes, num_edges, + num_graphs, graph_indicator, + nf, ef, gf) end # COO convenience constructors @@ -147,7 +157,7 @@ function GNNGraph(g::GNNGraph; nf=node_feature(g), ef=edge_feature(g), gf=global_feature(g)) # ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data - GNNGraph(g.graph, g.num_nodes, g.num_edges, nf, ef, gf) # ndata, edata, gdata, + GNNGraph(g.graph, g.num_nodes, g.num_edges, g.num_graphs, g.graph_indicator, nf, ef, gf) # ndata, edata, gdata, end @@ -370,6 +380,7 @@ function add_self_loops(g::GNNGraph{<:COO_T}) t = [t; nodes] GNNGraph((s, t, nothing), g.num_nodes, length(s), + g.num_graphs, g.graph_indicator, node_feature(g), edge_feature(g), global_feature(g)) end @@ -379,6 +390,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true) A += I num_edges = g.num_edges + g.num_nodes GNNGraph(A, g.num_nodes, num_edges, + g.num_graphs, g.graph_indicator, node_feature(g), edge_feature(g), global_feature(g)) end @@ -392,10 +404,46 @@ function remove_self_loops(g::GNNGraph{<:COO_T}) s = s[mask_old_loops] t = t[mask_old_loops] - GNNGraph((s, t, nothing), g.num_nodes, length(s), + GNNGraph((s, t, nothing), g.num_nodes, length(s), + g.num_graphs, g.graph_indicator, node_feature(g), edge_feature(g), global_feature(g)) end +function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T}) + s1, t1 = edge_index(g1) + s2, t2 = edge_index(g2) + nv1, nv2 = g1.num_nodes, g2.num_nodes + s = vcat(s1, nv1 .+ s2) + t = vcat(t1, nv1 .+ t2) + w = cat_features(edge_weight(g1), edge_weight(g2)) + + ind1 = isnothing(g1.graph_indicator) ? fill!(similar(s1, Int, nv1), 1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? fill!(similar(s2, Int, nv2), 1) : g2.graph_indicator + graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) + + GNNGraph( + (s, t, w), + nv1 + nv2, g1.num_edges + g2.num_edges, + g1.num_graphs + g2.num_graphs, graph_indicator, + cat_features(node_feature(g1), node_feature(g2)), + cat_features(edge_feature(g1), edge_feature(g2)), + cat_features(global_feature(g1), global_feature(g2)), + ) +end + +# Cat public interfaces +function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) + @assert length(gothers) >= 1 + g = g1 + for go in gothers + g = _catgraphs(g, go) + end + return g +end + +Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...) +######################### + @non_differentiable normalized_laplacian(x...) @non_differentiable normalized_adjacency(x...) @non_differentiable scaled_laplacian(x...) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 9f9beccb8..163ab7bf6 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -1,45 +1,42 @@ using DataStructures: nlargest -""" - GlobalPool(aggr, dim...) - -Global pooling layer. - -It pools all features with `aggr` operation. - -# Arguments - -- `aggr`: An aggregate function applied to pool all features. -""" -struct GlobalPool{A} - aggr - cluster::A - function GlobalPool(aggr, dim...) - cluster = ones(Int64, dim) - new{typeof(cluster)}(aggr, cluster) - end -end +@doc raw""" + GlobalPool(aggr) -(l::GlobalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster) +Global pooling layer for graph neural networks. +Takes a graph and feature nodes as inputs +and performs the operation -""" - LocalPool(aggr, cluster) +```math +\mathbf{u}_V = \box_{i \in V} \mathbf{x}_i +```` +where ``V`` is the set of nodes of the input graph and +the type of aggregation represented by `\box` is selected by the `aggr` argument. +Commonly used aggregations are are `mean`, `max`, and `+`. -Local pooling layer. +```julia +using GraphNeuralNetworks, LightGraphs -It pools features with `aggr` operation accroding to `cluster`. It is implemented with `scatter` operation. +pool = GlobalPool(mean) -# Arguments - -- `aggr`: An aggregate function applied to pool all features. -- `cluster`: An index structure which indicates what features to aggregate with. +g = GNNGraph(random_regular_graph(10, 4)) +X = rand(32, 10) +pool(g, X) # => 32x1 matrix +``` """ -struct LocalPool{A<:AbstractArray} - aggr - cluster::A +struct GlobalPool{F} + aggr::F end -(l::LocalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster) +function (l::GlobalPool)(g::GNNGraph, X::AbstractArray) + if isnothing(g.graph_indicator) + # assume only one graph + indexes = fill!(similar(X, Int, g.num_nodes), 1) + else + indexes = g.graph_indicator + end + return NNlib.scatter(l.aggr, X, indexes) +end """ TopKPool(adj, k, in_channel) diff --git a/src/utils.jl b/src/utils.jl index b761d3123..ae01f2240 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,3 +9,6 @@ function sort_edge_index(u, v) p = sortperm(uv) # isless lexicographically defined for tuples return u[p], v[p] end + +cat_features(x1::Nothing, x2::Nothing) = nothing +cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1)) diff --git a/test/gnngraph.jl b/test/gnngraph.jl index 3b189178e..0e92a83f3 100644 --- a/test/gnngraph.jl +++ b/test/gnngraph.jl @@ -101,4 +101,16 @@ @test adjacency_matrix(fg2) == A2 @test fg2.num_edges == sum(A2) end + + @testset "batch" begin + g1 = GNNGraph(random_regular_graph(10,2), nf=rand(16,10)) + g2 = GNNGraph(random_regular_graph(4,2), nf=rand(16,4)) + g3 = GNNGraph(random_regular_graph(7,2), nf=rand(16,7)) + + g12 = Flux.batch([g1, g2]) + g12b = blockdiag(g1, g2) + + g123 = Flux.batch([g1, g2, g3]) + @test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)] + end end diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 32276d820..f35465f87 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -1,16 +1,10 @@ -cluster = [1 1 1 1; 2 2 3 3; 4 4 5 5] -X = Array(reshape(1:24, 2, 3, 4)) - @testset "pool" begin @testset "GlobalPool" begin - glb_cltr = [1 1 1 1; 1 1 1 1; 1 1 1 1] - p = GlobalPool(+, 3, 4) - @test p(X) == NNlib.scatter(+, X, glb_cltr) - end - - @testset "LocalPool" begin - p = LocalPool(+, cluster) - @test p(X) == NNlib.scatter(+, X, cluster) + n = 10 + X = rand(16, n) + g = GNNGraph(random_regular_graph(n, 4)) + p = GlobalPool(+) + @test p(g, X) ≈ NNlib.scatter(+, X, ones(Int, n)) end @testset "TopKPool" begin