Skip to content

Commit

Permalink
more layers
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Aug 1, 2024
1 parent 67a51f7 commit c7eb715
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
6 changes: 4 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Lux: Lux, Chain, Dense, GRUCell,
glorot_uniform, zeros32,
StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -25,7 +27,7 @@ export AGNNConv,
GATv2Conv,
GatedGraphConv,
GCNConv,
# GINConv,
GINConv,
# GMMConv,
GraphConv
# MEGNetConv,
Expand Down
45 changes: 32 additions & 13 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
end

LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::GCNConv) = 0
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)

function Base.show(io::IO, l::GCNConv)
Expand Down Expand Up @@ -518,7 +517,7 @@ function Base.show(io::IO, l::GATv2Conv)
end


@concrete struct GatedGraphConv <: GRULayer
@concrete struct GatedGraphConv <: GNNLayer
gru
init_weight
dims::Int
Expand All @@ -533,28 +532,48 @@ function GatedGraphConv(dims::Int, num_layers::Int;
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
end

LucCore.outputsize(l::GatedGraphConv) = (l.dims,)
LuxCore.outputsize(l::GatedGraphConv) = (l.dims,)

function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
gru = LuxCore.initialparameters(rng, l.gru)
weight = l.init_weight(rng, l.dims, l.dims)
weight = l.init_weight(rng, l.dims, l.dims, l.num_layers)
return (; gru, weight)
end

LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2
LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers

function LuxCore.initialstates(rng::AbstractRNG, l::GatedGraphConv)
return (; gru = LuxCore.initialstates(rng, l.gru))
end

LuxCore.statelength(l::GatedGraphConv) = statelength(l.gru)

function (l::GatedGraphConv)(g, H, ps, st)
GNNlib.gated_graph_conv(l, g, H)
function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end
end

@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
ϵ <: Real
aggr
end

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.ϵ, l.aggr)
y = GNNlib.gin_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::GINConv)
print(io, "GINConv($(l.nn)")
print(io, ", $(l.ϵ)")
print(io, ")")
end
12 changes: 11 additions & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,15 @@

#TODO test edge
end
end

@testset "GatedGraphConv" begin
l = GatedGraphConv(in_dims, 3)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
end

@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
end
1 change: 1 addition & 0 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
Expand Down

0 comments on commit c7eb715

Please sign in to comment.