Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement unbatch and add_nodes #65

Merged
merged 3 commits into from
Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/gnngraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lg = erdos_renyi(10, 30)
g = GNNGraph(lg)

# Same as above using convenience method rand_graph
g = rand_graph(10, 30)
g = rand_graph(10, 60)

# From an adjacency matrix
A = sprand(10, 10, 0.3)
Expand Down
6 changes: 3 additions & 3 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ export GNNGraph, node_features, edge_features, graph_features
include("query.jl")
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
graph_indicator

include("transform.jl")
export add_edges, add_self_loops, remove_self_loops, getgraph
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph

include("generate.jl")
export rand_graph
Expand All @@ -38,6 +38,6 @@ export
# from SparseArrays
sprand, sparse, blockdiag,
# from Flux
batch
batch, unbatch

end #module
54 changes: 46 additions & 8 deletions src/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,52 @@
"""
rand_graph(n, m; directed=false, kws...)
rand_graph(n, m; bidirected=true, kws...)

Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
and `m` edges.

If `directed=false` the output will contain `2m` edges:
the reverse edge of each edge will be present.
If `directed=true` instead, `m` unrelated edges are generated.
If `bidirected=true` the reverse edge of each edge will be present.
If `bidirected=false` instead, `m` unrelated edges are generated.
In any case, the output graph will contain no self-loops or multi-edges.

Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.

# Usage

```juliarepl
julia> g = rand_graph(5, 4, bidirected=false)
GNNGraph:
num_nodes = 5
num_edges = 4
num_graphs = 1
ndata:
edata:
gdata:


julia> edge_index(g)
([1, 3, 3, 4], [5, 4, 5, 2])

# In the bidirected case, edge data will be duplicated on the reverse edges if needed.
julia> g = rand_graph(5, 4, edata=rand(16, 2))
GNNGraph:
num_nodes = 5
num_edges = 4
num_graphs = 1
ndata:
edata:
e => (16, 4)
gdata:

# Each edge has a reverse
julia> edge_index(g)
([1, 3, 3, 4], [3, 4, 1, 3])

```
"""
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
if bidirected
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
end
m2 = bidirected ? m÷2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
end
2 changes: 1 addition & 1 deletion src/GNNGraphs/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ functionality from that library.
Optionally, also edge weights can be given: `(source, target, weights)`.
- `:sparse`. A sparse adjacency matrix representation.
- `:dense`. A dense adjacency matrix representation.
Default `:coo`.
Defaults to `:coo`, currently the most supported type.
- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
Possible values are `:out` and `:in`. Default `:out`.
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
Expand Down
4 changes: 2 additions & 2 deletions src/GNNGraphs/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Graphs.is_directed(::Type{<:GNNGraph}) = true
Return the adjacency list representation (a vector of vectors)
of the graph `g`.

Calling `a` the adjacency list, if `dir=:out`
Calling `a` the adjacency list, if `dir=:out` than
`a[i]` will contain the neighbors of node `i` through
outgoing edges. If `dir=:in`, it will contain neighbors from
incoming edges instead.
Expand All @@ -75,7 +75,7 @@ end

function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out)
if g.graph[1] isa CuVector
# TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
else
A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes)
Expand Down
129 changes: 123 additions & 6 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ end
"""
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])

Add to graph `g` the edges with source nodes `s` and target nodes `t`.

Add to graph `g` the edges with source nodes `s` and target nodes `t`.
"""
function add_edges(g::GNNGraph{<:COO_T},
snew::AbstractVector{<:Integer},
Expand All @@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T},
g.ndata, edata, g.gdata)
end


"""
add_nodes(g::GNNGraph, n; [ndata])

Add `n` new nodes to graph `g`. In the
new graph, these nodes will have indexes from `g.num_nodes + 1`
to `g.num_nodes + n`.
"""
function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;))
ndata = normalize_graphdata(ndata, default_name=:x, n=n)
ndata = cat_features(g.ndata, ndata)

GNNGraph(g.graph,
g.num_nodes + n, g.num_edges, g.num_graphs,
g.graph_indicator,
ndata, g.edata, g.gdata)
end


