Skip to content

Commit

Permalink
not export GraphNetwork and MessagePassing APIs
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
yuehhua committed Jul 11, 2021
1 parent a706351 commit 6833924
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 22 deletions.
14 changes: 0 additions & 14 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,9 @@ using Zygote
export
# layers/gn
GraphNet,
update_edge,
update_vertex,
update_global,
update_batch_edge,
update_batch_vertex,
aggregate_neighbors,
aggregate_edges,
aggregate_vertices,
propagate,

# layers/msgpass
MessagePassing,
message,
update,

# layers/conv
GCNConv,
Expand All @@ -40,9 +29,6 @@ export
GATConv,
GatedGraphConv,
EdgeConv,
message,
update,
propagate,

# layer/pool
GlobalPool,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ end
# After some reshaping due to the multihead, we get the α from each message,
# then get the softmax over every α, and eventually multiply the message by α
function apply_batch_message(g::GATConv, i, js, X::AbstractMatrix)
e_ij = mapreduce(j -> message(g, _view(X, i), _view(X, j)), hcat, js)
e_ij = mapreduce(j -> GeometricFlux.message(g, _view(X, i), _view(X, j)), hcat, js)
n = size(e_ij, 1)
αs = Flux.softmax(reshape(view(e_ij, 1, :), g.heads, :), dims=2)
msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ First argument should be message-passing layer, the rest of arguments can be `X`
end

@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::AbstractMatrix, X::AbstractMatrix, u) =
mapreduce(j -> message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js)
mapreduce(j -> GeometricFlux.message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js)

@inline update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::AbstractMatrix, u) =
mapreduce(i -> update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2))
mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2))

@inline function aggregate_neighbors(mp::MessagePassing, aggr, M::AbstractMatrix, accu_edge)
@assert !iszero(accu_edge) "accumulated edge must not be zero."
Expand Down
8 changes: 4 additions & 4 deletions test/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ u = rand(T, in_channel)
l = NewGNLayer()

@testset "without aggregation" begin
(l::NewGNLayer)(fg) = propagate(l, fg)
(l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg)

fg = FeaturedGraph(adj, nf=V)
fg_ = l(fg)
Expand All @@ -34,7 +34,7 @@ u = rand(T, in_channel)
end

@testset "with neighbor aggregation" begin
(l::NewGNLayer)(fg) = propagate(l, fg, +)
(l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +)

fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0))
l = NewGNLayer()
Expand All @@ -48,7 +48,7 @@ u = rand(T, in_channel)

GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = rand(T, out_channel)
@testset "update edge with neighbor aggregation" begin
(l::NewGNLayer)(fg) = propagate(l, fg, +)
(l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +)

fg = FeaturedGraph(adj, nf=V, ef=E, gf=zeros(0))
l = NewGNLayer()
Expand All @@ -62,7 +62,7 @@ u = rand(T, in_channel)

GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = rand(T, out_channel)
@testset "update edge/vertex with all aggregation" begin
(l::NewGNLayer)(fg) = propagate(l, fg, +, +, +)
(l::NewGNLayer)(fg) = GeometricFlux.propagate(l, fg, +, +, +)

fg = FeaturedGraph(adj, nf=V, ef=E, gf=u)
l = NewGNLayer()
Expand Down
2 changes: 1 addition & 1 deletion test/layers/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct NewLayer <: MessagePassing
end
NewLayer(m, n) = NewLayer(randn(T, m,n))

(l::NewLayer)(fg) = propagate(l, fg, +)
(l::NewLayer)(fg) = GeometricFlux.propagate(l, fg, +)

X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V))
fg = FeaturedGraph(adj, nf=X, ef=Fill(zero(T), 0, 2num_E))
Expand Down

0 comments on commit 6833924

Please sign in to comment.