diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 40a12b25e..68a0503c3 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -40,6 +40,7 @@ export AGNNConv, include("layers/temporalconv.jl") export TGCN +export A3TGCN end #module \ No newline at end of file diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 2ab6235b3..50c45027f 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -56,4 +56,40 @@ function Base.show(io::IO, tgcn::TGCNCell) print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))") end -TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...)) \ No newline at end of file +TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...)) + +@concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)} + in_dims::Int + out_dims::Int + tgcn + dense1 + dense2 +end + +function A3TGCN(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) + in_dims, out_dims = ch + tgcn = TGCN(ch; use_bias, init_weight, init_state, init_bias, add_self_loops, use_edge_weight) + dense1 = Dense(out_dims, out_dims) + dense2 = Dense(out_dims, out_dims) + return A3TGCN(in_dims, out_dims, tgcn, dense1, dense2) +end + +function (l::A3TGCN)(g, x, ps, st) + dense1 = StatefulLuxLayer{true}(l.dense1, ps.dense1, _getstate(st, :dense1)) + dense2 = StatefulLuxLayer{true}(l.dense2, ps.dense2, _getstate(st, :dense2)) + h, st = l.tgcn(g, x, ps.tgcn, st.tgcn) + x = dense1(h) + x = dense2(x) + a = NNlib.softmax(x, dims = 3) + c = sum(a .* h , dims = 3) + if length(size(c)) == 3 + c = dropdims(c, dims = 3) + end + return c, st +end + +LuxCore.outputsize(l::A3TGCN) = (l.out_dims,) + +function Base.show(io::IO, l::A3TGCN) + print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))") +end diff --git a/GNNLux/test/layers/temporalconv_test.jl b/GNNLux/test/layers/temporalconv_test.jl index bdde7b325..073b16b49 100644 --- a/GNNLux/test/layers/temporalconv_test.jl +++ b/GNNLux/test/layers/temporalconv_test.jl @@ -2,7 +2,7 @@ using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme rng = StableRNG(1234) - g = rand_graph(10, 40, seed=1234) + g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) @testset "TGCN" begin @@ -12,4 +12,12 @@ loss = (x, ps) -> sum(first(l(g, x, ps, st))) test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end + + @testset "A3TGCN" begin + l = A3TGCN(3=>3) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + end end \ No newline at end of file