Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COO FeaturedGraph #204

Closed
wants to merge 11 commits into from
Closed
7 changes: 1 addition & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ version = "0.7.6"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc"
GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8"
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -23,11 +21,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
CUDA = "3.3"
DataStructures = "0.18"
FillArrays = "0.11, 0.12"
Flux = "0.12"
GraphLaplacians = "0.1"
GraphMLDatasets = "0.1"
GraphSignals = "0.2"
LightGraphs = "1.3"
NNlib = "0.7"
NNlibCUDA = "0.1"
Expand Down
18 changes: 9 additions & 9 deletions docs/src/abstractions/msgpass.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ A message function accepts feature vector representing node state `x_i`, feature
Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following:

```
propagate(mp, fg::FeaturedGraph, aggr::Symbol=:add)
propagate(mp, fg::FeaturedGraph; aggr::Symbol=+)
```

`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `:add` represents an aggregate function of addition of all messages.
`propagate` function calls the whole message passing layer. `fg` acts as an input for message passing layer and `aggr` represents assignment of aggregate function to `propagate` function. `+` represents an aggregate function of addition of all messages.

The following `aggr` are available aggregate functions:

`:add`: sum over all messages
`:sub`: negative of sum over all messages
`:mul`: multiplication over all messages
`:div`: inverse of multiplication over all messages
`:max`: the maximum of all messages
`:min`: the minimum of all messages
`:mean`: the average of all messages
`+`: sum over all messages
`-`: negative of sum over all messages
`*`: multiplication over all messages
`/`: inverse of multiplication over all messages
`max`: the maximum of all messages
`min`: the minimum of all messages
`mean`: the average of all messages

## Update function

Expand Down
34 changes: 18 additions & 16 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
module GeometricFlux

using NNlib: similar
using ChainRulesCore: eltype, reshape
using LinearAlgebra: similar
using Statistics: mean
using LinearAlgebra: Adjoint, norm, Transpose
using Reexport
using LinearAlgebra

using CUDA
using FillArrays: Fill
using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
using NNlib, NNlibCUDA
using GraphLaplacians
@reexport using GraphSignals
using LightGraphs
using Zygote
using ChainRulesCore
import LightGraphs
using LightGraphs: AbstractGraph, outneighbors, inneighbors, is_directed, ne, nv,
adjacency_matrix, degree

export
# layers/gn
GraphNet,
# featured_graph
FeaturedGraph,
edge_index,
node_feature, edge_feature, global_feature,
adjacency_list, normalized_laplacian, scaled_laplacian,

# from LightGraphs
ne, nv, adjacency_matrix,

# layers/msgpass
MessagePassing,
Expand Down Expand Up @@ -44,26 +52,20 @@ export
sample,

# layer/selector
bypass_graph,

# utils
generate_cluster
bypass_graph

include("featured_graph.jl")
include("datasets.jl")

include("utils.jl")

include("layers/gn.jl")
include("layers/msgpass.jl")

include("layers/conv.jl")
include("layers/pool.jl")
include("models.jl")
include("layers/misc.jl")

include("cuda/msgpass.jl")
include("cuda/conv.jl")

using .Datasets


Expand Down
26 changes: 0 additions & 26 deletions src/cuda/conv.jl

This file was deleted.

41 changes: 0 additions & 41 deletions src/cuda/msgpass.jl

This file was deleted.

Loading