Skip to content

Commit

Permalink
implement unbatch and add_nodes (#65)
Browse files Browse the repository at this point in the history
* add unbatch

* implement add_nodes

* cleanup
  • Loading branch information
CarloLucibello authored Oct 31, 2021
1 parent 195bb6c commit da7b1e4
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lg = erdos_renyi(10, 30)
g = GNNGraph(lg)

# Same as above using convenience method rand_graph
g = rand_graph(10, 30)
g = rand_graph(10, 60)

# From an adjacency matrix
A = sprand(10, 10, 0.3)
Expand Down
6 changes: 3 additions & 3 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ export GNNGraph, node_features, edge_features, graph_features
include("query.jl")
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
graph_indicator

include("transform.jl")
export add_edges, add_self_loops, remove_self_loops, getgraph
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph

include("generate.jl")
export rand_graph
Expand All @@ -38,6 +38,6 @@ export
# from SparseArrays
sprand, sparse, blockdiag,
# from Flux
batch
batch, unbatch

end #module
54 changes: 46 additions & 8 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,52 @@
"""
rand_graph(n, m; directed=false, kws...)
rand_graph(n, m; bidirected=true, kws...)
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
and `m` edges.
If `directed=false` the output will contain `2m` edges:
the reverse edge of each edge will be present.
If `directed=true` instead, `m` unrelated edges are generated.
If `bidirected=true` the reverse edge of each edge will be present.
If `bidirected=false` instead, `m` unrelated edges are generated.
In any case, the output graph will contain no self-loops or multi-edges.
Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
# Usage
```juliarepl
julia> g = rand_graph(5, 4, bidirected=false)
GNNGraph:
num_nodes = 5
num_edges = 4
num_graphs = 1
ndata:
edata:
gdata:
julia> edge_index(g)
([1, 3, 3, 4], [5, 4, 5, 2])
# In the bidirected case, edge data will be duplicated on the reverse edges if needed.
julia> g = rand_graph(5, 4, edata=rand(16, 2))
GNNGraph:
num_nodes = 5
num_edges = 4
num_graphs = 1
ndata:
edata:
e => (16, 4)
gdata:
# Each edge has a reverse
julia> edge_index(g)
([1, 3, 3, 4], [3, 4, 1, 3])
```
"""
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
if bidirected
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
end
m2 = bidirected ? m÷2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
end
2 changes: 1 addition & 1 deletion src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ functionality from that library.
Optionally, also edge weights can be given: `(source, target, weights)`.
- `:sparse`. A sparse adjacency matrix representation.
- `:dense`. A dense adjacency matrix representation.
Default `:coo`.
Defaults to `:coo`, currently the most supported type.
- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
Possible values are `:out` and `:in`. Default `:out`.
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
Expand Down
4 changes: 2 additions & 2 deletions src/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Graphs.is_directed(::Type{<:GNNGraph}) = true
Return the adjacency list representation (a vector of vectors)
of the graph `g`.
Calling `a` the adjacency list, if `dir=:out`
Calling `a` the adjacency list, if `dir=:out` than
`a[i]` will contain the neighbors of node `i` through
outgoing edges. If `dir=:in`, it will contain neighbors from
incoming edges instead.
Expand All @@ -75,7 +75,7 @@ end

function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out)
if g.graph[1] isa CuVector
# TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
else
A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes)
Expand Down
129 changes: 123 additions & 6 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ end
"""
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
"""
function add_edges(g::GNNGraph{<:COO_T},
snew::AbstractVector{<:Integer},
Expand All @@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T},
g.ndata, edata, g.gdata)
end


"""
add_nodes(g::GNNGraph, n; [ndata])
Add `n` new nodes to graph `g`. In the
new graph, these nodes will have indexes from `g.num_nodes + 1`
to `g.num_nodes + n`.
"""
function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;))
ndata = normalize_graphdata(ndata, default_name=:x, n=n)
ndata = cat_features(g.ndata, ndata)

GNNGraph(g.graph,
g.num_nodes + n, g.num_edges, g.num_graphs,
g.graph_indicator,
ndata, g.edata, g.gdata)
end


