diff --git a/GNNLux/Project.toml b/GNNLux/Project.toml index 9f27ee3d1..54944c740 100644 --- a/GNNLux/Project.toml +++ b/GNNLux/Project.toml @@ -12,14 +12,16 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ConcreteStructs = "0.2.3" -Lux = "0.5.61" -LuxCore = "0.1.20" +Lux = "1.0" +LuxCore = "1.0" NNlib = "0.9.21" Reexport = "1.2" +Static = "1.1" julia = "1.10" [extras] diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index f566cd0c6..689ad724b 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -2,7 +2,7 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish using Statistics: mean -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, parameterlength, statelength, outputsize, initialparameters, initialstates, parameterlength, statelength using Lux: Lux, Chain, Dense, GRUCell, glorot_uniform, zeros32, @@ -10,6 +10,7 @@ using Lux: Lux, Chain, Dense, GRUCell, using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib +using Static @reexport using GNNGraphs include("layers/basic.jl") diff --git a/GNNLux/src/layers/basic.jl b/GNNLux/src/layers/basic.jl index 32f33bdbb..ba12de728 100644 --- a/GNNLux/src/layers/basic.jl +++ b/GNNLux/src/layers/basic.jl @@ -1,14 +1,14 @@ """ - abstract type GNNLayer <: AbstractExplicitLayer end + abstract type GNNLayer <: AbstractLuxLayer end An abstract type from which graph neural network layers are derived. -It is Derived from Lux's `AbstractExplicitLayer` type. +It is Derived from Lux's `AbstractLuxLayer` type. See also [`GNNChain`](@ref GNNLux.GNNChain). """ -abstract type GNNLayer <: AbstractExplicitLayer end +abstract type GNNLayer <: AbstractLuxLayer end -abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end +abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end @concrete struct GNNChain <: GNNContainerLayer{(:layers,)} layers <: NamedTuple @@ -24,7 +24,7 @@ function GNNChain(; kw...) return GNNChain(nt) end -_wrapforchain(l::AbstractExplicitLayer) = l +_wrapforchain(l::AbstractLuxLayer) = l _wrapforchain(l) = Lux.WrappedFunction(l) Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers)) @@ -44,7 +44,7 @@ Base.firstindex(c::GNNChain) = firstindex(c.layers) LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end]) -(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st) +(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps.layers, st.layers) function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times newst = (;) @@ -56,6 +56,6 @@ function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, help end _applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) -_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) +_applylayer(l::AbstractLuxLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) _applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) _applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 30564ae48..04415f1f6 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1,7 +1,9 @@ _getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false _getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple() _getstate(s::StatefulLuxLayer{true}) = s.st +_getstate(s::StatefulLuxLayer{Static.True}) = s.st _getstate(s::StatefulLuxLayer{false}) = s.st_any +_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any @concrete struct GCNConv <: GNNLayer @@ -20,10 +22,9 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity; init_bias = zeros32, use_bias::Bool = true, add_self_loops::Bool = true, - use_edge_weight::Bool = false, - allow_fast_activation::Bool = true) + use_edge_weight::Bool = false) in_dims, out_dims = ch - σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + σ = NNlib.fast_act(σ) return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) end @@ -121,10 +122,9 @@ function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, init_weight = glorot_uniform, init_bias = zeros32, - use_bias::Bool = true, - allow_fast_activation::Bool = true) + use_bias::Bool = true) in_dims, out_dims = ch - σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + σ = NNlib.fast_act(σ) return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr) end @@ -212,11 +212,10 @@ end CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...) function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, - use_bias = true, init_weight = glorot_uniform, init_bias = zeros32, - allow_fast_activation = true) + use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) (nin, ein), out = ch - dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation) - dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation) + dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias) + dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias) return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias) end @@ -232,7 +231,7 @@ function (l::CGConv)(g, x, e, ps, st) end @concrete struct EdgeConv <: GNNContainerLayer{(:nn,)} - nn <: AbstractExplicitLayer + nn <: AbstractLuxLayer aggr end @@ -246,10 +245,10 @@ end function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) m = (; nn, l.aggr) y = GNNlib.edge_conv(m, g, x) - stnew = _getstate(nn) + stnew = (; nn = _getstate(nn)) # TODO: support also aggr state if present return y, stnew end @@ -608,7 +607,7 @@ function Base.show(io::IO, l::GatedGraphConv) end @concrete struct GINConv <: GNNContainerLayer{(:nn,)} - nn <: AbstractExplicitLayer + nn <: AbstractLuxLayer ϵ <: Real aggr end @@ -616,10 +615,10 @@ end GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) function (l::GINConv)(g, x, ps, st) - nn = StatefulLuxLayer{true}(l.nn, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) m = (; nn, l.ϵ, l.aggr) y = GNNlib.gin_conv(m, g, x) - stnew = _getstate(nn) + stnew = (; nn = _getstate(nn)) return y, stnew end @@ -669,4 +668,4 @@ function Base.show(io::IO, l::MEGNetConv) nout = l.out_dims print(io, "MEGNetConv(", nin, " => ", nout) print(io, ")") -end \ No newline at end of file +end diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 687a21983..63c196a55 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -1,4 +1,4 @@ -@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct StatefulRecurrentCell <: AbstractLuxContainerLayer{(:cell,)} cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer} end @@ -7,16 +7,16 @@ function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell end function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple) - (out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry) + (out, carry), st = applyrecurrentcell(r.cell, g, x, ps.cell, st.cell, st.carry) return out, (; cell=st, carry) end function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple) - st, carry = st.cell, st.carry + stcell, carry = st.cell, st.carry for xᵢ in x - (out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry) + (out, carry), stcell = applyrecurrentcell(r.cell, g, xᵢ, ps.cell, stcell, carry) end - return out, (; cell=st, carry) + return out, (; cell=stcell, carry) end function applyrecurrentcell(l, g, x, ps, st, carry) @@ -35,7 +35,7 @@ end function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true) in_dims, out_dims = ch - conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true) + conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight) gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state) return TGCNCell(in_dims, out_dims, conv, gru, init_state) end diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index ac937d128..f4cabad69 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -4,16 +4,16 @@ x = randn(rng, Float32, 3, 10) @testset "GNNLayer" begin - @test GNNLayer <: LuxCore.AbstractExplicitLayer + @test GNNLayer <: LuxCore.AbstractLuxLayer end @testset "GNNContainerLayer" begin - @test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer + @test GNNContainerLayer <: LuxCore.AbstractLuxContainerLayer end @testset "GNNChain" begin - @test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} - c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) + @test GNNChain <: LuxCore.AbstractLuxContainerLayer{(:layers,)} + c = GNNChain(GraphConv(3 => 5, tanh), GCNConv(5 => 3)) test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) end end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 86a056977..c0f0d28e3 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -89,7 +89,7 @@ end @testset "GINConv" begin - nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + nn = Chain(Dense(in_dims => out_dims, tanh), Dense(out_dims => out_dims)) l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end