Skip to content

Commit

Permalink
add w_mul_xj
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 22, 2022
1 parent df5b508 commit 5d10ed9
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 48 deletions.
1 change: 1 addition & 0 deletions docs/src/api/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ copy_xi
copy_xj
xi_dot_xj
e_mul_xj
w_mul_xj
```
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export add_nodes,
rand_edge_split,
remove_self_loops,
remove_multi_edges,
set_edge_weight,
# from Flux
batch,
unbatch,
Expand Down
10 changes: 5 additions & 5 deletions src/GNNGraphs/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted=
A = T.(A)
end
if !weighted
A = map(x -> x > 0 ? T(1) : T(0), A)
A = map(x -> ifelse(x > 0, T(1), T(0)), A)
end
return A, num_nodes, num_edges
end
Expand Down Expand Up @@ -121,10 +121,10 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
val = T(1)
end
A = fill!(similar(s, T, (n, n)), 0)
v = vec(A)
v = vec(A) # vec view of A
idxs = s .+ n .* (t .- 1)
NNlib.scatter!(+, v, val, idxs)
# A[s .+ n .* (t .- 1)] .= val # exploiting linear indexing
# A[idxs] .= val # exploiting linear indexing
NNlib.scatter!(+, v, val, idxs) # using scatter instead of indexing since there could be multiple edges
return A, n, length(s)
end

Expand All @@ -146,7 +146,7 @@ function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing, weighted
A = sparse(A)
end
if !weighted
A = map(x -> x > 0 ? T(1) : T(0), A)
A = map(x -> ifelse(x > 0, T(1), T(0)), A)
end
return A, num_nodes, num_edges
end
Expand Down
15 changes: 15 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ end
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
Optionally, pass the features `edata` for the new edges.
"""
function add_edges(g::GNNGraph{<:COO_T},
snew::AbstractVector{<:Integer},
Expand Down Expand Up @@ -155,6 +156,20 @@ function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;))
ndata, g.edata, g.gdata)
end

"""
set_edge_weight(g::GNNGraph, w::AbstractVector)
Set `w` as edge weights in the returned graph.
"""
function set_edge_weight(g::GNNGraph, w::AbstractVector)
s, t = edge_index(g)
@assert length(w) == length(s)

return GNNGraph((s, t, w),
g.num_nodes, g.num_edges, g.num_graphs,
g.graph_indicator,
g.ndata, g.edata, g.gdata)
end

function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
nv1, nv2 = g1.num_nodes, g2.num_nodes
Expand Down
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export
copy_xi,
xi_dot_xj,
e_mul_xj,
w_mul_xj,

# layers/basic
GNNLayer,
Expand Down
42 changes: 8 additions & 34 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,11 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
end

function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix)
# Extract edge_weight from g if available and l.edge_weight == true,
# otherwise return nothing.
edge_weight = GNNGraphs._get_edge_weight(g, l.use_edge_weight) # vector or nothing
return l(g, x, edge_weight)
end

function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::EW) where
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW=nothing) where
{T, EW<:Union{Nothing,AbstractVector}}

@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"

if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
Expand All @@ -100,10 +95,12 @@ function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::E
d = degree(g, T; dir=:in, edge_weight)
c = 1 ./ sqrt.(d)
x = x .* c'
if edge_weight === nothing
x = propagate(copy_xj, g, +, xj=x)
else
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj=x)
else
x = propagate(copy_xj, g, +, xj=x)
end
x = x .* c'
if Dout >= Din
Expand All @@ -112,29 +109,6 @@ function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix{T}, edge_weight::E
return l.σ.(x .+ l.bias)
end

# TODO merge the ADJMAT_T and COO_T methods
# The main problem is handling the weighted case for both.
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix{T}) where T
if l.add_self_loops
g = add_self_loops(g)
end
Dout, Din = size(l.weight)
if Dout < Din
# multiply before convolution if it is more convenient, otherwise multiply after
x = l.weight * x
end
d = degree(g, T; dir=:in, edge_weight=l.use_edge_weight)
c = 1 ./ sqrt.(d)
x = x .* c'
A = adjacency_matrix(g, weighted=l.use_edge_weight)
x = x * A
x = x .* c'
if Dout >= Din
x = l.weight * x
end
return l.σ.(x .+ l.bias)
end

function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO
return l(g, x, edge_weight)
Expand Down
48 changes: 43 additions & 5 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)



### SPECIALIZATIONS OF PROPAGATE ###
### MESSAGE FUNCTIONS ###
"""
copy_xj(xi, xj, e) = xj
"""
Expand Down Expand Up @@ -178,26 +178,64 @@ function e_mul_xj(xi, xj::AbstractArray{Tj,Nj}, e::AbstractArray{Te,Ne}) where {
return e .* xj
end

"""
w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj
Similar to [`e_mul_xj`](@ref) but specialized on scalar edge feautures (weights).
"""
w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights

function w_mul_xj(xi, xj::AbstractArray{Tj,Nj}, w::AbstractVector) where {Tj, Nj}
w = reshape(w, ntuple(_ -> 1, Nj-1)..., length(w))
return w .* xj
end


###### PROPAGATE SPECIALIZATIONS ####################

## COPY_XJ

function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
A = adjacency_matrix(g, weighted=false)
return xj * A
end

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e) -> copy_xj(xi,xj,e), g, +, xi, xj, e)
end

## E_MUL_XJ

# for weighted convolution
function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::AbstractVector)
s, t = edge_index(g)
g = GNNGraph((s, t, e); g.num_nodes)
g = set_edge_weight(g, e)
A = adjacency_matrix(g, weighted=true)
return xj * A
end

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e::AbstractVector)
propagate((xi,xj,e) -> e_mul_xj(xi,xj,e), g, +, xi, xj, e)
end

## W_MUL_XJ

# for weighted convolution
function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e::Nothing)
A = adjacency_matrix(g, weighted=true)
return xj * A
end

## avoid the fast path on gpu until we have better cuda support
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e::Nothing)
propagate((xi,xj,e) -> w_mul_xj(xi,xj,e), g, +, xi, xj, e)
end





# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
# A = adjacency_matrix(g, weigthed=false)
# D = compute_degree(A)
Expand Down
14 changes: 14 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,18 @@
@test g2.num_edges < 50
end
end

@testset "set_edge_weight" begin
g = rand_graph(10, 20, graph_type=GRAPH_T)
w = rand(20)

gw = set_edge_weight(g, w)
@test get_edge_weight(gw) == w

# now from weighted graph
s, t = edge_index(g)
g2 = GNNGraph(s, t, rand(20), graph_type=GRAPH_T)
gw2 = set_edge_weight(g2, w)
@test get_edge_weight(gw2) == w
end
end
16 changes: 12 additions & 4 deletions test/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,36 @@
@test spmm_copyxj_fused(g) X * Adj
end

@testset "e_mul_xj for weighted conv" begin
@testset "e_mul_xj adn w_mul_xj for weighted conv" begin
n = 128
A = sprand(n, n, 0.1)
Adj = map(x -> x > 0 ? 1 : 0, A)
X = rand(10, n)

g = GNNGraph(A, ndata=X, edata=reshape(A.nzval, 1, :), graph_type=GRAPH_T)
g = GNNGraph(A, ndata=X, edata=A.nzval, graph_type=GRAPH_T)

function spmm_unfused(g)
propagate(
(xi, xj, e) -> e .* xj ,
(xi, xj, e) -> reshape(e, 1, :) .* xj ,
g, +; xj=g.ndata.x, e=g.edata.e
)
end
function spmm_fused(g)
propagate(
e_mul_xj,
g, +; xj=g.ndata.x, e=vec(g.edata.e)
g, +; xj=g.ndata.x, e=g.edata.e
)
end

function spmm_fused2(g)
propagate(
w_mul_xj,
g, +; xj=g.ndata.x
)
end

@test spmm_unfused(g) X * A
@test spmm_fused(g) X * A
@test spmm_fused2(g) X * A
end
end

0 comments on commit 5d10ed9

Please sign in to comment.