Skip to content

Commit

Permalink
Fix for quadratic batching in JuliaGraphs#99
Browse files Browse the repository at this point in the history
  • Loading branch information
tclements committed Jan 12, 2022
1 parent 364c57e commit 3e2d3cf
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
30 changes: 28 additions & 2 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
end
return g
end
SparseArrays.blockdiag(gs::Vector{GNNGraph}) = SparseArrays.blockdiag(gs...)

"""
batch(gs::Vector{<:GNNGraph})
Expand Down Expand Up @@ -253,8 +254,33 @@ julia> g12.ndata.x
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
```
"""
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)

function Flux.batch(gs::Vector{<:GNNGraph})
nodes = [g.num_nodes for g in gs]

if all(y -> isa(y, COO_T), [g.graph for g in gs] )
edge_indices = [edge_index(g) for g in gs]
nodesum = cumsum([0, nodes...])[1:end-1]
s = reduce(vcat, [ei[1] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)])
t = reduce(vcat, [ei[2] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)])
w = reduce(vcat, [get_edge_weight(g) for g in gs])
w = w isa Vector{Nothing} ? nothing : w
graph = (s, t, w)
graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...)
elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] )
graph = blockdiag([g.graph for g in gs]...)
graph_indicator = vcat([ones_like(graph,Int,nodes[ii]) .+ (ii - 1) for ii in 1:length(nodes)]...)
end

GNNGraph(graph,
sum(nodes),
sum([g.num_edges for g in gs]),
sum([g.num_graphs for g in gs]),
graph_indicator,
cat_features([g.ndata for g in gs]),
cat_features([g.edata for g in gs]),
cat_features([g.gdata for g in gs]),
)
end

"""
unbatch(g::GNNGraph)
Expand Down
14 changes: 14 additions & 0 deletions src/GNNGraphs/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ function cat_features(x1::NamedTuple, x2::NamedTuple)
NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1))
end

function cat_features(xs::Vector{NamedTuple{T1, T2}}) where {T1, T2}
symbols = [sort(collect(keys(x))) for x in xs]
all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys"
length(xs) == 1 && return xs[1]

# concatenate
syms = symbols[1]
dims = [max(1, ndims(xs[1][k])) for k in syms] # promote scalar to 1D
methods = [dim == 1 ? vcat : hcat for dim in dims] # use optimized reduce(hcat,xs) or reduce(vcat,xs)
NamedTuple(
k => reduce(methods[ii],[x[k] for x in xs]) for (ii,k) in enumerate(syms)
)
end

# Turns generic type into named tuple
normalize_graphdata(data::Nothing; kws...) = NamedTuple()

Expand Down
1 change: 1 addition & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

g12 = Flux.batch([g1, g2])
g12b = blockdiag(g1, g2)
@test g12 == g12b

g123 = Flux.batch([g1, g2, g3])
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
Expand Down

0 comments on commit 3e2d3cf

Please sign in to comment.