diff --git a/docs/src/gnngraph.md b/docs/src/gnngraph.md index c1bfc2d80..b2fab3033 100644 --- a/docs/src/gnngraph.md +++ b/docs/src/gnngraph.md @@ -20,7 +20,7 @@ lg = erdos_renyi(10, 30) g = GNNGraph(lg) # Same as above using convenience method rand_graph -g = rand_graph(10, 30) +g = rand_graph(10, 60) # From an adjacency matrix A = sprand(10, 10, 0.3) diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index d188359bd..51e8891c6 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -21,9 +21,9 @@ export GNNGraph, node_features, edge_features, graph_features include("query.jl") export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian, graph_indicator - + include("transform.jl") -export add_edges, add_self_loops, remove_self_loops, getgraph +export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph include("generate.jl") export rand_graph @@ -38,6 +38,6 @@ export # from SparseArrays sprand, sparse, blockdiag, # from Flux - batch + batch, unbatch end #module diff --git a/src/GNNGraphs/generate.jl b/src/GNNGraphs/generate.jl index 67cfcac30..67f5d1f1b 100644 --- a/src/GNNGraphs/generate.jl +++ b/src/GNNGraphs/generate.jl @@ -1,14 +1,52 @@ """ - rand_graph(n, m; directed=false, kws...) + rand_graph(n, m; bidirected=true, kws...) -Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes. +Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes +and `m` edges. -If `directed=false` the output will contain `2m` edges: -the reverse edge of each edge will be present. -If `directed=true` instead, `m` unrelated edges are generated. +If `bidirected=true` the reverse edge of each edge will be present. +If `bidirected=false` instead, `m` unrelated edges are generated. +In any case, the output graph will contain no self-loops or multi-edges. -Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor. +Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Usage + +```juliarepl +julia> g = rand_graph(5, 4, bidirected=false) +GNNGraph: + num_nodes = 5 + num_edges = 4 + num_graphs = 1 + ndata: + edata: + gdata: + + +julia> edge_index(g) +([1, 3, 3, 4], [5, 4, 5, 2]) + +# In the bidirected case, edge data will be duplicated on the reverse edges if needed. +julia> g = rand_graph(5, 4, edata=rand(16, 2)) +GNNGraph: + num_nodes = 5 + num_edges = 4 + num_graphs = 1 + ndata: + edata: + e => (16, 4) + gdata: + +# Each edge has a reverse +julia> edge_index(g) +([1, 3, 3, 4], [3, 4, 1, 3]) + +``` """ -function rand_graph(n::Integer, m::Integer; directed=false, kws...) - return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...) +function rand_graph(n::Integer, m::Integer; bidirected=true, kws...) + if bidirected + @assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m." + end + m2 = bidirected ? m÷2 : m + return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...) end diff --git a/src/GNNGraphs/gnngraph.jl b/src/GNNGraphs/gnngraph.jl index 8a36b7396..a0bf2440f 100644 --- a/src/GNNGraphs/gnngraph.jl +++ b/src/GNNGraphs/gnngraph.jl @@ -56,7 +56,7 @@ functionality from that library. Optionally, also edge weights can be given: `(source, target, weights)`. - `:sparse`. A sparse adjacency matrix representation. - `:dense`. A dense adjacency matrix representation. - Default `:coo`. + Defaults to `:coo`, currently the most supported type. - `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. Possible values are `:out` and `:in`. Default `:out`. - `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. diff --git a/src/GNNGraphs/query.jl b/src/GNNGraphs/query.jl index 6cf476412..bb4fb6f29 100644 --- a/src/GNNGraphs/query.jl +++ b/src/GNNGraphs/query.jl @@ -62,7 +62,7 @@ Graphs.is_directed(::Type{<:GNNGraph}) = true Return the adjacency list representation (a vector of vectors) of the graph `g`. -Calling `a` the adjacency list, if `dir=:out` +Calling `a` the adjacency list, if `dir=:out` than `a[i]` will contain the neighbors of node `i` through outgoing edges. If `dir=:in`, it will contain neighbors from incoming edges instead. @@ -75,7 +75,7 @@ end function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out) if g.graph[1] isa CuVector - # TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152 + # TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152 A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes) else A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes) diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index 1d9b783d7..825b22eed 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -54,8 +54,7 @@ end """ add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata]) -Add to graph `g` the edges with source nodes `s` and target nodes `t`. - +Add to graph `g` the edges with source nodes `s` and target nodes `t`. """ function add_edges(g::GNNGraph{<:COO_T}, snew::AbstractVector{<:Integer}, @@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T}, g.ndata, edata, g.gdata) end + +""" + add_nodes(g::GNNGraph, n; [ndata]) + +Add `n` new nodes to graph `g`. In the +new graph, these nodes will have indexes from `g.num_nodes + 1` +to `g.num_nodes + n`. +""" +function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;)) + ndata = normalize_graphdata(ndata, default_name=:x, n=n) + ndata = cat_features(g.ndata, ndata) + + GNNGraph(g.graph, + g.num_nodes + n, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, g.edata, g.gdata) +end + + function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) nv1, nv2 = g1.num_nodes, g2.num_nodes if g1.graph isa COO_T @@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) O2 A2] end -### Cat public interfaces ############# - """ blockdiag(xs::GNNGraph...) @@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) end """ - batch(xs::Vector{<:GNNGraph}) + batch(gs::Vector{<:GNNGraph}) Batch together multiple `GNNGraph`s into a single one containing the total number of original nodes and edges. Equivalent to [`SparseArrays.blockdiag`](@ref). +See also [`Flux.unbatch`](@ref). + +# Usage + +```juliarepl +julia> g1 = rand_graph(4, 6, ndata=ones(8, 4)) +GNNGraph: + num_nodes = 4 + num_edges = 6 + num_graphs = 1 + ndata: + x => (8, 4) + edata: + gdata: + + +julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7)) +GNNGraph: + num_nodes = 7 + num_edges = 4 + num_graphs = 1 + ndata: + x => (8, 7) + edata: + gdata: + + +julia> g12 = Flux.batch([g1, g2]) +GNNGraph: + num_nodes = 11 + num_edges = 10 + num_graphs = 2 + ndata: + x => (8, 11) + edata: + gdata: + + +julia> g12.ndata.x +8×11 Matrix{Float64}: + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +``` +""" +Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...) + + """ -Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...) + unbatch(g::GNNGraph) + +Opposite of the [`Flux.batch`](@ref) operation, returns +an array of the individual graphs batched together in `g`. + +See also [`Flux.batch`](@ref) and [`getgraph`](@ref). + +# Usage + +```juliarepl +julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)]) +GNNGraph: + num_nodes = 19 + num_edges = 16 + num_graphs = 3 + ndata: + edata: + gdata: + +julia> Flux.unbatch(gbatched) +3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: + GNNGraph: + num_nodes = 5 + num_edges = 6 + num_graphs = 1 + ndata: + edata: + gdata: + + GNNGraph: + num_nodes = 10 + num_edges = 8 + num_graphs = 1 + ndata: + edata: + gdata: + + GNNGraph: + num_nodes = 4 + num_edges = 2 + num_graphs = 1 + ndata: + edata: + gdata: +``` +""" +function Flux.unbatch(g::GNNGraph) + [getgraph(g, i) for i in 1:g.num_graphs] +end """ diff --git a/test/GNNGraphs/generate.jl b/test/GNNGraphs/generate.jl index 326844659..82e4c8f07 100644 --- a/test/GNNGraphs/generate.jl +++ b/test/GNNGraphs/generate.jl @@ -1,20 +1,21 @@ @testset "generate" begin @testset "rand_graph" begin n, m = 10, 20 + m2 = m ÷ 2 x = rand(3, n) - e = rand(4, m) + e = rand(4, m2) g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T) @test g.num_nodes == n - @test g.num_edges == 2m + @test g.num_edges == m @test g.ndata.x === x if GRAPH_T == :coo s, t = edge_index(g) - @test s[1:m] == t[m+1:end] - @test t[1:m] == s[m+1:end] - @test g.edata.e[:,1:m] == e - @test g.edata.e[:,m+1:end] == e + @test s[1:m2] == t[m2+1:end] + @test t[1:m2] == s[m2+1:end] + @test g.edata.e[:,1:m2] == e + @test g.edata.e[:,m2+1:end] == e end - g = rand_graph(n, m, directed=true, graph_type=GRAPH_T) + g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T) @test g.num_nodes == n @test g.num_edges == m end diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index aabdcf221..b0b617b94 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -42,6 +42,20 @@ @test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u] end + @testset "unbatch" begin + g1 = rand_graph(10, 20) + g2 = rand_graph(5, 10) + g12 = Flux.batch([g1, g2]) + gs = Flux.unbatch([g1,g2]) + @test length(gs) == 2 + @test gs[1].num_nodes == 10 + @test gs[1].num_edges == 20 + @test gs[1].num_graphs == 1 + @test gs[2].num_nodes == 5 + @test gs[2].num_edges == 10 + @test gs[2].num_graphs == 1 + end + @testset "getgraph" begin g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T) g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T) @@ -80,4 +94,15 @@ @test all(gnew.edata.e2[:,5] .== 0) end end + + @testset "add_nodes" begin + if GRAPH_T == :coo + g = rand_graph(6, 4, ndata=rand(2, 6), graph_type=GRAPH_T) + gnew = add_nodes(g, 5, ndata=ones(2, 5)) + @test gnew.num_nodes == g.num_nodes + 5 + @test gnew.num_edges == g.num_edges + @test gnew.num_graphs == g.num_graphs + @test all(gnew.ndata.x[:,7:11] .== 1) + end + end end \ No newline at end of file