Skip to content

Commit

Permalink
fix gatedgraphconv
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 17, 2024
1 parent 6a23a70 commit cc387e4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 5 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l

function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style
# make the forward compatible with Flux.GRUCell style
function fgru(x, h)
y, (h, ) = gru((x, (h,)))
return y, h
end
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end
Expand Down
3 changes: 1 addition & 2 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
for i in 1:(l.num_layers)
m = view(l.weight, :, :, i) * h
m = propagate(copy_xj, g, l.aggr; xj = m)
# in gru forward, hidden state is first argument, input is second
h = l.gru(m, h)
_, h = l.gru(m, h)
end
return h
end
Expand Down

0 comments on commit cc387e4

Please sign in to comment.