diff --git a/Project.toml b/Project.toml index 0ea94fba8..aa0f9ff27 100644 --- a/Project.toml +++ b/Project.toml @@ -9,9 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc" GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8" -GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -26,9 +24,7 @@ CUDA = "3.3" DataStructures = "0.18" FillArrays = "0.11, 0.12" Flux = "0.12" -GraphLaplacians = "0.1" GraphMLDatasets = "0.1" -GraphSignals = "0.2" LightGraphs = "1.3" NNlib = "0.7" NNlibCUDA = "0.1" diff --git a/src/GeometricFlux.jl b/src/GeometricFlux.jl index a9eecc2ae..50244ee15 100644 --- a/src/GeometricFlux.jl +++ b/src/GeometricFlux.jl @@ -1,8 +1,7 @@ module GeometricFlux -using Base: check_channel_state using Statistics: mean -using LinearAlgebra: Adjoint, norm, Transpose +using LinearAlgebra using FillArrays: Fill using CUDA @@ -10,9 +9,11 @@ using Flux using Flux: glorot_uniform, leakyrelu, GRUCell, @functor using NNlib, NNlibCUDA using Zygote +using ChainRulesCore -import GraphLaplacians -using GraphLaplacians: normalized_laplacian, scaled_laplacian + +# import GraphLaplacians +# using GraphLaplacians: normalized_laplacian, scaled_laplacian # using GraphLaplacians: adjacency_matrix # using Reexport # @reexport using GraphSignals @@ -23,7 +24,7 @@ export # featured_graph FeaturedGraph, adjacency_list, - graph, + # graph, # has_graph, node_feature, edge_feature, global_feature, # ne, nv, adjacency_matrix, # from LightGraphs diff --git a/src/featured_graph.jl b/src/featured_graph.jl index 466b255f2..1dd10c551 100644 --- a/src/featured_graph.jl +++ b/src/featured_graph.jl @@ -171,9 +171,8 @@ function LightGraphs.adjacency_matrix(fg::FeaturedGraph, T::DataType=Int; dir=:o return dir == :out ? adj_mat : adj_mat' end -Zygote.@nograd adjacency_matrix, adjacency_list #, FeaturedGraph +Zygote.@nograd adjacency_matrix, adjacency_list -# using ChainRulesCore # function ChainRulesCore.rrule(::typeof(copy), x) # copy_pullback(ȳ) = (NoTangent(), ȳ) @@ -233,18 +232,71 @@ end LightGraphs.is_directed(g::AbstractMatrix) = !issymmetric(Matrix(g)) -function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:unspec) - if dir == :unspec - dir = is_directed(g) ? :both : :out +# function LightGraphs.laplacian_matrix(fg::FeaturedGraph, T::DataType=Int; dir::Symbol=:out) +# A = adjacency_matrix(fg, T; dir=dir) +# D = Diagonal(vec(sum(A; dims=2))) +# return D - A +# end + +## from GraphLaplacians + +""" + normalized_laplacian(g[, T]; selfloop=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). +- `T`: result element type of degree vector; default is the element type of `g` (optional). +- `selfloop`: adding self loop while calculating the matrix (optional). +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(fg::FeaturedGraph, T::DataType=Int; selfloop::Bool=false, dir::Symbol=:out) + A = adjacency_matrix(fg, T; dir=dir) + selfloop && (A += I) + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return I - inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g[, T]; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: should be a adjacency matrix, `FeaturedGraph`, `SimpleGraph`, `SimpleDiGraph` (from LightGraphs) or `SimpleWeightedGraph`, `SimpleWeightedDiGraph` (from SimpleWeightedGraphs). +- `T`: result element type of degree vector; default is the element type of `g` (optional). +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(fg::FeaturedGraph, T::DataType=Int; dir=:out) + A = adjacency_matrix(fg, T; dir=dir) + @assert issymmetric(A) "scaled_laplacian only works with symmetric matrices" + E = eigen(Symmetric(A)).values + degs = vec(sum(A; dims=2)) + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + Lnorm = I - inv_sqrtD * A * inv_sqrtD + return 2 / maximum(E) * Lnorm - I +end + + +function add_self_loop!(adj::AbstractVector{<:AbstractVector}) + for i = 1:length(adj) + i in adj[i] || push!(adj[i], i) end - A = adjacency_matrix(g, T; dir=dir) - s = sum(A; dims=2) - D = Diagonal(vec(s)) - return D - A + adj end -# TODO Do we need a separate package just for laplacians? -GraphLaplacians.scaled_laplacian(fg::FeaturedGraph, T::DataType) = - scaled_laplacian(adjacency_matrix(fg, T)) -GraphLaplacians.normalized_laplacian(fg::FeaturedGraph, T::DataType; kws...) = - normalized_laplacian(adjacency_matrix(fg, T); kws...) \ No newline at end of file +# # TODO Do we need a separate package just for laplacians? +# GraphLaplacians.scaled_laplacian(fg::FeaturedGraph, T::DataType) = +# scaled_laplacian(adjacency_matrix(fg, T)) +# GraphLaplacians.normalized_laplacian(fg::FeaturedGraph, T::DataType; kws...) = +# normalized_laplacian(adjacency_matrix(fg, T); kws...) + + +@non_differentiable normalized_laplacian(x...) +@non_differentiable scaled_laplacian(x...) +@non_differentiable add_self_loop!(x...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e519fbed5..08f0bb2ab 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -247,10 +247,7 @@ update_batch_edge(g::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = up function update_batch_edge(g::GATConv, adj, X::AbstractMatrix) n = size(adj, 1) - # a vertex must always receive a message from itself - Zygote.ignore() do - GraphLaplacians.add_self_loop!(adj, n) - end + add_self_loop!(adj) mapreduce(i -> apply_batch_message(g, i, adj[i], X), hcat, 1:n) end