Skip to content

Commit

Permalink
Change Tuple in NamedTuple (JuliaGraphs#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi authored Aug 28, 2023
1 parent e3019dd commit 9a4f4b7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, T, A <: AbstractMatrix,
end

@functor GATConv
Flux.trainable(l::GATConv) = (l.dense_x, l.dense_e, l.bias, l.a)
Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)

GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)

Expand Down Expand Up @@ -457,7 +457,7 @@ struct GATv2Conv{T, A1, A2, A3, B, C <: AbstractMatrix, F} <: GNNLayer
end

@functor GATv2Conv
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.dense_e, l.bias, l.a)
Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a)

function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
Expand Down Expand Up @@ -668,7 +668,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer
end

@functor GINConv
Flux.trainable(l::GINConv) = (l.nn,)
Flux.trainable(l::GINConv) = (nn = l.nn,)

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)

Expand Down Expand Up @@ -1569,7 +1569,7 @@ end
@functor TransformerConv

function Flux.trainable(l::TransformerConv)
(l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2)
(W1 = l.W1, W2 = l.W2, W3 = l.W3, W4 = l.W4, W5 = l.W5, W6 = l.W6, FF = l.FF, BN1 = l.BN1, BN2 = l.BN2)
end

function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
Expand Down

0 comments on commit 9a4f4b7

Please sign in to comment.