Skip to content

Commit

Permalink
TGCNCell
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 16, 2024
1 parent b229ab2 commit 23d9e45
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
48 changes: 48 additions & 0 deletions GraphNeuralNetworks/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,51 @@ julia> size(y[end]) # (d_out, num_nodes[end])
```
"""
EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))



@concrete struct TGCNCell <: GNNLayer
in::Int
out::Int
conv_z
dense_z
conv_r
dense_r
conv_h
dense_h
end

Flux.@layer :noexpand TGCNCell

function TGCNCell((in, out)::Pair{Int, Int}; kws...)
conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_z = Dense(2*out => out, sigmoid)
conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_r = Dense(2*out => out, sigmoid)
conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...))
dense_h = Dense(2*out => out, tanh)
return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h)
end

Flux.initialstates(cell::TGCNCell) = zeros_like(cell.dense_z.weight, cell.out)

(cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell))

function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector)
return cell(g, x, repeat(h, 1, g.num_nodes))
end

function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix)
z = cell.conv_z(g, x)
z = cell.dense_z(vcat(z, h))
r = cell.conv_r(g, x)
r = cell.dense_r(vcat(r, h))
= cell.conv_h(g, x)
= cell.dense_h(vcat(h̃, r .* h))
h = (1 .- z) .* h .+ z .*
return h, h
end

function Base.show(io::IO, cell::TGCNCell)
print(io, "TGCNCell($(cell.in) => $(cell.out))")
end
3 changes: 1 addition & 2 deletions GraphNeuralNetworks/src/layers/temporalconv_old.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ end

Flux.@layer :noexpand TGCNCell

function TGCNCell(ch::Pair{Int, Int};
function TGCNCell((in, out)::Pair{Int, Int};
bias::Bool = true,
init = Flux.glorot_uniform,
add_self_loops = false)
in, out = ch
conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops)
gru = GRUCell(out => out)
return TGCNCell(conv, gru, in, out)
Expand Down

0 comments on commit 23d9e45

Please sign in to comment.