function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
nv1, nv2 = g1.num_nodes, g2.num_nodes
if g1.graph isa COO_T
Expand Down Expand Up @@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
O2 A2]
end

### Cat public interfaces #############

"""
blockdiag(xs::GNNGraph...)

Expand All @@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
end

"""
batch(xs::Vector{<:GNNGraph})
batch(gs::Vector{<:GNNGraph})

Batch together multiple `GNNGraph`s into a single one
containing the total number of original nodes and edges.

Equivalent to [`SparseArrays.blockdiag`](@ref).
See also [`Flux.unbatch`](@ref).

# Usage

```juliarepl
julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
GNNGraph:
num_nodes = 4
num_edges = 6
num_graphs = 1
ndata:
x => (8, 4)
edata:
gdata:


julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7))
GNNGraph:
num_nodes = 7
num_edges = 4
num_graphs = 1
ndata:
x => (8, 7)
edata:
gdata:


julia> g12 = Flux.batch([g1, g2])
GNNGraph:
num_nodes = 11
num_edges = 10
num_graphs = 2
ndata:
x => (8, 11)
edata:
gdata:


julia> g12.ndata.x
8×11 Matrix{Float64}:
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
```
"""
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)


"""
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
unbatch(g::GNNGraph)

Opposite of the [`Flux.batch`](@ref) operation, returns
an array of the individual graphs batched together in `g`.

See also [`Flux.batch`](@ref) and [`getgraph`](@ref).

# Usage

```juliarepl
julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
GNNGraph:
num_nodes = 19
num_edges = 16
num_graphs = 3
ndata:
edata:
gdata:

julia> Flux.unbatch(gbatched)
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph:
num_nodes = 5
num_edges = 6
num_graphs = 1
ndata:
edata:
gdata:

GNNGraph:
num_nodes = 10
num_edges = 8
num_graphs = 1
ndata:
edata:
gdata:

GNNGraph:
num_nodes = 4
num_edges = 2
num_graphs = 1
ndata:
edata:
gdata:
```
"""
function Flux.unbatch(g::GNNGraph)
[getgraph(g, i) for i in 1:g.num_graphs]
end


"""
Expand Down
15 changes: 8 additions & 7 deletions test/GNNGraphs/generate.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
@testset "generate" begin
@testset "rand_graph" begin
n, m = 10, 20
m2 = m ÷ 2
x = rand(3, n)
e = rand(4, m)
e = rand(4, m2)
g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == 2m
@test g.num_edges == m
@test g.ndata.x === x
if GRAPH_T == :coo
s, t = edge_index(g)
@test s[1:m] == t[m+1:end]
@test t[1:m] == s[m+1:end]
@test g.edata.e[:,1:m] == e
@test g.edata.e[:,m+1:end] == e
@test s[1:m2] == t[m2+1:end]
@test t[1:m2] == s[m2+1:end]
@test g.edata.e[:,1:m2] == e
@test g.edata.e[:,m2+1:end] == e
end
g = rand_graph(n, m, directed=true, graph_type=GRAPH_T)
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == m
end
Expand Down
25 changes: 25 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
end

@testset "unbatch" begin
g1 = rand_graph(10, 20)
g2 = rand_graph(5, 10)
g12 = Flux.batch([g1, g2])
gs = Flux.unbatch([g1,g2])
@test length(gs) == 2
@test gs[1].num_nodes == 10
@test gs[1].num_edges == 20
@test gs[1].num_graphs == 1
@test gs[2].num_nodes == 5
@test gs[2].num_edges == 10
@test gs[2].num_graphs == 1
end

@testset "getgraph" begin
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
Expand Down Expand Up @@ -80,4 +94,15 @@
@test all(gnew.edata.e2[:,5] .== 0)
end
end

@testset "add_nodes" begin
if GRAPH_T == :coo
g = rand_graph(6, 4, ndata=rand(2, 6), graph_type=GRAPH_T)
gnew = add_nodes(g, 5, ndata=ones(2, 5))
@test gnew.num_nodes == g.num_nodes + 5
@test gnew.num_edges == g.num_edges
@test gnew.num_graphs == g.num_graphs
@test all(gnew.ndata.x[:,7:11] .== 1)
end
end
end