Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EvolveGCNO temporal layer #489

Merged
merged 9 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ export TGCN,
A3TGCN,
GConvGRU,
GConvLSTM,
DCGRU
DCGRU,
EvolveGCNO

end #module

60 changes: 60 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,63 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)

DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))

@concrete struct EvolveGCNO <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_state::Function
init_bias
end

function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
Wf = l.init_weight(rng, l.out_dims, l.in_dims)
Uf = l.init_weight(rng, l.out_dims, l.in_dims)
Wi = l.init_weight(rng, l.out_dims, l.in_dims)
Ui = l.init_weight(rng, l.out_dims, l.in_dims)
Wo = l.init_weight(rng, l.out_dims, l.in_dims)
Uo = l.init_weight(rng, l.out_dims, l.in_dims)
Wc = l.init_weight(rng, l.out_dims, l.in_dims)
Uc = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
Bf = l.init_bias(rng, l.out_dims, l.in_dims)
Bi = l.init_bias(rng, l.out_dims, l.in_dims)
Bo = l.init_bias(rng, l.out_dims, l.in_dims)
Bc = l.init_bias(rng, l.out_dims, l.in_dims)
return (; conv = (; weight, bias), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc, Bf, Bi, Bo, Bc))
else
return (; conv = (; weight), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc))
end
end

function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO)
h = l.init_state(rng, l.out_dims, l.in_dims)
c = l.init_state(rng, l.out_dims, l.in_dims)
return (; conv = (;), lstm = (; h, c))
end

function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple)
H, C = st.lstm
W = ps.conv.weight
m = (; ps.conv.weight, bias = _getbias(ps),
add_self_loops =true, use_edge_weight=true, σ = identity)

X = map(1:tg.num_snapshots) do i
F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf)
I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi)
O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo)
C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc)
C = F .* C + I .* C̃
H = O .* NNlib.tanh_fast.(C)
W = H
GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W)
end
return X, (; conv = (;), lstm = (h = H, c = C))
end
11 changes: 11 additions & 0 deletions GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
g = rand_graph(rng, 10, 40)
x = randn(rng, Float32, 3, 10)

tg = TemporalSnapshotsGNNGraph([g for _ in 1:5])
tx = [x for _ in 1:5]

@testset "TGCN" begin
l = TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
Expand Down Expand Up @@ -44,4 +47,12 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "EvolveGCNO" begin
l = EvolveGCNO(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end
3 changes: 2 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ export TGCN,
A3TGCN,
GConvLSTM,
GConvGRU,
DCGRU
DCGRU,
EvolveGCNO

include("layers/pool.jl")
export GlobalPool,
Expand Down
97 changes: 97 additions & 0 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,103 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)

"""
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)

Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).

Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.


# Arguments

- `in`: Number of input features.
- `out`: Number of output features.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.

# Examples

```jldoctest
julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_edges: [20, 14, 22]
num_snapshots: 3

julia> ev = EvolveGCNO(4 => 5)
EvolveGCNO(4 => 5)

julia> size(ev(tg, tg.ndata.x))
(3,)

julia> size(ev(tg, tg.ndata.x)[1])
(5, 10)
```
"""
struct EvolveGCNO
conv
W_init
init_state
in::Int
out::Int
Wf
Uf
Bf
Wi
Ui
Bi
Wo
Uo
Bo
Wc
Uc
Bc
end

Flux.@functor EvolveGCNO

function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
in, out = ch
W = init(out, in)
conv = GCNConv(ch; bias = bias, init = init)
Wf = init(out, in)
Uf = init(out, in)
Bf = bias ? init(out, in) : nothing
Wi = init(out, in)
Ui = init(out, in)
Bi = bias ? init(out, in) : nothing
Wo = init(out, in)
Uo = init(out, in)
Bo = bias ? init(out, in) : nothing
Wc = init(out, in)
Uc = init(out, in)
Bc = bias ? init(out, in) : nothing
return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
end

function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
H = egcno.init_state(egcno.out, egcno.in)
C = egcno.init_state(egcno.out, egcno.in)
W = egcno.W_init
X = map(1:tg.num_snapshots) do i
F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
C = F .* C + I .* C̃
H = O .* tanh_fast.(C)
W = H
egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
end
return X
end

function Base.show(io::IO, egcno::EvolveGCNO)
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
end

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
end
Expand Down
6 changes: 6 additions & 0 deletions test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ end
@test model(g1) isa GNNGraph
end

@testset "EvolveGCNO" begin
evolvegcno = EvolveGCNO(in_channel => out_channel)
@test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S
@test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N)
end

@testset "GINConv" begin
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
@test length(ginconv(tg, tg.ndata.x)) == S
Expand Down
Loading