Skip to content

Commit

Permalink
Merge pull request #48 from CarloLucibello/cl/redesign
Browse files Browse the repository at this point in the history
add nodes/edges softmax and readout
  • Loading branch information
CarloLucibello authored Sep 30, 2021
2 parents b2f6ebf + d4db332 commit 867cc10
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 42 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ makedocs(;
"Convolutional Layers" => "api/conv.md",
"Pooling Layers" => "api/pool.md",
"Message Passing" => "api/messagepassing.md",
"NNlib" => "api/nnlib.md",
"Utils" => "api/utils.md",
],
"Developer Notes" => "dev.md",
],
Expand Down
23 changes: 0 additions & 23 deletions docs/src/api/nnlib.md

This file was deleted.

37 changes: 37 additions & 0 deletions docs/src/api/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
```@meta
CurrentModule = GraphNeuralNetworks
```

# Utility Functions

## Index

```@index
Order = [:type, :function]
Pages = ["utils.md"]
```

## Docs


### Graph-wise operations

```@docs
GraphNeuralNetworks.reduce_nodes
GraphNeuralNetworks.reduce_edges
GraphNeuralNetworks.softmax_nodes
GraphNeuralNetworks.softmax_edges
GraphNeuralNetworks.broadcast_nodes
GraphNeuralNetworks.broadcast_edges
```

### NNlib

Primitive functions implemented in NNlib.jl.

```@docs
NNlib.gather!
NNlib.gather
NNlib.scatter!
NNlib.scatter
```
6 changes: 6 additions & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using MacroTools: @forward
import LearnBase
using LearnBase: getobs
using NNlib, NNlibCUDA
using NNlib: scatter, gather
using ChainRulesCore
import LightGraphs
using LightGraphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
Expand All @@ -30,6 +31,11 @@ export
# from SparseArrays
sprand, sparse, blockdiag,

# utils
reduce_nodes, reduce_edges,
softmax_nodes, softmax_edges,
broadcast_nodes, broadcast_edges,

# msgpass
apply_edges, propagate,
copyxj,
Expand Down
17 changes: 8 additions & 9 deletions src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,25 +541,24 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]

s, t = edge_index(g)
w = edge_weight(g)
edge_mask = s .∈ Ref(nodes)

if g.graph isa COO_T
s, t = edge_index(g)
w = edge_weight(g)
edge_mask = s .∈ Ref(nodes)
s = [nodemap[i] for i in s[edge_mask]]
t = [nodemap[i] for i in t[edge_mask]]
w = isnothing(w) ? nothing : w[edge_mask]
graph = (s, t, w)
num_edges = length(s)
edata = getobs(g.edata, edge_mask)
elseif g.graph isa ADJMAT_T
graph = g.graph[nodes, nodes]
num_edges = count(>=(0), graph)
@assert g.edata == (;) # TODO
edata = (;)
end

ndata = getobs(g.ndata, node_mask)
edata = getobs(g.edata, edge_mask)
gdata = getobs(g.gdata, i)


num_edges = sum(edge_mask)
num_nodes = length(graph_indicator)
num_graphs = length(i)

