diff --git a/GNNGraphs/Project.toml b/GNNGraphs/Project.toml index 310b1205e..bd3b83340 100644 --- a/GNNGraphs/Project.toml +++ b/GNNGraphs/Project.toml @@ -1,7 +1,7 @@ name = "GNNGraphs" uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c" authors = ["Carlo Lucibello and contributors"] -version = "1.2.1" +version = "1.3.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 7e723182a..6e38730f0 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) node_map = Dict(node => i for (i, node) in enumerate(nodes)) + edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])] + # Collect edges to add source = Int[] target = Int[] @@ -187,8 +189,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) if neighbor in keys(node_map) push!(target, node_map[node]) push!(source, node_map[neighbor]) - - eindex = findfirst(x -> x == [neighbor, node], edge_index(graph)) + eindex = findfirst(x -> x == [neighbor, node], edge_list) push!(eindices, eindex) end end diff --git a/GNNlib/docs/src/messagepassing.md b/GNNlib/docs/src/messagepassing.md index 709192319..954fb9dd2 100644 --- a/GNNlib/docs/src/messagepassing.md +++ b/GNNlib/docs/src/messagepassing.md @@ -134,7 +134,7 @@ function (l::GCN)(g::GNNGraph, x::AbstractMatrix{T}) where T end ``` -See the `GATConv` implementation [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example. +See the `GATConv` implementation [here](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl) for a more complex example. ## Built-in message functions diff --git a/GraphNeuralNetworks/Project.toml b/GraphNeuralNetworks/Project.toml index 89979ff69..eb9c44caf 100644 --- a/GraphNeuralNetworks/Project.toml +++ b/GraphNeuralNetworks/Project.toml @@ -1,7 +1,7 @@ name = "GraphNeuralNetworks" uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694" authors = ["Carlo Lucibello and contributors"] -version = "0.6.21" +version = "0.6.22" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -28,6 +29,7 @@ CUDA = "4, 5" ChainRulesCore = "1" Flux = "0.14" Functors = "0.4.1" +Graphs = "1.12" GNNGraphs = "1.0" GNNlib = "0.2" LinearAlgebra = "1" diff --git a/GraphNeuralNetworks/README.md b/GraphNeuralNetworks/README.md index 434e4cd46..565ee8f42 100644 --- a/GraphNeuralNetworks/README.md +++ b/GraphNeuralNetworks/README.md @@ -1,12 +1,12 @@ - + # GraphNeuralNetworks.jl -[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/stable) -[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/dev) -![](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg) -[![codecov](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl) +[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/stable) +[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/dev) +![](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg) +[![codecov](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl) GraphNeuralNetworks.jl is a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). @@ -18,7 +18,7 @@ Among its features: * Easy to define custom layers. * CUDA support. * Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl). -* [Examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. +* [Examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. * Heterogeneous and temporal graphs. ## Installation @@ -31,7 +31,7 @@ pkg> add GraphNeuralNetworks ## Usage -Usage examples can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) and in the [notebooks](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/notebooks) folder. Also, make sure to read the [documentation](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/dev) for a comprehensive introduction to the library. +Usage examples can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) and in the [notebooks](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/notebooks) folder. Also, make sure to read the [documentation](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/dev) for a comprehensive introduction to the library. ## Citing @@ -43,7 +43,7 @@ If you use GraphNeuralNetworks.jl in a scientific publication, we would apprecia author = {Carlo Lucibello and other contributors}, title = {GraphNeuralNetworks.jl: a geometric deep learning library for the Julia programming language}, year = 2021, - url = {https://github.com/CarloLucibello/GraphNeuralNetworks.jl} + url = {https://github.com/JuliaGraphs/GraphNeuralNetworks.jl} } ``` diff --git a/GraphNeuralNetworks/docs/make.jl b/GraphNeuralNetworks/docs/make.jl index 69b911226..eb84c3dbd 100644 --- a/GraphNeuralNetworks/docs/make.jl +++ b/GraphNeuralNetworks/docs/make.jl @@ -25,8 +25,6 @@ makedocs(; "Home" => "index.md", "Developer guide" => "dev.md", "Google Summer of Code" => "gsoc.md", - - ], "GraphNeuralNetworks.jl" =>[ "Home" => "home.md", @@ -38,7 +36,8 @@ makedocs(; "Convolutional layers" => "api/conv.md", "Pooling layers" => "api/pool.md", "Temporal Convolutional layers" => "api/temporalconv.md", - "Hetero Convolutional layers" => "api/heteroconv.md" + "Hetero Convolutional layers" => "api/heteroconv.md", + "Samplers" => "api/samplers.md", ], @@ -49,4 +48,4 @@ makedocs(; -deploydocs(;repo = "https://github.com/CarloLucibello/GraphNeuralNetworks.jl.git", dirname= "GraphNeuralNetworks") \ No newline at end of file +deploydocs(;repo = "https://github.com/CarloLucibello/GraphNeuralNetworks.jl.git", dirname= "GraphNeuralNetworks") diff --git a/GraphNeuralNetworks/docs/src/api/samplers.md b/GraphNeuralNetworks/docs/src/api/samplers.md new file mode 100644 index 000000000..f4285562c --- /dev/null +++ b/GraphNeuralNetworks/docs/src/api/samplers.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = GraphNeuralNetworks +``` + +# Samplers + + +## Docs + +```@autodocs +Modules = [GraphNeuralNetworks] +Pages = ["samplers.jl"] +Private = false +``` diff --git a/GraphNeuralNetworks/docs/src/datasets.md b/GraphNeuralNetworks/docs/src/datasets.md new file mode 100644 index 000000000..8644509c3 --- /dev/null +++ b/GraphNeuralNetworks/docs/src/datasets.md @@ -0,0 +1,9 @@ +# Datasets + +GraphNeuralNetworks.jl doesn't come with its own datasets, but leverages those available in the Julia (and non-Julia) ecosystem. In particular, the [examples in the GraphNeuralNetworks.jl repository](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) make use of the [MLDatasets.jl](https://github.com/JuliaML/MLDatasets.jl) package. There you will find common graph datasets such as Cora, PubMed, Citeseer, TUDataset and [many others](https://juliaml.github.io/MLDatasets.jl/dev/datasets/graphs/). + +GraphNeuralNetworks.jl provides the [`mldataset2gnngraph`](@ref) method for interfacing with MLDatasets.jl. + +```@docs +mldataset2gnngraph +``` diff --git a/GraphNeuralNetworks/docs/src/dev.md b/GraphNeuralNetworks/docs/src/dev.md index 2a2aae370..f67fac0cf 100644 --- a/GraphNeuralNetworks/docs/src/dev.md +++ b/GraphNeuralNetworks/docs/src/dev.md @@ -80,7 +80,7 @@ julia> compare(dfpr, dfmaster) Tutorials in GraphNeuralNetworks.jl are written in Pluto and rendered using [DemoCards.jl](https://github.com/JuliaDocs/DemoCards.jl) and [PlutoStaticHTML.jl](https://github.com/rikhuijzer/PlutoStaticHTML.jl). Rendering a Pluto notebook is time and resource-consuming, especially in a CI environment. So we use the [caching functionality](https://huijzer.xyz/PlutoStaticHTML.jl/dev/#Caching) provided by PlutoStaticHTML.jl to reduce CI time. -If you are contributing a new tutorial or making changes to the existing notebook, generate the docs locally before committing/pushing. For caching to work, the cache environment(your local) and the documenter CI should have the same Julia version (e.g. "v1.9.1", also the patch number must match). So use the [documenter CI Julia version](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/.github/workflows/docs.yml#L17) for generating docs locally. +If you are contributing a new tutorial or making changes to the existing notebook, generate the docs locally before committing/pushing. For caching to work, the cache environment(your local) and the documenter CI should have the same Julia version (e.g. "v1.9.1", also the patch number must match). So use the [documenter CI Julia version](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/blob/master/.github/workflows/docs.yml#L17) for generating docs locally. ```console julia --version # check julia version before generating docs @@ -95,6 +95,6 @@ During the doc generation process, DemoCards.jl stores the cache notebooks in do git add docs/pluto_output # add generated cache ``` -Check the [documenter CI logs](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/docs.yml) to ensure that it used the local cache: +Check the [documenter CI logs](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/actions/workflows/docs.yml) to ensure that it used the local cache: ![](https://user-images.githubusercontent.com/55111154/210061301-c84b7274-9e66-46fd-b272-d45b1c681d00.png) \ No newline at end of file diff --git a/GraphNeuralNetworks/docs/src/index.md b/GraphNeuralNetworks/docs/src/index.md index ee5918c47..39413eef8 100644 --- a/GraphNeuralNetworks/docs/src/index.md +++ b/GraphNeuralNetworks/docs/src/index.md @@ -1,20 +1,20 @@ # GraphNeuralNetworks Monorepo -This repository is a monorepo that contains all the code for the GraphNeuralNetworks project. The project is organized as a monorepo to facilitate code sharing and reusability across different components of the project. The monorepo contains the following packages: +This is the documentation page for [GraphNeuralNetworks.jl](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl), a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). +GraphNeuralNetworks.jl is largely inspired by [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/), [Deep Graph Library](https://docs.dgl.ai/), +and [GeometricFlux.jl](https://fluxml.ai/GeometricFlux.jl/stable/). - `GraphNeuralNetwork.jl`: Package that contains stateful graph convolutional layers based on the machine learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/). This is fronted package for Flux users. It depends on GNNlib.jl, GNNGraphs.jl, and Flux.jl packages. -- `GNNLux.jl`: Package that contains stateless graph convolutional layers based on the machine learning framework [Lux.jl](https://lux.csail.mit.edu/stable/). This is fronted package for Lux users. It depends on GNNlib.jl, GNNGraphs.jl, and Lux.jl packages. - -- `GNNlib.jl`: Package that contains the core graph neural network layers and utilities. It depends on GNNGraphs.jl and GNNlib.jl packages and serves for code base for GraphNeuralNetwork.jl and GNNLux.jl packages. - -- `GNNGraphs.jl`: Package that contains the graph data structures and helper functions for working with graph data. It depends on Graphs.jl package. - -Here is a schema of the dependencies between the packages: - -![Monorepo schema](assets/schema.png) +* Implements common graph convolutional layers. +* Supports computations on batched graphs. +* Easy to define custom layers. +* CUDA support. +* Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl). +* [Examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. +Usage examples on real datasets can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) folder. diff --git a/GraphNeuralNetworks/docs/src/models.md b/GraphNeuralNetworks/docs/src/models.md index 672a7e9f8..4a7876390 100644 --- a/GraphNeuralNetworks/docs/src/models.md +++ b/GraphNeuralNetworks/docs/src/models.md @@ -122,4 +122,4 @@ X = randn(Float32, din, 10) y = model(X) ``` -An example of `WithGraph` usage is given in the graph neural ODE script in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) folder. +An example of `WithGraph` usage is given in the graph neural ODE script in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) folder. diff --git a/GraphNeuralNetworks/notebooks/gnn_intro.ipynb b/GraphNeuralNetworks/notebooks/gnn_intro.ipynb index db3721ea5..798d4cc33 100644 --- a/GraphNeuralNetworks/notebooks/gnn_intro.ipynb +++ b/GraphNeuralNetworks/notebooks/gnn_intro.ipynb @@ -18,7 +18,7 @@ "\\mathbf{x}_i^{(\\ell + 1)} = f^{(\\ell + 1)}_{\\theta} \\left( \\mathbf{x}_i^{(\\ell)}, \\left\\{ \\mathbf{x}_j^{(\\ell)} : j \\in \\mathcal{N}(i) \\right\\} \\right)\n", "$$\n", "\n", - "This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/CarloLucibello/GraphNeuralNetworks.jl)**.\n", + "This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl)**.\n", "GraphNeuralNetworks.jl is an extension library to the popular deep learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/), and consists of various methods and utilities to ease the implementation of Graph Neural Networks.\n", "\n", "Let's first import the packages we need:" diff --git a/GraphNeuralNetworks/notebooks/graph_classification.ipynb b/GraphNeuralNetworks/notebooks/graph_classification.ipynb index 832c5903d..f407b0559 100644 --- a/GraphNeuralNetworks/notebooks/graph_classification.ipynb +++ b/GraphNeuralNetworks/notebooks/graph_classification.ipynb @@ -654,7 +654,7 @@ "# Exercise 2 \n", "\n", "Define your own convolutional layer drawing inspiration from any of the already existing ones:\n", - "https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl\n", + "https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl\n", "\n", "You can try to:\n", "- use MLPs instead of linear operators\n", diff --git a/GraphNeuralNetworks/notebooks/graph_classification_solved.ipynb b/GraphNeuralNetworks/notebooks/graph_classification_solved.ipynb index a54c5b359..e938c7a43 100644 --- a/GraphNeuralNetworks/notebooks/graph_classification_solved.ipynb +++ b/GraphNeuralNetworks/notebooks/graph_classification_solved.ipynb @@ -647,7 +647,7 @@ "# Exercise 2 \n", "\n", "Define your own convolutional layer drawing inspiration from any of the already existing ones:\n", - "https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl\n", + "https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/blob/master/src/layers/conv.jl\n", "\n", "You can try to:\n", "- use MLPs instead of linear operators\n", diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index c9a227b8d..9ac46e8b1 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -10,6 +10,7 @@ using NNlib: scatter, gather using ChainRulesCore using Reexport using MLUtils: zeros_like +using Graphs: Graphs using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, @@ -66,4 +67,7 @@ export GlobalPool, include("deprecations.jl") +include("samplers.jl") +export NeighborLoader + end diff --git a/GraphNeuralNetworks/src/samplers.jl b/GraphNeuralNetworks/src/samplers.jl new file mode 100644 index 000000000..5c06c1681 --- /dev/null +++ b/GraphNeuralNetworks/src/samplers.jl @@ -0,0 +1,103 @@ +""" + NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size]) + +A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). +It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training +originally introduced in "Inductive Representation Learning on Large Graphs" paper. +[see https://arxiv.org/abs/1706.02216] + +# Fields +- `graph::GNNGraph`: The input graph. +- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. +- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. +- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). +- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. + +# Usage +```jldoctest +julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) + +julia> batch_counter = 0 +julia> for mini_batch_gnn in loader + batch_counter += 1 + println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn)) +``` +""" +struct NeighborLoader + graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) + num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer + input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) + num_layers::Int # Number of layers for neighborhood expansion + batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given + neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation +end + +function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing, + num_layers::Int, batch_size::Union{Int, Nothing}=nothing) + return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers, + batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) +end + +# Function to get cached neighbors or compute them +function get_neighbors(loader::NeighborLoader, node::Int) + if haskey(loader.neighbors_cache, node) + return loader.neighbors_cache[node] + else + neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph + loader.neighbors_cache[node] = neighbors + return neighbors + end +end + +# Function to sample neighbors for a given node at a specific layer +function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) + neighbors = get_neighbors(loader, node) + if isempty(neighbors) + return Int[] + else + num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer + return rand(neighbors, num_samples) # Randomly sample neighbors + end +end + +# Iterator protocol for NeighborLoader with lazy batch loading +function Base.iterate(loader::NeighborLoader, state=1) + if state > length(loader.input_nodes) + return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) + end + + # Determine the size of the current batch + batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch + batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes + + # Set for tracking the subgraph nodes + subgraph_nodes = Set(batch_nodes) + + for node in batch_nodes + # Initialize current layer of nodes (starting with the node itself) + sampled_neighbors = Set([node]) + + # For each GNN layer, sample the neighborhood + for layer in 1:loader.num_layers + new_neighbors = Set{Int}() + for n in sampled_neighbors + neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer + new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set + end + sampled_neighbors = new_neighbors + subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors + end + end + + # Collect subgraph nodes and their features + subgraph_node_list = collect(subgraph_nodes) + + if isempty(subgraph_node_list) + return GNNGraph(), state + batch_size + end + + mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes + + # Continue iteration for the next batch + return mini_batch_gnn, state + batch_size +end diff --git a/GraphNeuralNetworks/test/runtests.jl b/GraphNeuralNetworks/test/runtests.jl index 05cb6fd5f..f796651bb 100644 --- a/GraphNeuralNetworks/test/runtests.jl +++ b/GraphNeuralNetworks/test/runtests.jl @@ -30,6 +30,7 @@ tests = [ "layers/temporalconv", "layers/pool", "examples/node_classification_cora", + "samplers" ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") diff --git a/GraphNeuralNetworks/test/samplers.jl b/GraphNeuralNetworks/test/samplers.jl new file mode 100644 index 000000000..546291717 --- /dev/null +++ b/GraphNeuralNetworks/test/samplers.jl @@ -0,0 +1,125 @@ +# Helper function to create a simple graph with node features using GNNGraph +function create_test_graph() + source = [1, 2, 3, 4] # Define source nodes of edges + target = [2, 3, 4, 5] # Define target nodes of edges + node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes) + + return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features +end + +# Tests for NeighborLoader structure and its functionalities +@testset "NeighborLoader tests" begin + + # 1. Basic functionality: Check neighbor sampling and subgraph creation + @testset "Basic functionality" begin + g = create_test_graph() + + # Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2 + loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph is not empty + @test !isempty(mini_batch_gnn.graph) + + num_sampled_nodes = mini_batch_gnn.num_nodes + println("Number of nodes in mini-batch: ", num_sampled_nodes) + + @test num_sampled_nodes == 2 + + # Test if there are edges in the subgraph + @test mini_batch_gnn.num_edges > 0 + end + + # 2. Edge case: Single node with no neighbors + @testset "Single node with no neighbors" begin + g = SimpleDiGraph(1) # A graph with a single node and no edges + node_features = rand(Float32, 5, 1) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains only one node + @test size(mini_batch_gnn.x, 2) == 1 + end + + # 3. Edge case: A node with no outgoing edges (isolated node) + @testset "Node with no outgoing edges" begin + g = SimpleDiGraph(2) # Graph with 2 nodes, no edges + node_features = rand(Float32, 5, 2) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled) + @test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes + end + + # 4. Edge case: A fully connected graph + @testset "Fully connected graph" begin + g = SimpleDiGraph(3) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 1) + node_features = rand(Float32, 5, 3) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if all nodes are included in the mini-batch since it's fully connected + @test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included + end + + # 5. Edge case: More layers than the number of neighbors + @testset "More layers than available neighbors" begin + g = SimpleDiGraph(3) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + node_features = rand(Float32, 5, 3) + graph = GNNGraph(g, ndata = node_features) + + # Test with 3 layers but only enough connections for 2 layers + loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains all available nodes + @test size(mini_batch_gnn.x, 2) == 1 + end + + # 6. Edge case: Large batch size greater than the number of input nodes + @testset "Large batch size" begin + g = create_test_graph() + + # Define NeighborLoader with a larger batch size than input nodes + loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph is not empty + @test !isempty(mini_batch_gnn.graph) + + # Test if the correct number of nodes are sampled + @test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected + end + + # 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer + @testset "No neighbors sampled" begin + g = create_test_graph() + + # Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2 + loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains only the input nodes + @test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph + end + +end \ No newline at end of file diff --git a/README.md b/README.md index 68c49926a..1cb508a46 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ - + # GraphNeuralNetworks.jl -[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/stable) -[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/dev) -![](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg) -[![codecov](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl) +[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/stable) +[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/dev) +![](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg) +[![codecov](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaGraphs/GraphNeuralNetworks.jl) GraphNeuralNetworks.jl is a graph neural network library written in Julia and based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl). @@ -18,7 +18,7 @@ Among its features: * Easy to define custom layers. * CUDA support. * Integration with [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl). -* [Examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. +* [Examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) of node, edge, and graph level machine learning tasks. * Heterogeneous and temporal graphs. ## Installation @@ -31,7 +31,7 @@ pkg> add GraphNeuralNetworks ## Usage -Usage examples can be found in the [examples](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/examples) and in the [notebooks](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/tree/master/notebooks) folder. Also, make sure to read the [documentation](https://CarloLucibello.github.io/GraphNeuralNetworks.jl/dev) for a comprehensive introduction to the library. +Usage examples can be found in the [examples](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/examples) and in the [notebooks](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/tree/master/notebooks) folder. Also, make sure to read the [documentation](https://JuliaGraphs.github.io/GraphNeuralNetworks.jl/dev) for a comprehensive introduction to the library. ## Citing @@ -43,7 +43,7 @@ If you use GraphNeuralNetworks.jl in a scientific publication, we would apprecia author = {Carlo Lucibello and other contributors}, title = {GraphNeuralNetworks.jl: a geometric deep learning library for the Julia programming language}, year = 2021, - url = {https://github.com/CarloLucibello/GraphNeuralNetworks.jl} + url = {https://github.com/JuliaGraphs/GraphNeuralNetworks.jl} } ``` diff --git a/tutorials/docs/src/pluto_output/gnn_intro_pluto.md b/tutorials/docs/src/pluto_output/gnn_intro_pluto.md index 9c3c1e52c..1174628d6 100644 --- a/tutorials/docs/src/pluto_output/gnn_intro_pluto.md +++ b/tutorials/docs/src/pluto_output/gnn_intro_pluto.md @@ -29,7 +29,7 @@ julia_version = "1.10.5" --> -

