From cc387e43095b7aa59ca4b6c921671e09830986c4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 11:26:13 +0100 Subject: [PATCH] fix gatedgraphconv --- GNNLux/src/layers/conv.jl | 6 +++++- GNNlib/src/layers/conv.jl | 3 +-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index f92dd1ec6..63c4f90b4 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l function (l::GatedGraphConv)(g, x, ps, st) gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style + # make the forward compatible with Flux.GRUCell style + function fgru(x, h) + y, (h, ) = gru((x, (h,))) + return y, h + end m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) return GNNlib.gated_graph_conv(m, g, x), st end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index fe7c27d9c..bd9bd18b3 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) for i in 1:(l.num_layers) m = view(l.weight, :, :, i) * h m = propagate(copy_xj, g, l.aggr; xj = m) - # in gru forward, hidden state is first argument, input is second - h = l.gru(m, h) + _, h = l.gru(m, h) end return h end