Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix message-passing tests #253

Merged
merged 4 commits into from
Dec 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearAlgebra: Adjoint, norm, Transpose
using Random
using Reexport

using CUDA
using CUDA, CUDA.CUSPARSE
using ChainRulesCore: @non_differentiable
using FillArrays: Fill
using Flux
Expand Down Expand Up @@ -78,8 +78,6 @@ include("layers/misc.jl")
include("sampling.jl")
include("embedding/node2vec.jl")

include("cuda/conv.jl")

using .Datasets


Expand Down
26 changes: 0 additions & 26 deletions src/cuda/conv.jl

This file was deleted.

9 changes: 5 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
Y = view(c.weight,:,:,1) * Z_prev
Y += view(c.weight,:,:,2) * Z
for k = 3:c.k
Z, Z_prev = 2*Z*L̃ - Z_prev, Z
Z, Z_prev = 2 .* Z * L̃ - Z_prev, Z
Y += view(c.weight,:,:,k) * Z
end
return Y .+ c.bias
Expand Down Expand Up @@ -253,13 +253,14 @@ function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
end

function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
@assert check_self_loops(sg) "a vertex must have self loop (receive a message from itself)."
mapreduce(i -> apply_batch_message(gat, i, neighbors(sg, i), X), hcat, 1:nv(sg))
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
ys = map(i -> apply_batch_message(gat, i, GraphSignals.cpu_neighbors(sg, i), X), 1:nv(sg))
return hcat(ys...)
end

function check_self_loops(sg::SparseGraph)
for i in 1:nv(sg)
if !(i in GraphSignals.rowvalview(sg.S, i))
if !(i in collect(GraphSignals.rowvalview(sg.S, i)))
return false
end
end
Expand Down
20 changes: 14 additions & 6 deletions src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ abstract type GraphNet <: AbstractGraphLayer end
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
@inline update_global(gn::GraphNet, ē, v̄, u) = u

@inline update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u) =
mapreduce(i -> apply_batch_message(gn, sg, i, neighbors(sg, i), E, V, u), hcat, vertices(sg))
@inline function update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u)
ys = map(i -> apply_batch_message(gn, sg, i, GraphSignals.cpu_neighbors(sg, i), E, V, u), vertices(sg))
return hcat(ys...)
end

@inline apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u) =
mapreduce(j -> update_edge(gn, _view(E, edge_index(sg, i, j)), _view(V, i), _view(V, j), u), hcat, js)
@inline function apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u)
# js still CuArray
es = Zygote.ignore(() -> GraphSignals.cpu_incident_edges(sg, i))
ys = map(k -> update_edge(gn, _view(E, es[k]), _view(V, i), _view(V, js[k]), u), 1:length(js))
return hcat(ys...)
end

@inline update_batch_vertex(gn::GraphNet, Ē, V, u) =
mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2))
@inline function update_batch_vertex(gn::GraphNet, Ē, V, u)
ys = map(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), 1:size(V,2))
return hcat(ys...)
end

@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr, E) = neighbor_scatter(aggr, E, sg)
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr::Nothing, @nospecialize E) = nothing
Expand Down
139 changes: 73 additions & 66 deletions test/cuda/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,70 +51,77 @@
@test size(g.bias) == size(cc.bias)
end

# @testset "GraphConv" begin
# gc = GraphConv(fg, in_channel=>out_channel) |> gpu
# @test size(gc.weight1) == (out_channel, in_channel)
# @test size(gc.weight2) == (out_channel, in_channel)
# @test size(gc.bias) == (out_channel,)

# X = rand(in_channel, N) |> gpu
# Y = gc(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(gc(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), gc)[1]
# @test size(g.weight1) == size(gc.weight1)
# @test size(g.weight2) == size(gc.weight2)
# @test size(g.bias) == size(gc.bias)
# end

# @testset "GATConv" begin
# gat = GATConv(fg, in_channel=>out_channel) |> gpu
# @test size(gat.weight) == (out_channel, in_channel)
# @test size(gat.bias) == (out_channel,)

# X = rand(in_channel, N) |> gpu
# Y = gat(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(gat(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), gat)[1]
# @test size(g.weight) == size(gat.weight)
# @test size(g.bias) == size(gat.bias)
# @test size(g.a) == size(gat.a)
# end

# @testset "GatedGraphConv" begin
# num_layers = 3
# ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
# @test size(ggc.weight) == (out_channel, out_channel, num_layers)

# X = rand(in_channel, N) |> gpu
# Y = ggc(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
# @test size(g.weight) == size(ggc.weight)
# end

# @testset "EdgeConv" begin
# ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
# X = rand(in_channel, N) |> gpu
# Y = ec(X)
# @test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(ec(x)), X)[1]
# @test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), ec)[1]
# @test size(g.nn.weight) == size(ec.nn.weight)
# @test size(g.nn.bias) == size(ec.nn.bias)
# end
@testset "GraphConv" begin
gc = GraphConv(fg, in_channel=>out_channel) |> gpu
@test size(gc.weight1) == (out_channel, in_channel)
@test size(gc.weight2) == (out_channel, in_channel)
@test size(gc.bias) == (out_channel,)

X = rand(in_channel, N) |> gpu
Y = gc(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(gc(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), gc)[1]
@test size(g.weight1) == size(gc.weight1)
@test size(g.weight2) == size(gc.weight2)
@test size(g.bias) == size(gc.bias)
end

@testset "GATConv" begin
adj = T[1 1 0 1;
1 1 1 0;
0 1 1 1;
1 0 1 1]

fg = FeaturedGraph(adj)

gat = GATConv(fg, in_channel=>out_channel) |> gpu
@test size(gat.weight) == (out_channel, in_channel)
@test size(gat.bias) == (out_channel,)

X = rand(in_channel, N) |> gpu
Y = gat(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(gat(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), gat)[1]
@test size(g.weight) == size(gat.weight)
@test size(g.bias) == size(gat.bias)
@test size(g.a) == size(gat.a)
end

@testset "GatedGraphConv" begin
num_layers = 3
ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
@test size(ggc.weight) == (out_channel, out_channel, num_layers)

X = rand(in_channel, N) |> gpu
Y = ggc(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
@test size(g.weight) == size(ggc.weight)
end

@testset "EdgeConv" begin
ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
X = rand(in_channel, N) |> gpu
Y = ec(X)
@test size(Y) == (out_channel, N)

g = Zygote.gradient(x -> sum(ec(x)), X)[1]
@test size(g) == size(X)

g = Zygote.gradient(model -> sum(model(X)), ec)[1]
@test size(g.nn.weight) == size(ec.nn.weight)
@test size(g.nn.bias) == size(ec.nn.bias)
end
end
Loading