diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 04a5f56d8..c1ea60b1e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -329,7 +329,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix, end @functor GATConv -Flux.trainable(l::GATConv) = (l.dense_x, l.dense_e, l.bias, l.a) +Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a) GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...) @@ -457,7 +457,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer end @functor GATv2Conv -Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_e, l.bias, l.a) +Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a) function GATv2Conv(ch::Pair{Int, Int}, args...; kws...) GATv2Conv((ch[1], 0) => ch[2], args...; kws...) @@ -668,7 +668,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer end @functor GINConv -Flux.trainable(l::GINConv) = (l.nn,) +Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) @@ -1569,7 +1569,7 @@ end @functor TransformerConv function Flux.trainable(l::TransformerConv) - (l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) + (W1 = l.W1, W2 = l.W2, W3 = l.W3, W4 = l.W4, W5 = l.W5, W6 = l.W6, FF = l.FF, BN1 = l.BN1, BN2 = l.BN2) end function TransformerConv(ch::Pair{Int, Int}, args...; kws...)