Hands-on introduction to Graph Neural Networks

This Pluto notebook is a Julia adaptation of the Pytorch Geometric tutorials that can be found here.

Recently, deep learning on graphs has emerged to one of the hottest research fields in the deep learning community. Here, Graph Neural Networks (GNNs) aim to generalize classical deep learning concepts to irregular structured data (in contrast to images or texts) and to enable neural networks to reason about objects and their relations.

This is done by following a simple neural message passing scheme, where node features \(\mathbf{x}_i^{(\ell)}\) of all nodes \(i \in \mathcal{V}\) in a graph \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\) are iteratively updated by aggregating localized information from their neighbors \(\mathcal{N}(i)\):

$$\mathbf{x}_i^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_i^{(\ell)}, \left\{ \mathbf{x}_j^{(\ell)} : j \in \mathcal{N}(i) \right\} \right)$$

This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the GraphNeuralNetworks.jl library. GraphNeuralNetworks.jl is an extension library to the popular deep learning framework Flux.jl, and consists of various methods and utilities to ease the implementation of Graph Neural Networks.

Let's first import the packages we need:

+

This Pluto notebook is a Julia adaptation of the Pytorch Geometric tutorials that can be found here.

Recently, deep learning on graphs has emerged to one of the hottest research fields in the deep learning community. Here, Graph Neural Networks (GNNs) aim to generalize classical deep learning concepts to irregular structured data (in contrast to images or texts) and to enable neural networks to reason about objects and their relations.

