Skip to content

Commit

Permalink
[GNNLux] TGCN temporal layer (#470)
Browse files Browse the repository at this point in the history
* First draft

* Fix signature

* Improvement

* Export TGCN

* Fixes

* Back to previous version

* Add test

* Remove GNNlib code

* Fix

* Fix

Co-authored-by: Carlo Lucibello <[email protected]>

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
aurorarossi and CarloLucibello authored Aug 19, 2024
1 parent 87f3c60 commit 7319f4d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
4 changes: 3 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ export AGNNConv,
SGConv
# TAGConv,
# TransformerConv


include("layers/temporalconv.jl")
export TGCN

end #module

59 changes: 59 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
end

function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell)
return (cell=LuxCore.initialstates(rng, r.cell), carry=nothing)
end

function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple)
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
return out, (; cell=st, carry)
end

function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple)
st, carry = st.cell, st.carry
for xᵢ in x
(out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry)
end
return out, (; cell=st, carry)
end

function applyrecurrentcell(l, g, x, ps, st, carry)
return Lux.apply(l, g, (x, carry), ps, st)
end

LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)

@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
in_dims::Int
out_dims::Int
conv
gru
init_state::Function
end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end

function (l::TGCNCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, 1)
end
x̃, stconv = l.conv(g, x, ps.conv, st.conv)
(h, (h,)), stgru = l.gru((x̃,(h,)), ps.gru,st.gru)
return (h, h), (conv=stconv, gru=stgru)
end

LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)
LuxCore.outputsize(l::GNNLux.StatefulRecurrentCell) = (l.cell.out_dims,)

function Base.show(io::IO, tgcn::TGCNCell)
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
end

TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
15 changes: 15 additions & 0 deletions GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@testitem "layers/temporalconv" setup=[SharedTestSetup] begin
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme

rng = StableRNG(1234)
g = rand_graph(10, 40, seed=1234)
x = randn(rng, Float32, 3, 10)

@testset "TGCN" begin
l = TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end
2 changes: 1 addition & 1 deletion GNNlib/src/GNNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export agnn_conv,
transformer_conv

include("layers/temporalconv.jl")
export a3tgcn_conv
export tgcn_conv

include("layers/pool.jl")
export global_pool,
Expand Down

0 comments on commit 7319f4d

Please sign in to comment.