From 695ca1c0178f1f695ab57f39575b38eab419ccce Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 22 Jul 2021 15:39:23 +0200 Subject: [PATCH] a bunch of little improve --- src/layers/conv.jl | 136 ++++++++++++++++++++---------------------- src/layers/gn.jl | 9 ++- src/layers/misc.jl | 8 +-- src/layers/msgpass.jl | 55 +++++++++-------- src/layers/pool.jl | 4 -- src/models.jl | 6 +- src/utils.jl | 17 +++++- test/cuda/conv.jl | 4 +- test/layers/conv.jl | 24 ++++---- 9 files changed, 135 insertions(+), 128 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 553a5d023..3c729443d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,11 +1,11 @@ """ - GCNConv([graph,] in => out, σ=identity; bias=true, init=glorot_uniform) + GCNConv([fg,] in => out, σ=identity; bias=true, init=glorot_uniform) Graph convolutional layer. # Arguments -- `graph`: Optionally pass a FeaturedGraph. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). - `in`: The dimension of input features. - `out`: The dimension of output features. - `σ`: Activation function. @@ -41,7 +41,7 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix) l.σ.(l.weight * x * L̃ .+ l.bias) end -(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::GCNConv)(x::AbstractMatrix) = l(l.fg, x) function Base.show(io::IO, l::GCNConv) @@ -53,14 +53,13 @@ end """ - ChebConv([graph, ]in=>out, k; bias=true, init=glorot_uniform) + ChebConv([fg,] in=>out, k; bias=true, init=glorot_uniform) Chebyshev spectral graph convolutional layer. # Arguments -- `graph`: Should be a adjacency matrix, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). Is optionnal so you can give a `FeaturedGraph` to -the layer instead of only the features. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). - `in`: The dimension of input features. - `out`: The dimension of output features. - `k`: The order of Chebyshev polynomial. @@ -72,8 +71,6 @@ struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph} bias::B fg::S k::Int - in_channel::Int - out_channel::Int end function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int; @@ -81,7 +78,7 @@ function ChebConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, k::Int; in, out = ch W = init(out, in, k) b = Flux.create_bias(W, bias, out) - ChebConv(W, b, fg, k, in, out) + ChebConv(W, b, fg, k) end ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) = @@ -90,11 +87,11 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) = @functor ChebConv function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T - L̃ = scaled_laplacian(fg, T) - - @assert size(X, 1) == c.in_channel "Input feature size must match input channel size." - GraphSignals.check_num_node(L̃, X) + check_num_nodes(fg, X) + @assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size." + L̃ = scaled_laplacian(fg, eltype(X)) + Z_prev = X Z = X * L̃ Y = view(c.weight,:,:,1) * Z_prev @@ -106,25 +103,25 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T return Y .+ c.bias end -(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::ChebConv)(x::AbstractMatrix) = l(l.fg, x) function Base.show(io::IO, l::ChebConv) - print(io, "ChebConv(", l.in_channel, " => ", l.out_channel) - print(io, ", k=", l.k) + out, in, k = size(l.weight) + print(io, "ChebConv(", in, " => ", out) + print(io, ", k=", k) print(io, ")") end """ - GraphConv([graph,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform) + GraphConv([fg,] in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform) Graph neural network layer. # Arguments -- `graph`: Should be a adjacency matrix, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). Is optionnal so you can give a `FeaturedGraph` to -the layer instead of only the features. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). - `in`: The dimension of input features. - `out`: The dimension of output features. - `σ`: Activation function. @@ -155,18 +152,17 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) = @functor GraphConv -message(g::GraphConv, x_i, x_j::AbstractVector, e_ij) = g.weight2 * x_j +message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j -update(g::GraphConv, m::AbstractVector, x::AbstractVector) = g.σ.(g.weight1*x .+ m .+ g.bias) +update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias) function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix) - g = graph(fg) - GraphSignals.check_num_node(g, x) - _, x = propagate(gc, adjacency_list(g), Fill(0.f0, 0, ne(g)), x, +) + check_num_nodes(fg, x) + _, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +) x end -(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::GraphConv)(x::AbstractMatrix) = l(l.fg, x) function Base.show(io::IO, l::GraphConv) @@ -181,7 +177,7 @@ end """ - GATConv([graph,] in => out; + GATConv([fg,] in => out; heads=1, concat=true, init=glorot_uniform @@ -192,8 +188,7 @@ Graph attentional layer. # Arguments -- `graph`: Should be a adjacency matrix, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). Is optionnal so you can give a `FeaturedGraph` to -the layer instead of only the features. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). - `in`: The dimension of input features. - `out`: The dimension of output features. - `bias::Bool`: Keyword argument, whether to learn the additive bias. @@ -227,56 +222,55 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...) @functor GATConv # Here the α that has not been softmaxed is the first number of the output message -function message(g::GATConv, x_i::AbstractVector, x_j::AbstractVector) - x_i = reshape(g.weight*x_i, :, g.heads) - x_j = reshape(g.weight*x_j, :, g.heads) +function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector) + x_i = reshape(gat.weight*x_i, :, gat.heads) + x_j = reshape(gat.weight*x_j, :, gat.heads) x_ij = vcat(x_i, x_j+zero(x_j)) - e = sum(x_ij .* g.a, dims=1) # inner product for each head, output shape: (1, g.heads) - e_ij = leakyrelu.(e, g.negative_slope) - vcat(e_ij, x_j) # shape: (n+1, g.heads) + e = sum(x_ij .* gat.a, dims=1) # inner product for each head, output shape: (1, gat.heads) + e_ij = leakyrelu.(e, gat.negative_slope) + vcat(e_ij, x_j) # shape: (n+1, gat.heads) end # After some reshaping due to the multihead, we get the α from each message, # then get the softmax over every α, and eventually multiply the message by α -function apply_batch_message(g::GATConv, i, js, X::AbstractMatrix) - e_ij = mapreduce(j -> GeometricFlux.message(g, _view(X, i), _view(X, j)), hcat, js) +function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix) + e_ij = mapreduce(j -> GeometricFlux.message(gat, _view(X, i), _view(X, j)), hcat, js) n = size(e_ij, 1) - αs = Flux.softmax(reshape(view(e_ij, 1, :), g.heads, :), dims=2) + αs = Flux.softmax(reshape(view(e_ij, 1, :), gat.heads, :), dims=2) msgs = view(e_ij, 2:n, :) .* reshape(αs, 1, :) - reshape(msgs, (n-1)*g.heads, :) + reshape(msgs, (n-1)*gat.heads, :) end -update_batch_edge(g::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(g, adj, X) +update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X) -function update_batch_edge(g::GATConv, adj, X::AbstractMatrix) +function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix) n = size(adj, 1) # a vertex must always receive a message from itself Zygote.ignore() do GraphLaplacians.add_self_loop!(adj, n) end - mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n) + mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n) end # The same as update function in batch manner -update_batch_vertex(g::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(g, M) +update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M) -function update_batch_vertex(g::GATConv, M::AbstractMatrix) - M = M .+ g.bias - if !g.concat +function update_batch_vertex(gat::GATConv, M::AbstractMatrix) + M = M .+ gat.bias + if !gat.concat N = size(M, 2) - M = reshape(mean(reshape(M, :, g.heads, N), dims=2), :, N) + M = reshape(mean(reshape(M, :, gat.heads, N), dims=2), :, N) end return M end function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix) - g = graph(fg) - GraphSignals.check_num_node(g, X) - _, X = propagate(gat, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, +) + check_num_nodes(fg, X) + _, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +) X end -(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::GATConv)(x::AbstractMatrix) = l(l.fg, x) function Base.show(io::IO, l::GATConv) @@ -289,14 +283,13 @@ end """ - GatedGraphConv([graph,] out, num_layers; aggr=+, init=glorot_uniform) + GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform) Gated graph convolution layer. # Arguments -- `graph`: Should be a adjacency matrix, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). Is optionnal so you can give a `FeaturedGraph` to -the layer instead of only the features. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). - `out`: The dimension of output features. - `num_layers`: The number of gated recurrent unit. - `aggr`: Keyword argument, an aggregate function applied to the result of message function. `+`, `max` and `mean` are available. @@ -322,27 +315,29 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) = @functor GatedGraphConv -message(g::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j +message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j -update(g::GatedGraphConv, m::AbstractVector, x) = m +update(ggc::GatedGraphConv, m::AbstractVector, x) = m function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} - adj = adjacency_list(fg) + check_num_nodes(fg, H) m, n = size(H) @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." - GraphSignals.check_num_node(adj, H) - (m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n))) - + adj = adjacency_list(fg) + if m < ggc.out_ch + Hpad = similar(H, S, ggc.out_ch - m, n) + H = vcat(H, fill!(Hpad, 0)) + end for i = 1:ggc.num_layers M = view(ggc.weight, :, :, i) * H - _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(adj)), M, +) + _, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +) H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381 end H end -(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x) @@ -355,14 +350,14 @@ end """ - EdgeConv(graph, nn; aggr=max) + EdgeConv([fg,] nn; aggr=max) Edge convolutional layer. # Arguments -- `graph`: Should be a adjacency matrix, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). -- `nn`: A neural network or a layer. It can be, e.g., MLP. +- `fg`: Optionally pass a [`FeaturedGraph`](@ref). +- `nn`: A neural network (e.g. a Dense layer or a MLP). - `aggr`: Keyword argument, an aggregate function applied to the result of message function. `+`, `max` and `mean` are available. """ struct EdgeConv{V<:AbstractFeaturedGraph} <: MessagePassing @@ -376,17 +371,16 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...) @functor EdgeConv -message(e::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = e.nn(vcat(x_i, x_j .- x_i)) -update(e::EdgeConv, m::AbstractVector, x) = m +message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) +update(ec::EdgeConv, m::AbstractVector, x) = m -function (e::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) - g = graph(fg) - GraphSignals.check_num_node(g, X) - _, X = propagate(e, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, e.aggr) +function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix) + check_num_nodes(fg, X) + _, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr) X end -(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) +(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg))) (l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x) function Base.show(io::IO, l::EdgeConv) diff --git a/src/layers/gn.jl b/src/layers/gn.jl index 4a8106b2e..496d3dcf2 100644 --- a/src/layers/gn.jl +++ b/src/layers/gn.jl @@ -35,24 +35,27 @@ end end @inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge) - @nospecialize E accu_edge num_V num_E + @nospecialize E accu_edge + return nothing end @inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E) @inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E) @nospecialize E + return nothing end @inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V) @inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V) @nospecialize V + return nothing end function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr) - FeaturedGraph(graph(fg), nf=V, ef=E, gf=u) + FeaturedGraph(fg, nf=V, ef=E, gf=u) end function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P, @@ -63,5 +66,5 @@ function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P, ē = aggregate_edges(gn, eaggr, E) v̄ = aggregate_vertices(gn, vaggr, V) u = update_global(gn, ē, v̄, u) - E, V, u + return E, V, u end diff --git a/src/layers/misc.jl b/src/layers/misc.jl index 20f4e1e5c..3f6ed324f 100644 --- a/src/layers/misc.jl +++ b/src/layers/misc.jl @@ -17,11 +17,11 @@ function FeatureSelector(feature::Symbol) end function (fs::FeatureSelector)(fg::FeaturedGraph) - if fs.enable_node_feature && has_node_feature(fg) + if fs.enable_node_feature return node_feature(fg) - elseif fs.enable_edge_feature && has_edge_feature(fg) + elseif fs.enable_edge_feature return edge_feature(fg) - elseif fs.enable_global_feature && has_global_feature(fg) + elseif fs.enable_global_feature return global_feature(fg) end end @@ -36,7 +36,7 @@ Bypassing graph in FeaturedGraph and let other layer process (node, edge and glo """ function bypass_graph(nf_func=identity, ef_func=identity, gf_func=identity) return function (fg::FeaturedGraph) - FeaturedGraph(graph(fg), + FeaturedGraph(fg, nf=nf_func(node_feature(fg)), ef=ef_func(edge_feature(fg)), gf=gf_func(global_feature(fg))) diff --git a/src/layers/msgpass.jl b/src/layers/msgpass.jl index 10e5cecee..9aee05391 100644 --- a/src/layers/msgpass.jl +++ b/src/layers/msgpass.jl @@ -1,55 +1,62 @@ abstract type MessagePassing <: GraphNet end """ - message(mp, x_i, x_j, e_ij) + message(mp::MessagePassing, x_i, x_j, e_ij) -Message function for message-passing scheme. This function can be overrided to dispatch to custom layers. -First argument should be message-passing layer, the rest of arguments can be `x_i`, `x_j` and `e_ij`. +Message function for the message-passing scheme, +returning the message from node `j` to node `i` . +In the message-passing scheme. the incoming messages +from the neighborhood of `i` will later be aggregated +in order to [`update`](@ref) the features of node `i`. + +By default, the function returns `x_j`. +Layers subtyping [`MessagePassing`](@ref) should +specialize this method with custom behavior. # Arguments + - `mp`: message-passing layer. -- `x_i`: the feature of node `x_i`. -- `x_j`: the feature of neighbors of node `x_i`. -- `e_ij`: the feature of edge (`x_i`, `x_j`). +- `x_i`: the features of node `i`. +- `x_j`: the features of the nighbor `j` of node `i`. +- `e_ij`: the features of edge (`i`, `j`). + +See also [`update`](@ref). """ @inline message(mp::MessagePassing, x_i, x_j, e_ij) = x_j @inline message(mp::MessagePassing, i::Integer, j::Integer, x_i, x_j, e_ij) = x_j """ - update(mp, m, x) + update(mp::MessagePassing, m, x) -Update function for message-passing scheme. This function can be overrided to dispatch to custom layers. -First argument should be message-passing layer, the rest of arguments can be `X` and `M`. +Update function for the message-passing scheme, +returning a new set of node features `x′` based on old +features `x` and the incoming message from the neighborhood +aggregation `m`. + +By default, the function returns `m`. +Layers subtyping [`MessagePassing`](@ref) should +specialize this method with custom behavior. # Arguments + - `mp`: message-passing layer. -- `m`: the message aggregated from message function. -- `x`: the single node feature. +- `m`: the aggregated edge messages from the [`message`](@ref) function. +- `x`: the node features to be updated. + +See also [`message`](@ref). """ @inline update(mp::MessagePassing, m, x) = m @inline update(mp::MessagePassing, i::Integer, m, x) = m -@inline function update_batch_edge(mp::MessagePassing, adj, E::AbstractMatrix, X::AbstractMatrix, u) - n = size(adj, 1) - edge_idx = edge_index_table(adj) - mapreduce(i -> apply_batch_message(mp, i, adj[i], edge_idx, E, X, u), hcat, 1:n) -end - @inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::AbstractMatrix, X::AbstractMatrix, u) = mapreduce(j -> GeometricFlux.message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js) @inline update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::AbstractMatrix, u) = mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2)) -@inline function aggregate_neighbors(mp::MessagePassing, aggr, M::AbstractMatrix, accu_edge) - @assert !iszero(accu_edge) "accumulated edge must not be zero." - cluster = generate_cluster(M, accu_edge) - NNlib.scatter(aggr, M, cluster) -end - function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+) E, X = propagate(mp, adjacency_list(fg), fg.ef, fg.nf, aggr) - FeaturedGraph(graph(fg), nf=X, ef=E, gf=Fill(0.f0, 0)) + FeaturedGraph(fg, nf=X, ef=E, gf=Fill(0.f0, 0)) end function propagate(mp::MessagePassing, adj::AbstractVector{S}, E::R, X::Q, aggr) where {S<:AbstractVector,R,Q} diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 788ccd5a7..f118e02e1 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -39,10 +39,6 @@ struct LocalPool{A<:AbstractArray} cluster::A end -function LocalPool(aggr, cluster::AbstractArray) - LocalPool{typeof(cluster)}(aggr, cluster) -end - (l::LocalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster) """ diff --git a/src/models.jl b/src/models.jl index 56a0f6846..9a9594453 100644 --- a/src/models.jl +++ b/src/models.jl @@ -68,10 +68,10 @@ end (i::InnerProductDecoder)(Z::AbstractMatrix)::AbstractMatrix = i.σ.(Z'*Z) -function (i::InnerProductDecoder)(fg::FeaturedGraph)::FeaturedGraph +function (i::InnerProductDecoder)(fg::FeaturedGraph) Z = node_feature(fg) A = i(Z) - FeaturedGraph(graph(fg), nf=A) + return FeaturedGraph(fg, nf=A) end @@ -106,7 +106,7 @@ end function (ve::VariationalEncoder)(fg::FeaturedGraph)::FeaturedGraph μ, logσ = summarize(ve, fg) Z = sample(μ, logσ) - FeaturedGraph(graph(fg), nf=Z) + FeaturedGraph(fg, nf=Z) end function summarize(ve::VariationalEncoder, fg::FeaturedGraph) diff --git a/src/utils.jl b/src/utils.jl index 24077eee4..46fa9d66b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -111,6 +111,17 @@ function transform(X::AbstractArray, eidx::Dict) Y end -### TODO move these to GraphSignals ###### -# @functor FeaturedGraph -# Zygote.@nograd normalized_laplacian, scaled_laplacian +function check_num_nodes(fg::FeaturedGraph, x::AbstractArray) + @assert nv(fg) == size(x, ndims(x)) +end + +### TODO move this to GraphSignals ###### +import GraphSignals: FeaturedGraph + +function FeaturedGraph(fg::FeaturedGraph; + nf=node_feature(fg), + ef=edge_feature(fg), + gf=global_feature(fg)) + + return FeaturedGraph(graph(fg); nf, ef, gf) +end \ No newline at end of file diff --git a/test/cuda/conv.jl b/test/cuda/conv.jl index eacceca76..386edc94e 100644 --- a/test/cuda/conv.jl +++ b/test/cuda/conv.jl @@ -15,7 +15,7 @@ fg = FeaturedGraph(adj) gc = GCNConv(fg, in_channel=>out_channel) |> gpu @test size(gc.weight) == (out_channel, in_channel) @test size(gc.bias) == (out_channel,) - @test collect(graph(gc.fg)) == adj + @test Array(adjacency_matrix(gc.fg)) == adj X = rand(in_channel, N) |> gpu Y = gc(X) @@ -35,7 +35,7 @@ fg = FeaturedGraph(adj) cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu @test size(cc.weight) == (out_channel, in_channel, k) @test size(cc.bias) == (out_channel,) - @test collect(graph(cc.fg)) == adj + @test Array(adjacency_matrix(cc.fg)) == adj @test cc.k == k @test cc.in_channel == in_channel @test cc.out_channel == out_channel diff --git a/test/layers/conv.jl b/test/layers/conv.jl index cecc2cf5d..ffe6ee3cf 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -67,8 +67,8 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) end @testset "bias=false" begin - length(Flux.params(GCNConv(2=>3))) == 2 - length(Flux.params(GCNConv(2=>3, bias=false))) == 1 + @test length(Flux.params(GCNConv(2=>3))) == 2 + @test length(Flux.params(GCNConv(2=>3, bias=false))) == 1 end end @@ -83,9 +83,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(cc.bias) == (out_channel,) @test graph(cc.fg) === adj @test cc.k == k - @test cc.in_channel == in_channel - @test cc.out_channel == out_channel - + Y = cc(X) @test size(Y) == (out_channel, N) @@ -107,9 +105,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @test size(cc.bias) == (out_channel,) @test !has_graph(cc.fg) @test cc.k == k - @test cc.in_channel == in_channel - @test cc.out_channel == out_channel - + fg = FeaturedGraph(adj, nf=X) fg_ = cc(fg) @test size(node_feature(fg_)) == (out_channel, N) @@ -129,8 +125,8 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) end @testset "bias=false" begin - length(Flux.params(ChebConv(2=>3, 3))) == 2 - length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 + @test length(Flux.params(ChebConv(2=>3, 3))) == 2 + @test length(Flux.params(ChebConv(2=>3, 3, bias=false))) == 1 end end @@ -187,8 +183,8 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) @testset "bias=false" begin - length(Flux.params(GraphConv(2=>3))) == 3 - length(Flux.params(GraphConv(2=>3, bias=false))) == 2 + @test length(Flux.params(GraphConv(2=>3))) == 3 + @test length(Flux.params(GraphConv(2=>3, bias=false))) == 2 end end @@ -258,8 +254,8 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex) end @testset "bias=false" begin - length(Flux.params(GATConv(2=>3))) == 3 - length(Flux.params(GATConv(2=>3, bias=false))) == 2 + @test length(Flux.params(GATConv(2=>3))) == 3 + @test length(Flux.params(GATConv(2=>3, bias=false))) == 2 end end