Skip to content

Commit

Permalink
[GNNLux] Added SGConv (#475)
Browse files Browse the repository at this point in the history
* added sgconv lux

* fix

* fix

* fix

* fix
  • Loading branch information
rbSparky authored Aug 1, 2024
1 parent c82efa0 commit 0a23ffa
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 3 deletions.
4 changes: 2 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ export AGNNConv,
GCNConv,
# GINConv,
# GMMConv,
GraphConv
GraphConv,
# MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
# SGConv,
SGConv
# TAGConv,
# TransformerConv

Expand Down
56 changes: 56 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,60 @@ function Base.show(io::IO, l::GATv2Conv)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end

@concrete struct SGConv <: GNNLayer
in_dims::Int
out_dims::Int
k::Int
use_bias::Bool
add_self_loops::Bool
use_edge_weight::Bool
init_weight
init_bias
end

function SGConv(ch::Pair{Int, Int}, k = 1;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false)
in_dims, out_dims = ch
return SGConv(in_dims, out_dims, k, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::SGConv) = 0
LuxCore.outputsize(d::SGConv) = (d.out_dims,)

function Base.show(io::IO, l::SGConv)
print(io, "SGConv(", l.in_dims, " => ", l.out_dims)
l.k || print(io, ", ", l.k)
l.use_bias || print(io, ", use_bias=false")
l.add_self_loops || print(io, ", add_self_loops=false")
!l.use_edge_weight || print(io, ", use_edge_weight=true")
print(io, ")")
end

(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
l(g, x, edge_weight, ps, st; conv_weight)

function (l::SGConv)(g, x, edge_weight, ps, st;
conv_weight=nothing, )

m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.k)
y = GNNlib.sg_conv(m, g, x, edge_weight)
return y, st
end
5 changes: 5 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,10 @@

#TODO test edge
end

@testset "SGConv" begin
l = SGConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end

2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,4 +722,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
T1_out = T2_out
end
return h .+ l.bias
end
end

0 comments on commit 0a23ffa

Please sign in to comment.