diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index 11a1d3a29..9f59f3b10 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -7,18 +7,13 @@ @test GNNLayer <: LuxCore.AbstractExplicitLayer end + @testset "GNNContainerLayer" begin + @test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer + end + @testset "GNNChain" begin @test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} - @test GNNChain <: GNNContainerLayer c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) - ps = LuxCore.initialparameters(rng, c) - st = LuxCore.initialstates(rng, c) - @test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps) - @test LuxCore.statelength(c) == LuxCore.statelength(st) - y, st′ = c(g, x, ps, st) - @test LuxCore.outputsize(c) == (3,) - @test size(y) == (3, 10) - loss = (x, ps) -> sum(first(c(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + test_lux_layer(rng, c, g, x, outputsize=(3,), container=true) end end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ca36bcbb0..520fcc570 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -5,89 +5,32 @@ @testset "GCNConv" begin l = GCNConv(3 => 5, relu) - @test l isa GNNLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - - y, _ = l(g, x, ps, st) - @test Lux.outputsize(l) == (5,) - @test size(y) == (5, 10) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true + test_lux_layer(rng, l, g, x, outputsize=(5,)) end @testset "ChebConv" begin l = ChebConv(3 => 5, 2) - @test l isa GNNLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - - y, _ = l(g, x, ps, st) - @test Lux.outputsize(l) == (5,) - @test size(y) == (5, 10) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + test_lux_layer(rng, l, g, x, outputsize=(5,)) end @testset "GraphConv" begin l = GraphConv(3 => 5, relu) - @test l isa GNNLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - - y, _ = l(g, x, ps, st) - @test Lux.outputsize(l) == (5,) - @test size(y) == (5, 10) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true + test_lux_layer(rng, l, g, x, outputsize=(5,)) end @testset "AGNNConv" begin l = AGNNConv(init_beta=1.0f0) - @test l isa GNNLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(ps) == 1 - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - - y, _ = l(g, x, ps, st) - @test size(y) == size(x) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + test_lux_layer(rng, l, g, x, sizey=(3,10)) end @testset "EdgeConv" begin nn = Chain(Dense(6 => 5, relu), Dense(5 => 5)) l = EdgeConv(nn, aggr = +) - @test l isa GNNContainerLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - y, st′ = l(g, x, ps, st) - @test size(y) == (5, 10) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + test_lux_layer(rng, l, g, x, sizey=(5,10), container=true) end @testset "CGConv" begin - l = CGConv(3 => 5, residual = true) - @test l isa GNNContainerLayer - ps = Lux.initialparameters(rng, l) - st = Lux.initialstates(rng, l) - @test Lux.parameterlength(l) == Lux.parameterlength(ps) - @test Lux.statelength(l) == Lux.statelength(st) - y, st′ = l(g, x, ps, st) - @test size(y) == (5, 10) - @test Lux.outputsize(l) == (5,) - loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true + l = CGConv(3 => 3, residual = true) + test_lux_layer(rng, l, g, x, outputsize=(3,), container=true) end end diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index 2ba29ea1a..1354ef387 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -2,23 +2,43 @@ import Reexport: @reexport +@reexport using Test @reexport using GNNLux -@reexport using Lux, Functors -@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test, - Zygote, Statistics -@reexport using LuxTestUtils: @jet, @test_gradients, check_approx -using MLDataDevices - -# Some Helper Functions -function get_default_rng(mode::String) - dev = mode == "cpu" ? CPUDevice() : - mode == "cuda" ? CUDADevice() : mode == "amdgpu" ? AMDGPUDevice() : nothing - rng = default_device_rng(dev) - return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng) -end +@reexport using Lux +@reexport using StableRNGs +@reexport using Random, Statistics + +using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme + +export test_lux_layer -export get_default_rng +function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; + outputsize=nothing, sizey=nothing, container=false, + atol=1.0f-2, rtol=1.0f-2) -# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing + if container + @test l isa GNNContainerLayer + else + @test l isa GNNLayer + end + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(l) == LuxCore.statelength(st) + + y, st′ = l(g, x, ps, st) + if outputsize !== nothing + @test LuxCore.outputsize(l) == outputsize + end + if sizey !== nothing + @test size(y) == sizey + elseif outputsize !== nothing + @test size(y) == (outputsize..., g.num_nodes) + end + + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) +end end \ No newline at end of file