From 9d2346ac80c4c633894703cb735d44ce334cd18a Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Wed, 4 Sep 2024 16:19:45 +0200 Subject: [PATCH 1/9] First draft --- src/layers/temporalconv.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 443ef2a3a..823ee9d2f 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -484,6 +484,36 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) +struct EvolveGCNO + conv + lstm + W_init + init_state + in::Int + out::Int +end + +Flux.@functor EvolveGCNO + +function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + W = init(out, in) + conv = GCNConv(ch; bias = bias, init = init) + lstm = Flux.LSTM(out,out) + return EvolveGCNO(conv, lstm, W, init_state, in, out) +end + +function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph) + H = egcno.init_state(egcno.out, tg.snapshots[i].num_nodes, tg.num_snapshots) + W = egcno.W_init + for i in 1:tg.num_snapshots + W = egcno.lstm(W) + H[:,:,i] .= egcno.conv(tg.snapshots[i], tg.ndata.x[i]; conv_weight = W) + end + return H +end + + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end From 247a9a876d5ea1e215c8cf0c75026baa95470c0e Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 14:15:46 +0200 Subject: [PATCH 2/9] Add test --- test/layers/temporalconv.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index 45c8acf04..bdf44b45f 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -69,6 +69,12 @@ end @test model(g1) isa GNNGraph end +@testset "EvolveGCNO" begin + evolvegcno = EvolveGCNO(in_channel => out_channel) + @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S + @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N) +end + @testset "GINConv" begin ginconv = GINConv(Dense(in_channel => out_channel),0.3) @test length(ginconv(tg, tg.ndata.x)) == S From 5919956d62d2db0e8402af51898a48c8ad19b3d5 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 14:16:02 +0200 Subject: [PATCH 3/9] Add export `EvolveGCNO` --- src/GraphNeuralNetworks.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index bf6991155..cebf7b7d3 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -54,7 +54,8 @@ export TGCN, A3TGCN, GConvLSTM, GConvGRU, - DCGRU + DCGRU, + EvolveGCNO include("layers/pool.jl") export GlobalPool, From 4ea56b1902ee187c1dc5146742a1ce3730cd0d35 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 14:17:32 +0200 Subject: [PATCH 4/9] Improve `EvolveGCNO` --- src/layers/temporalconv.jl | 90 +++++++++++++++++++++++++++++++++----- 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 823ee9d2f..b69e86eac 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -484,13 +484,59 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) +""" + EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + +Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). + +Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. + + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10] + num_edges: [20, 14, 22] + num_snapshots: 3 + +julia> ev = EvolveGCNO(4 => 5) +EvolveGCNO(4 => 5) + +julia> size(ev(tg, tg.ndata.x)) +(3,) + +julia> size(ev(tg, tg.ndata.x)[1]) +(5, 10) +``` +""" struct EvolveGCNO conv - lstm W_init init_state in::Int out::Int + Wf + Uf + Bf + Wi + Ui + Bi + Wo + Uo + Bo + Wc + Uc + Bc end Flux.@functor EvolveGCNO @@ -499,20 +545,42 @@ function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.ze in, out = ch W = init(out, in) conv = GCNConv(ch; bias = bias, init = init) - lstm = Flux.LSTM(out,out) - return EvolveGCNO(conv, lstm, W, init_state, in, out) -end - -function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph) - H = egcno.init_state(egcno.out, tg.snapshots[i].num_nodes, tg.num_snapshots) + Wf = init(out, in) + Uf = init(out, in) + Bf = bias ? init(out, in) : nothing + Wi = init(out, in) + Ui = init(out, in) + Bi = bias ? init(out, in) : nothing + Wo = init(out, in) + Uo = init(out, in) + Bo = bias ? init(out, in) : nothing + Wc = init(out, in) + Uc = init(out, in) + Bc = bias ? init(out, in) : nothing + return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) +end + +function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) + X = egcno.init_state(egcno.out, tg.snapshots[1].num_nodes, tg.num_snapshots) + H = egcno.init_state(egcno.out, egcno.in) + C = egcno.init_state(egcno.out, egcno.in) W = egcno.W_init - for i in 1:tg.num_snapshots - W = egcno.lstm(W) - H[:,:,i] .= egcno.conv(tg.snapshots[i], tg.ndata.x[i]; conv_weight = W) + X = map(1:tg.num_snapshots) do i + F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) + I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) + O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) + C̃ = Flux.tanh.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) + C = F .* C + I .* C̃ + H = O .* tanh_fast.(C) + W = H + egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) end - return H + return X end +function Base.show(io::IO, egcno::EvolveGCNO) + print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") +end function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) From c5cc135bf764a91d5be6647d66e9926edece4708 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 16:44:56 +0200 Subject: [PATCH 5/9] Ecport `EvolveGCNo` --- GNNLux/src/GNNLux.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index f566cd0c6..1a9abc322 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -44,7 +44,8 @@ export TGCN, A3TGCN, GConvGRU, GConvLSTM, - DCGRU + DCGRU, + EvolveGCNO end #module \ No newline at end of file From 823108323ff06b9ce46c907606d0b91c61c965b2 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 16:45:13 +0200 Subject: [PATCH 6/9] Add `EvolveGCNO` --- GNNLux/src/layers/temporalconv.jl | 62 +++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 687a21983..12c0608eb 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -274,3 +274,65 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,) DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...)) +@concrete struct EvolveGCNO <: GNNContainerLayer{(:conv,)} + in_dims::Int + out_dims::Int + conv + use_bias::Bool + init_weight + init_state::Function + init_bias +end + +function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32) + in_dims, out_dims = ch + conv = GCNConv(ch; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias) + return EvolveGCNO(in_dims, out_dims, conv, use_bias, init_weight, init_state, init_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + Wf = l.init_weight(rng, l.out_dims, l.in_dims) + Uf = l.init_weight(rng, l.out_dims, l.in_dims) + Wi = l.init_weight(rng, l.out_dims, l.in_dims) + Ui = l.init_weight(rng, l.out_dims, l.in_dims) + Wo = l.init_weight(rng, l.out_dims, l.in_dims) + Uo = l.init_weight(rng, l.out_dims, l.in_dims) + Wc = l.init_weight(rng, l.out_dims, l.in_dims) + Uc = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + Bf = l.init_bias(rng, l.out_dims, l.in_dims) + Bi = l.init_bias(rng, l.out_dims, l.in_dims) + Bo = l.init_bias(rng, l.out_dims, l.in_dims) + Bc = l.init_bias(rng, l.out_dims, l.in_dims) + return (; conv = (; weight, bias), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc, Bf, Bi, Bo, Bc)) + else + return (; conv = (; weight), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc)) + end +end + +function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO) + h = l.init_state(rng, l.out_dims, l.in_dims) + c = l.init_state(rng, l.out_dims, l.in_dims) + return (; conv = (;), lstm = (; h, c)) +end + +function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple) + H = st.lstm.h + C = st.lstm.c + W = ps.conv.weight + X = map(1:tg.num_snapshots) do i + F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W + ps.lstm.Uf .* H + ps.lstm.Bf) + I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W + ps.lstm.Ui .* H + ps.lstm.Bi) + O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W + ps.lstm.Uo .* H + ps.lstm.Bo) + C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W + ps.lstm.Uc .* H + ps.lstm.Bc) + C = F .* C + I .* C̃ + H = O .* NNlib.tanh_fast.(C) + W = H + X, _ = l.conv(tg.snapshots[i], x[i], ps.conv, st.conv; conv_weight = H) + X + end + return X, (conv = (), lstm = (h = H, c = C)) +end + From 1e451c4beb4faa4edf5f3464abb0d9f85daa8fcd Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Thu, 5 Sep 2024 16:45:25 +0200 Subject: [PATCH 7/9] Fix --- src/layers/temporalconv.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index b69e86eac..2f6292f28 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -561,7 +561,6 @@ function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.ze end function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) - X = egcno.init_state(egcno.out, tg.snapshots[1].num_nodes, tg.num_snapshots) H = egcno.init_state(egcno.out, egcno.in) C = egcno.init_state(egcno.out, egcno.in) W = egcno.W_init @@ -569,7 +568,7 @@ function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) - C̃ = Flux.tanh.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) + C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) C = F .* C + I .* C̃ H = O .* tanh_fast.(C) W = H From d49bde79641779d3a058cb84da4b8ba300b2b786 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Fri, 6 Sep 2024 13:30:51 +0200 Subject: [PATCH 8/9] Add `EvolveGCNO` test --- GNNLux/test/layers/temporalconv_test.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/GNNLux/test/layers/temporalconv_test.jl b/GNNLux/test/layers/temporalconv_test.jl index 7a7c48f4a..ec670b6bc 100644 --- a/GNNLux/test/layers/temporalconv_test.jl +++ b/GNNLux/test/layers/temporalconv_test.jl @@ -5,6 +5,9 @@ g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) + tg = TemporalSnapshotsGNNGraph([g for _ in 1:5]) + tx = [x for _ in 1:5] + @testset "TGCN" begin l = TGCN(3=>3) ps = LuxCore.initialparameters(rng, l) @@ -44,4 +47,12 @@ loss = (x, ps) -> sum(first(l(g, x, ps, st))) test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end + + @testset "EvolveGCNO" begin + l = EvolveGCNO(3=>3) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st)))) + test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + end end \ No newline at end of file From 8e6175ba2977eeea7f9d5754c46f0c5f8bb7ad93 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 9 Sep 2024 14:47:37 +0200 Subject: [PATCH 9/9] Fix --- GNNLux/src/layers/temporalconv.jl | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 12c0608eb..09594bf67 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -274,10 +274,9 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,) DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...)) -@concrete struct EvolveGCNO <: GNNContainerLayer{(:conv,)} +@concrete struct EvolveGCNO <: GNNLayer in_dims::Int out_dims::Int - conv use_bias::Bool init_weight init_state::Function @@ -286,8 +285,7 @@ end function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32) in_dims, out_dims = ch - conv = GCNConv(ch; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias) - return EvolveGCNO(in_dims, out_dims, conv, use_bias, init_weight, init_state, init_bias) + return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias) end function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO) @@ -319,20 +317,20 @@ function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO) end function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple) - H = st.lstm.h - C = st.lstm.c + H, C = st.lstm W = ps.conv.weight + m = (; ps.conv.weight, bias = _getbias(ps), + add_self_loops =true, use_edge_weight=true, σ = identity) + X = map(1:tg.num_snapshots) do i - F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W + ps.lstm.Uf .* H + ps.lstm.Bf) - I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W + ps.lstm.Ui .* H + ps.lstm.Bi) - O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W + ps.lstm.Uo .* H + ps.lstm.Bo) - C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W + ps.lstm.Uc .* H + ps.lstm.Bc) + F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf) + I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi) + O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo) + C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc) C = F .* C + I .* C̃ H = O .* NNlib.tanh_fast.(C) W = H - X, _ = l.conv(tg.snapshots[i], x[i], ps.conv, st.conv; conv_weight = H) - X + GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W) end - return X, (conv = (), lstm = (h = H, c = C)) + return X, (; conv = (;), lstm = (h = H, c = C)) end -