Skip to content

Commit

Permalink
[GNNLux] more layers pt. 3 (#471)
Browse files Browse the repository at this point in the history
* more layer

more layers

stuff

* fixes
  • Loading branch information
CarloLucibello authored Aug 1, 2024
1 parent 0a23ffa commit 83b6b7e
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 39 deletions.
11 changes: 7 additions & 4 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, GRUCell,
glorot_uniform, zeros32,
StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -22,9 +25,9 @@ export AGNNConv,
DConv,
GATConv,
GATv2Conv,
# GatedGraphConv,
GatedGraphConv,
GCNConv,
# GINConv,
GINConv,
# GMMConv,
GraphConv,
# MEGNetConv,
Expand Down
72 changes: 64 additions & 8 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
end

LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::GCNConv) = 0
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)

function Base.show(io::IO, l::GCNConv)
Expand Down Expand Up @@ -549,7 +548,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
end

LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::SGConv) = 0
LuxCore.outputsize(d::SGConv) = (d.out_dims,)

function Base.show(io::IO, l::SGConv)
Expand All @@ -561,14 +559,72 @@ function Base.show(io::IO, l::SGConv)
print(io, ")")
end

(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
l(g, x, edge_weight, ps, st; conv_weight)

function (l::SGConv)(g, x, edge_weight, ps, st;
conv_weight=nothing, )
(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::SGConv)(g, x, edge_weight, ps, st)
m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.k)
y = GNNlib.sg_conv(m, g, x, edge_weight)
return y, st
end
end

@concrete struct GatedGraphConv <: GNNLayer
gru
init_weight
dims::Int
num_layers::Int
aggr
end


function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init_weight = glorot_uniform)
gru = GRUCell(dims => dims)
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
end

LuxCore.outputsize(l::GatedGraphConv) = (l.dims,)

function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
gru = LuxCore.initialparameters(rng, l.gru)
weight = l.init_weight(rng, l.dims, l.dims, l.num_layers)
return (; gru, weight)
end

LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers


function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end

@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
ϵ <: Real
aggr
end

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.ϵ, l.aggr)
y = GNNlib.gin_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::GINConv)
print(io, "GINConv($(l.nn)")
print(io, ", $(l.ϵ)")
print(io, ")")
end
12 changes: 11 additions & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,15 @@
l = SGConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end

@testset "GatedGraphConv" begin
l = GatedGraphConv(in_dims, 3)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
end

@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
end
1 change: 1 addition & 0 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
Expand Down
35 changes: 17 additions & 18 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w
if edge_weight !== nothing
# Pad weights with ones
# TODO for ADJMAT_T the new edges are not generally at the end
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -215,23 +215,22 @@ end

####################### GatedGraphConv ######################################

# TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)

function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real}
check_num_nodes(g, H)
m, n = size(H)
@assert (m<=l.out_ch) "number of input features must less or equals to output features."
if m < l.out_ch
Hpad = similar(H, S, l.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
m, n = size(x)
@assert m <= l.dims "number of input features must be less or equal to output features."
if m < l.dims
xpad = zeros_like(x, (l.dims - m, n))
x = vcat(x, xpad)
end
h = x
for i in 1:(l.num_layers)
M = view(l.weight, :, :, i) * H
M = propagate(copy_xj, g, l.aggr; xj = M)
H, _ = l.gru(H, M)
m = view(l.weight, :, :, i) * h
m = propagate(copy_xj, g, l.aggr; xj = m)
# in gru forward, hidden state is first argument, input is second
h, _ = l.gru(h, m)
end
return H
return h
end

####################### EdgeConv ######################################
Expand Down Expand Up @@ -419,7 +418,7 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -512,7 +511,7 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down Expand Up @@ -644,7 +643,7 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
@assert length(edge_weight) == g.num_edges
end
end
Expand Down
16 changes: 8 additions & 8 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing throu
# Arguments
- `out`: The dimension of output features.
- `num_layers`: The number of gated recurrent unit.
- `num_layers`: The number of recursion steps.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `init`: Weight initialization function.
Expand All @@ -510,25 +510,25 @@ y = l(g, x)
struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
weight::W
gru::R
out_ch::Int
dims::Int
num_layers::Int
aggr::A
end

@functor GatedGraphConv

function GatedGraphConv(out_ch::Int, num_layers::Int;
function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init = glorot_uniform)
w = init(out_ch, out_ch, num_layers)
gru = GRUCell(out_ch, out_ch)
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
w = init(dims, dims, num_layers)
gru = GRUCell(dims => dims)
GatedGraphConv(w, gru, dims, num_layers, aggr)
end


(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H)

function Base.show(io::IO, l::GatedGraphConv)
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
print(io, ", aggr=", l.aggr)
print(io, ")")
end
Expand Down Expand Up @@ -1201,7 +1201,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
SGConv(W, b, k, add_self_loops, use_edge_weight)
return SGConv(W, b, k, add_self_loops, use_edge_weight)
end

(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight)
Expand Down

0 comments on commit 83b6b7e

Please sign in to comment.