function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
nv1, nv2 = g1.num_nodes, g2.num_nodes
if g1.graph isa COO_T
Expand Down Expand Up @@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
O2 A2]
end

### Cat public interfaces #############

"""
blockdiag(xs::GNNGraph...)
Expand All @@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
end

"""
batch(xs::Vector{<:GNNGraph})
batch(gs::Vector{<:GNNGraph})
Batch together multiple `GNNGraph`s into a single one
containing the total number of original nodes and edges.
Equivalent to [`SparseArrays.blockdiag`](@ref).
See also [`Flux.unbatch`](@ref).
# Usage
```juliarepl
julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
GNNGraph:
num_nodes = 4
num_edges = 6
num_graphs = 1
ndata:
x => (8, 4)
edata:
gdata:
julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7))
GNNGraph:
num_nodes = 7
num_edges = 4
num_graphs = 1
ndata:
x => (8, 7)
edata:
gdata:
julia> g12 = Flux.batch([g1, g2])
GNNGraph:
num_nodes = 11
num_edges = 10
num_graphs = 2
ndata:
x => (8, 11)
edata:
gdata:
julia> g12.ndata.x
8×11 Matrix{Float64}:
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
```
"""
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)


"""
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
unbatch(g::GNNGraph)
Opposite of the [`Flux.batch`](@ref) operation, returns
an array of the individual graphs batched together in `g`.
See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
# Usage
```juliarepl
julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
GNNGraph:
num_nodes = 19
num_edges = 16
num_graphs = 3
ndata:
edata:
gdata:
julia> Flux.unbatch(gbatched)
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph:
num_nodes = 5
num_edges = 6
num_graphs = 1
ndata:
edata:
gdata:
GNNGraph:
num_nodes = 10
num_edges = 8
num_graphs = 1
ndata:
edata:
gdata:
GNNGraph:
num_nodes = 4
num_edges = 2
num_graphs = 1
ndata:
edata:
gdata:
```
"""
function Flux.unbatch(g::GNNGraph)
[getgraph(g, i) for i in 1:g.num_graphs]
end


"""
Expand Down
15 changes: 8 additions & 7 deletions test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
@testset "generate" begin
@testset "rand_graph" begin
n, m = 10, 20
m2 = m ÷ 2
x = rand(3, n)
e = rand(4, m)
e = rand(4, m2)
g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == 2m
@test g.num_edges == m
@test g.ndata.x === x
if GRAPH_T == :coo
s, t = edge_index(g)
@test s[1:m] == t[m+1:end]
@test t[1:m] == s[m+1:end]
@test g.edata.e[:,1:m] == e
@test g.edata.e[:,m+1:end] == e
@test s[1:m2] == t[m2+1:end]
@test t[1:m2] == s[m2+1:end]
@test g.edata.e[:,1:m2] == e
@test g.edata.e[:,m2+1:end] == e
end
g = rand_graph(n, m, directed=true, graph_type=GRAPH_T)
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == m
end
Expand Down
25 changes: 25 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
end

@testset "unbatch" begin
g1 = rand_graph(10, 20)
g2 = rand_graph(5, 10)
g12 = Flux.batch([g1, g2])
gs = Flux.unbatch([g1,g2])
@test length(gs) == 2
@test gs[1].num_nodes == 10
@test gs[1].num_edges == 20
@test gs[1].num_graphs == 1
@test gs[2].num_nodes == 5
@test gs[2].num_edges == 10
@test gs[2].num_graphs == 1
end

@testset "getgraph" begin
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
Expand Down Expand Up @@ -80,4 +94,15 @@
@test all(gnew.edata.e2[:,5] .== 0)
end
end

@testset "add_nodes" begin
if GRAPH_T == :coo
g = rand_graph(6, 4, ndata=rand(2, 6), graph_type=GRAPH_T)
gnew = add_nodes(g, 5, ndata=ones(2, 5))
@test gnew.num_nodes == g.num_nodes + 5
@test gnew.num_edges == g.num_edges
@test gnew.num_graphs == g.num_graphs
@test all(gnew.ndata.x[:,7:11] .== 1)
end
end
end

0 comments on commit da7b1e4

Please sign in to comment.