diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a981159f8..428dbf056 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -338,7 +338,7 @@ end function Base.show(io::IO, l::GATConv) out_channel, in_channel = size(l.weight) - print(io, "GATConv(", in_channel, "=>", out_channel ÷ l.heads) + print(io, "GATConv(", in_channel, " => ", out_channel ÷ l.heads) print(io, ", LeakyReLU(λ=", l.negative_slope) print(io, "))") end @@ -439,8 +439,8 @@ end function Base.show(io::IO, l::GATv2Conv) - out, in = size(l.weight_i) - print(io, "GATv2Conv(", in, "=>", out ÷ l.heads) + out, in = size(l.dense_i.weight) + print(io, "GATv2Conv(", in, " => ", out ÷ l.heads) print(io, ", LeakyReLU(λ=", l.negative_slope) print(io, "))") end @@ -654,7 +654,7 @@ end function Base.show(io::IO, l::NNConv) out, in = size(l.weight) - print(io, "NNConv( $in => $out") + print(io, "NNConv($in => $out") print(io, ", aggr=", l.aggr) print(io, ")") end @@ -776,7 +776,7 @@ end function Base.show(io::IO, l::ResGatedGraphConv) out_channel, in_channel = size(l.A) - print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel) + print(io, "ResGatedGraphConv(", in_channel, " => ", out_channel) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end