Skip to content

Commit

Permalink
fix GlobalPooling with graph only input
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 5, 2021
1 parent afad470 commit b720b05
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
return reduce_nodes(l.aggr, g, x)
end

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))

"""
TopKPool(adj, k, in_channel)
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
end

sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)

if duplicate_if_needed
# Used to copy edge features on reverse edges
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
Expand Down
6 changes: 4 additions & 2 deletions test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
@testset "GlobalPool" begin
n = 10
X = rand(16, n)
g = GNNGraph(random_regular_graph(n, 4))
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
p = GlobalPool(+)
@test p(g, X) NNlib.scatter(+, X, ones(Int, n))
y = p(g, X)
@test y NNlib.scatter(+, X, ones(Int, n))
test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph)
end

@testset "TopKPool" begin
Expand Down
12 changes: 10 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
verbose = false,
test_gpu = TEST_GPU,
outsize = nothing,
outtype = :node,
)

# TODO these give errors, probably some bugs in ChainRulesTestUtils
Expand Down Expand Up @@ -57,8 +58,15 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
@test ycoo y

g′ = f(l, g)
@test g′.ndata.x y

if outtype == :node
@test g′.ndata.x y
elseif outtype == :edge
@test g′.edata.e y
elseif outtype == :graph
@test g′.gdata.u y
else
@error "wrong outtype $outtype"
end
if test_gpu
ygpu = f(lgpu, ggpu, xgpu)
@test ygpu isa CuArray
Expand Down

0 comments on commit b720b05

Please sign in to comment.