Skip to content

Commit

Permalink
Merge pull request #253 from FluxML/develop
Browse files Browse the repository at this point in the history
Fix message-passing tests
  • Loading branch information
yuehhua authored Dec 12, 2021
2 parents 16d43e5 + e8905c6 commit f1b3c1d
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 319 deletions.
4 changes: 1 addition & 3 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearAlgebra: Adjoint, norm, Transpose
using Random
using Reexport

using CUDA
using CUDA, CUDA.CUSPARSE
using ChainRulesCore: @non_differentiable
using FillArrays: Fill
using Flux
Expand Down Expand Up @@ -78,8 +78,6 @@ include("layers/misc.jl")
include("sampling.jl")
include("embedding/node2vec.jl")

include("cuda/conv.jl")

using .Datasets


Expand Down
26 changes: 0 additions & 26 deletions src/cuda/conv.jl

This file was deleted.

9 changes: 5 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
Y = view(c.weight,:,:,1) * Z_prev
Y += view(c.weight,:,:,2) * Z
for k = 3:c.k
Z, Z_prev = 2*Z*- Z_prev, Z
Z, Z_prev = 2 .* Z * - Z_prev, Z
Y += view(c.weight,:,:,k) * Z
end
return Y .+ c.bias
Expand Down Expand Up @@ -253,13 +253,14 @@ function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
end

function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
@assert check_self_loops(sg) "a vertex must have self loop (receive a message from itself)."
mapreduce(i -> apply_batch_message(gat, i, neighbors(sg, i), X), hcat, 1:nv(sg))
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
ys = map(i -> apply_batch_message(gat, i, GraphSignals.cpu_neighbors(sg, i), X), 1:nv(sg))
return hcat(ys...)
end

function check_self_loops(sg::SparseGraph)
for i in 1:nv(sg)
if !(i in GraphSignals.rowvalview(sg.S, i))
if !(i in collect(GraphSignals.rowvalview(sg.S, i)))
return false
end
end
Expand Down
20 changes: 14 additions & 6 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ abstract type GraphNet <: AbstractGraphLayer end
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
@inline update_global(gn::GraphNet, ē, v̄, u) = u

@inline update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u) =
mapreduce(i -> apply_batch_message(gn, sg, i, neighbors(sg, i), E, V, u), hcat, vertices(sg))
@inline function update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u)
ys = map(i -> apply_batch_message(gn, sg, i, GraphSignals.cpu_neighbors(sg, i), E, V, u), vertices(sg))
return hcat(ys...)
end

@inline apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u) =
mapreduce(j -> update_edge(gn, _view(E, edge_index(sg, i, j)), _view(V, i), _view(V, j), u), hcat, js)
@inline function apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u)
# js still CuArray
es = Zygote.ignore(() -> GraphSignals.cpu_incident_edges(sg, i))
ys = map(k -> update_edge(gn, _view(E, es[k]), _view(V, i), _view(V, js[k]), u), 1:length(js))
return hcat(ys...)
end

@inline update_batch_vertex(gn::GraphNet, Ē, V, u) =
mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2))
@inline function update_batch_vertex(gn::GraphNet, Ē, V, u)
ys = map(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), 1:size(V,2))
return hcat(ys...)
end

@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr, E) = neighbor_scatter(aggr, E, sg)
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr::Nothing, @nospecialize E) = nothing
Expand Down
139 changes: 73 additions & 66 deletions test/cuda/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,70 +51,77 @@
@test size(g.bias) == size(cc.bias)
end

# @testset "GraphConv" begin
# gc = GraphConv(fg, in_channel=>out_channel) |> gpu
# @test size(gc.weight1) == (out_channel, in_channel)
# @test size(gc.weight2) == (out_channel, in_channel)
# @test size(gc.bias) == (out_channel,)

# X = rand(in_channel, N) |> gpu
# Y = gc(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(gc(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), gc)[1]
# @test size(g.weight1) == size(gc.weight1)
# @test size(g.weight2) == size(gc.weight2)
# @test size(g.bias) == size(gc.bias)
# end

# @testset "GATConv" begin
# gat = GATConv(fg, in_channel=>out_channel) |> gpu
# @test size(gat.weight) == (out_channel, in_channel)
# @test size(gat.bias) == (out_channel,)

# X = rand(in_channel, N) |> gpu
# Y = gat(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(gat(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), gat)[1]
# @test size(g.weight) == size(gat.weight)
# @test size(g.bias) == size(gat.bias)
# @test size(g.a) == size(gat.a)
# end

# @testset "GatedGraphConv" begin
# num_layers = 3
# ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
# @test size(ggc.weight) == (out_channel, out_channel, num_layers)

# X = rand(in_channel, N) |> gpu
# Y = ggc(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
# @test size(g.weight) == size(ggc.weight)
# end

# @testset "EdgeConv" begin
# ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
# X = rand(in_channel, N) |> gpu
# Y = ec(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(ec(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), ec)[1]
# @test size(g.nn.weight) == size(ec.nn.weight)
# @test size(g.nn.bias) == size(ec.nn.bias)
# end
@testset "GraphConv" begin
gc = GraphConv(fg, in_channel=>out_channel) |> gpu
@test size(gc.weight1) == (out_channel, in_channel)
@test size(gc.weight2) == (out_channel, in_channel)
@test size(gc.bias) == (out_channel,)

X = rand(in_channel, N) |> gpu
Y = gc(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(gc(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), gc)[1]
@test size(g.weight1) == size(gc.weight1)
@test size(g.weight2) == size(gc.weight2)
@test size(g.bias) == size(gc.bias)
end

@testset "GATConv" begin
adj = T[1 1 0 1;
1 1 1 0;
0 1 1 1;
1 0 1 1]

fg = FeaturedGraph(adj)

gat = GATConv(fg, in_channel=>out_channel) |> gpu
@test size(gat.weight) == (out_channel, in_channel)
@test size(gat.bias) == (out_channel,)

X = rand(in_channel, N) |> gpu
Y = gat(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(gat(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), gat)[1]
@test size(g.weight) == size(gat.weight)
@test size(g.bias) == size(gat.bias)
@test size(g.a) == size(gat.a)
end

@testset "GatedGraphConv" begin
num_layers = 3
ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
@test size(ggc.weight) == (out_channel, out_channel, num_layers)

X = rand(in_channel, N) |> gpu
Y = ggc(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
@test size(g.weight) == size(ggc.weight)
end

@testset "EdgeConv" begin
ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
X = rand(in_channel, N) |> gpu
Y = ec(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(ec(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), ec)[1]
@test size(g.nn.weight) == size(ec.nn.weight)
@test size(g.nn.bias) == size(ec.nn.bias)
end
end
Loading

0 comments on commit f1b3c1d

Please sign in to comment.