Expand Down
13 changes: 5 additions & 8 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ where ``V`` is the set of nodes of the input graph and
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
Commonly used aggregations are `mean`, `max`, and `+`.
See also [`reduce_nodes`](@ref).
# Examples
```julia
using Flux, GraphNeuralNetworks, LightGraphs
Expand All @@ -33,14 +36,8 @@ struct GlobalPool{F} <: GNNLayer
aggr::F
end

function (l::GlobalPool)(g::GNNGraph, X::AbstractArray)
if isnothing(g.graph_indicator)
# assume only one graph
indexes = fill!(similar(X, Int, g.num_nodes), 1)
else
indexes = g.graph_indicator
end
return NNlib.scatter(l.aggr, X, indexes)
function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
return reduce_nodes(l.aggr, g, x)
end

"""
Expand Down
98 changes: 97 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,100 @@ function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray)
blocks = cld(max_idx, threads)
kernel(args...; threads=threads, blocks=blocks)
return dst
end
end

"""
reduce_nodes(aggr, g, x)
For a batched graph `g`, return the graph-wise aggregation of the node
features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
The returned array will have last dimension `g.num_graphs`.
"""
function reduce_nodes(aggr, g::GNNGraph, x)
@assert size(x)[end] == g.num_nodes
indexes = graph_indicator(g)
return NNlib.scatter(aggr, x, indexes)
end

"""
reduce_edges(aggr, g, e)
For a batched graph `g`, return the graph-wise aggregation of the edge
features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`.
The returned array will have last dimension `g.num_graphs`.
"""
function reduce_edges(aggr, g::GNNGraph, e)
@assert size(e)[end] == g.num_edges
s, t = edge_index(g)
indexes = graph_indicator(g)[s]
return NNlib.scatter(aggr, e, indexes)
end

"""
softmax_nodes(g, x)
Graph-wise softmax of the node features `x`.
"""
function softmax_nodes(g::GNNGraph, x)
@assert size(x)[end] == g.num_nodes
gi = graph_indicator(g)
max_ = gather(scatter(max, x, gi), gi)
num = exp.(x .- max_)
den = reduce_nodes(+, g, num)
den = gather(den, gi)
return num ./ den
end

"""
softmax_edges(g, e)
Graph-wise softmax of the edge features `e`.
"""
function softmax_edges(g::GNNGraph, e)
@assert size(e)[end] == g.num_edges
gi = graph_indicator(g, edges=true)
max_ = gather(scatter(max, e, gi), gi)
num = exp.(e .- max_)
den = reduce_edges(+, g, num)
den = gather(den, gi)
return num ./ den
end

"""
broadcast_nodes(g, x)
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
to size `(*, g.num_nodes)`.
"""
function broadcast_nodes(g::GNNGraph, x)
@assert size(x)[end] == g.num_graphs
gi = graph_indicator(g)
return gather(x, gi)
end

"""
broadcast_edges(g, x)
Graph-wise broadcast array `x` of size `(*, g.num_graphs)`
to size `(*, g.num_edges)`.
"""
function broadcast_edges(g::GNNGraph, x)
@assert size(x)[end] == g.num_graphs
gi = graph_indicator(g, edges=true)
return gather(x, gi)
end


function graph_indicator(g; edges=false)
if isnothing(g.graph_indicator)
gi = ones_like(edge_index(g)[1], Int, g.num_nodes)
else
gi = g.graph_indicator
end
if edges
s, t = edge_index(g)
return gi[s]
else
return gi
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("test_utils.jl")

tests = [
"gnngraph",
"utils",
"msgpass",
"layers/basic",
"layers/conv",
Expand Down
52 changes: 52 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
@testset "Utils" begin
De, Dx = 3, 2
g = Flux.batch([GNNGraph(erdos_renyi(10, 30),
ndata=rand(Dx, 10),
edata=rand(De, 30),
graph_type=GRAPH_T) for i=1:5])
x = g.ndata.x
e = g.edata.e

@testset "reduce_nodes" begin
r = reduce_nodes(mean, g, x)
@test size(r) == (Dx, g.num_graphs)
@test r[:,2] mean(getgraph(g, 2).ndata.x, dims=2)
end

@testset "reduce_edges" begin
r = reduce_edges(mean, g, e)
@test size(r) == (De, g.num_graphs)
@test r[:,2] mean(getgraph(g, 2).edata.e, dims=2)
end

@testset "softmax_nodes" begin
r = softmax_nodes(g, x)
@test size(r) == size(x)
@test r[:,1:10] softmax(getgraph(g, 1).ndata.x, dims=2)
end

@testset "softmax_edges" begin
r = softmax_edges(g, e)
@test size(r) == size(e)
@test r[:,1:60] softmax(getgraph(g, 1).edata.e, dims=2)
end


@testset "broadcast_nodes" begin
z = rand(4, g.num_graphs)
r = broadcast_nodes(g, z)
@test size(r) == (4, g.num_nodes)
@test r[:,1] z[:,1]
@test r[:,10] z[:,1]
@test r[:,11] z[:,2]
end

@testset "broadcast_edges" begin
z = rand(4, g.num_graphs)
r = broadcast_edges(g, z)
@test size(r) == (4, g.num_edges)
@test r[:,1] z[:,1]
@test r[:,60] z[:,1]
@test r[:,61] z[:,2]
end
end

0 comments on commit 867cc10

Please sign in to comment.