Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 28, 2022
1 parent 10d73db commit 1558ffa
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
@testset "basic" begin
@testset "GNNChain" begin
n, din, d, dout = 10, 3, 4, 2
deg = 4

g = GNNGraph(random_regular_graph(n, 4),
g = GNNGraph(random_regular_graph(n, deg),
graph_type=GRAPH_T,
ndata= randn(Float32, din, n))

x = g.ndata.x

gnn = GNNChain(GCNConv(din => d),
BatchNorm(d),
x -> tanh.(x),
Expand All @@ -18,20 +20,27 @@
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])

@testset "constructor with names" begin
m = GNNChain(GCNConv(2=>5),
BatchNorm(5),
m = GNNChain(GCNConv(din=>d),
BatchNorm(d),
x -> relu.(x),
Dense(5, 4))
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

Dense(d, dout))

m2 = GNNChain(enc = m,
dec = DotDecoder())

@test m2[:enc] === m
@test m2(g, x) == m2[:dec](g, m2[:enc](g, x))
end

@testset "constructor with vector" begin
m = GNNChain(GCNConv(din=>d),
BatchNorm(d),
x -> relu.(x),
Dense(d, dout))
m2 = GNNChain([m.layers...])
@test m2(g, x) == m(g, x)
end

@testset "Parallel" begin
AddResidual(l) = Parallel(+, identity, l)

Expand Down

0 comments on commit 1558ffa

Please sign in to comment.