Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] fix tests #468

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@
@test GNNLayer <: LuxCore.AbstractExplicitLayer
end

@testset "GNNContainerLayer" begin
@test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer
end

@testset "GNNChain" begin
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
@test GNNChain <: GNNContainerLayer
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
ps = LuxCore.initialparameters(rng, c)
st = LuxCore.initialstates(rng, c)
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(c) == LuxCore.statelength(st)
y, st′ = c(g, x, ps, st)
@test LuxCore.outputsize(c) == (3,)
@test size(y) == (3, 10)
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
end
end
71 changes: 7 additions & 64 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,89 +5,32 @@

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(ps) == 1
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test size(y) == size(x)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, sizey=(3,10))
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
l = EdgeConv(nn, aggr = +)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
end

@testset "CGConv" begin
l = CGConv(3 => 5, residual = true)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
@test Lux.outputsize(l) == (5,)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
l = CGConv(3 => 3, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
end
end
50 changes: 35 additions & 15 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,43 @@

import Reexport: @reexport

@reexport using Test
@reexport using GNNLux
@reexport using Lux, Functors
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
Zygote, Statistics
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx
using MLDataDevices

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? CPUDevice() :
mode == "cuda" ? CUDADevice() : mode == "amdgpu" ? AMDGPUDevice() : nothing
rng = default_device_rng(dev)
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
end
@reexport using Lux
@reexport using StableRNGs
@reexport using Random, Statistics

using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme

export test_lux_layer

export get_default_rng
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)

# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing
if container
@test l isa GNNContainerLayer
else
@test l isa GNNLayer
end

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
if sizey !== nothing
@test size(y) == sizey
elseif outputsize !== nothing
@test size(y) == (outputsize..., g.num_nodes)
end

loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

end
Loading