-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into new-multidocs
- Loading branch information
Showing
22 changed files
with
302 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
```@meta | ||
CurrentModule = GraphNeuralNetworks | ||
``` | ||
|
||
# Samplers | ||
|
||
|
||
## Docs | ||
|
||
```@autodocs | ||
Modules = [GraphNeuralNetworks] | ||
Pages = ["samplers.jl"] | ||
Private = false | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Datasets | ||
|
||
GraphNeuralNetworks.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem. In particular, the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/). | ||
|
||
GraphNeuralNetworks.jl provides the [`mldataset2gnngraph`](@ref) method for interfacing with MLDatasets.jl. | ||
|
||
```@docs | ||
mldataset2gnngraph | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
# GraphNeuralNetworks Monorepo | ||
|
||
This repository is a monorepo that contains all the code for the GraphNeuralNetworks project. The project is organized as a monorepo to facilitate code sharing and reusability across different components of the project. The monorepo contains the following packages: | ||
This is the documentation page for [GraphNeuralNetworks.jl](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl), a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). | ||
GraphNeuralNetworks.jl is largely inspired by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/), [Deep Graph Library](https://docs.dgl.ai/), | ||
and [GeometricFlux.jl](https://fluxml.ai/GeometricFlux.jl/stable/). | ||
|
||
- `GraphNeuralNetwork.jl`: Package that contains stateful graph convolutional layers based on the machine learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/). This is fronted package for Flux users. It depends on GNNlib.jl, GNNGraphs.jl, and Flux.jl packages. | ||
|
||
- `GNNLux.jl`: Package that contains stateless graph convolutional layers based on the machine learning framework [Lux.jl](https://lux.csail.mit.edu/stable/). This is fronted package for Lux users. It depends on GNNlib.jl, GNNGraphs.jl, and Lux.jl packages. | ||
|
||
- `GNNlib.jl`: Package that contains the core graph neural network layers and utilities. It depends on GNNGraphs.jl and GNNlib.jl packages and serves for code base for GraphNeuralNetwork.jl and GNNLux.jl packages. | ||
|
||
- `GNNGraphs.jl`: Package that contains the graph data structures and helper functions for working with graph data. It depends on Graphs.jl package. | ||
|
||
Here is a schema of the dependencies between the packages: | ||
|
||
![Monorepo schema](assets/schema.png) | ||
* Implements common graph convolutional layers. | ||
* Supports computations on batched graphs. | ||
* Easy to define custom layers. | ||
* CUDA support. | ||
* Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl). | ||
* [Examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. | ||
|
||
|
||
|
||
|
||
Usage examples on real datasets can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) folder. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
""" | ||
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size]) | ||
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). | ||
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training | ||
originally introduced in "Inductive Representation Learning on Large Graphs" paper. | ||
[see https://arxiv.org/abs/1706.02216] | ||
# Fields | ||
- `graph::GNNGraph`: The input graph. | ||
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. | ||
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. | ||
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). | ||
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. | ||
# Usage | ||
```jldoctest | ||
julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) | ||
julia> batch_counter = 0 | ||
julia> for mini_batch_gnn in loader | ||
batch_counter += 1 | ||
println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn)) | ||
``` | ||
""" | ||
struct NeighborLoader | ||
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) | ||
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer | ||
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) | ||
num_layers::Int # Number of layers for neighborhood expansion | ||
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given | ||
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation | ||
end | ||
|
||
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing, | ||
num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers, | ||
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) | ||
end | ||
|
||
# Function to get cached neighbors or compute them | ||
function get_neighbors(loader::NeighborLoader, node::Int) | ||
if haskey(loader.neighbors_cache, node) | ||
return loader.neighbors_cache[node] | ||
else | ||
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph | ||
loader.neighbors_cache[node] = neighbors | ||
return neighbors | ||
end | ||
end | ||
|
||
# Function to sample neighbors for a given node at a specific layer | ||
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||
neighbors = get_neighbors(loader, node) | ||
if isempty(neighbors) | ||
return Int[] | ||
else | ||
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer | ||
return rand(neighbors, num_samples) # Randomly sample neighbors | ||
end | ||
end | ||
|
||
# Iterator protocol for NeighborLoader with lazy batch loading | ||
function Base.iterate(loader::NeighborLoader, state=1) | ||
if state > length(loader.input_nodes) | ||
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) | ||
end | ||
|
||
# Determine the size of the current batch | ||
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch | ||
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes | ||
|
||
# Set for tracking the subgraph nodes | ||
subgraph_nodes = Set(batch_nodes) | ||
|
||
for node in batch_nodes | ||
# Initialize current layer of nodes (starting with the node itself) | ||
sampled_neighbors = Set([node]) | ||
|
||
# For each GNN layer, sample the neighborhood | ||
for layer in 1:loader.num_layers | ||
new_neighbors = Set{Int}() | ||
for n in sampled_neighbors | ||
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer | ||
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set | ||
end | ||
sampled_neighbors = new_neighbors | ||
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors | ||
end | ||
end | ||
|
||
# Collect subgraph nodes and their features | ||
subgraph_node_list = collect(subgraph_nodes) | ||
|
||
if isempty(subgraph_node_list) | ||
return GNNGraph(), state + batch_size | ||
end | ||
|
||
mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes | ||
|
||
# Continue iteration for the next batch | ||
return mini_batch_gnn, state + batch_size | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.