This is done by following a simple neural message passing scheme, where node features \(\mathbf{x}_i^{(\ell)}\) of all nodes \(i \in \mathcal{V}\) in a graph \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\) are iteratively updated by aggregating localized information from their neighbors \(\mathcal{N}(i)\):

$$\mathbf{x}_i^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_i^{(\ell)}, \left\{ \mathbf{x}_j^{(\ell)} : j \in \mathcal{N}(i) \right\} \right)$$

This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the GraphNeuralNetworks.jl library. GraphNeuralNetworks.jl is an extension library to the popular deep learning framework Flux.jl, and consists of various methods and utilities to ease the implementation of Graph Neural Networks.

Let's first import the packages we need:

begin
     using Flux
diff --git a/tutorials/tutorials/index.md b/tutorials/tutorials/index.md
index e22d4134c..e0a02c6e6 100644
--- a/tutorials/tutorials/index.md
+++ b/tutorials/tutorials/index.md
@@ -9,7 +9,7 @@
 ## Contributions
 
 If you have a suggestion on adding new tutorials, feel free to create a new issue
-[here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/issues/new).
+[here](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl/issues/new).
 Users are invited to contribute demonstrations of their own.
 If you want to contribute new tutorials and looking for inspiration,
 checkout these tutorials from
diff --git a/tutorials/tutorials/introductory_tutorials/gnn_intro_pluto.jl b/tutorials/tutorials/introductory_tutorials/gnn_intro_pluto.jl
index 76e74ddef..756c2d1f0 100644
--- a/tutorials/tutorials/introductory_tutorials/gnn_intro_pluto.jl
+++ b/tutorials/tutorials/introductory_tutorials/gnn_intro_pluto.jl
@@ -39,7 +39,7 @@ This is done by following a simple **neural message passing scheme**, where node
 \mathbf{x}_i^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_i^{(\ell)}, \left\{ \mathbf{x}_j^{(\ell)} : j \in \mathcal{N}(i) \right\} \right)
 ```
 
-This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/CarloLucibello/GraphNeuralNetworks.jl)**.
+This tutorial will introduce you to some fundamental concepts regarding deep learning on graphs via Graph Neural Networks based on the **[GraphNeuralNetworks.jl library](https://github.com/JuliaGraphs/GraphNeuralNetworks.jl)**.
 GraphNeuralNetworks.jl is an extension library to the popular deep learning framework [Flux.jl](https://fluxml.ai/Flux.jl/stable/), and consists of various methods and utilities to ease the implementation of Graph Neural Networks.
 
 Let's first import the packages we need: