Documentation | Paper | Colab Notebooks and Video Tutorials | External Resources | OGB Examples
PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.
It consists of various methods for deep learning on graphs and other irregular structures, also known as geometric deep learning, from a variety of published papers.
In addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, multi GPU-support, torch.compile
support, DataPipe
support, a large number of common benchmark datasets (based on simple interfaces to create your own), the GraphGym experiment manager, and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.
Click here to join our Slack community!
- Library Highlights
- Quick Tour for New Users
- Architecture Overview
- Implemented GNN Models
- Installation
Whether you are a machine learning researcher or first-time user of machine learning toolkits, here are some reasons to try out PyG for machine learning on graph-structured data.
- Easy-to-use and unified API: All it takes is 10-20 lines of code to get started with training a GNN model (see the next section for a quick tour). PyG is PyTorch-on-the-rocks: It utilizes a tensor-centric API and keeps design principles close to vanilla PyTorch. If you are already familiar with PyTorch, utilizing PyG is straightforward.
- Comprehensive and well-maintained GNN models: Most of the state-of-the-art Graph Neural Network architectures have been implemented by library developers or authors of research papers and are ready to be applied.
- Great flexibility: Existing PyG models can easily be extended for conducting your own research with GNNs. Making modifications to existing models or creating new architectures is simple, thanks to its easy-to-use message passing API, and a variety of operators and utility functions.
- Large-scale real-world GNN models: We focus on the need of GNN applications in challenging real-world scenarios, and support learning on diverse types of graphs, including but not limited to: scalable GNNs for graphs with millions of nodes; dynamic GNNs for node predictions over time; heterogeneous GNNs with multiple node types and edge types.
- GraphGym integration: GraphGym lets users easily reproduce GNN experiments, is able to launch and analyze thousands of different GNN configurations, and is customizable by registering new modules to a GNN learning pipeline.
In this quick tour, we highlight the ease of creating and training a GNN model with only a few lines of code.
In the first glimpse of PyG, we implement the training of a GNN for classifying papers in a citation graph.
For this, we load the Cora dataset, and create a simple 2-layer GCN model using the pre-defined GCNConv
:
import torch
from torch import Tensor
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='.', name='Cora')
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
# x: Node feature matrix of shape [num_nodes, in_channels]
# edge_index: Graph connectivity matrix of shape [2, num_edges]
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GCN(dataset.num_features, 16, dataset.num_classes)
We can now optimize the model in a training loop, similar to the standard PyTorch training procedure.
import torch.nn.functional as F
data = dataset[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
pred = model(data.x, data.edge_index)
loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask])
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
More information about evaluating final model performance can be found in the corresponding example.
In addition to the easy application of existing GNNs, PyG makes it simple to implement custom Graph Neural Networks (see here for the accompanying tutorial). For example, this is all it takes to implement the edge convolutional layer from Wang et al.:
import torch
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr="max") # "Max" aggregation.
self.mlp = Sequential(
Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels),
)
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
# x: Node feature matrix of shape [num_nodes, in_channels]
# edge_index: Graph connectivity matrix of shape [2, num_edges]
return self.propagate(edge_index, x=x) # shape [num_nodes, out_channels]
def message(self, x_j: Tensor, x_i: Tensor) -> Tensor:
# x_j: Source node features of shape [num_edges, in_channels]
# x_i: Target node features of shape [num_edges, in_channels]
edge_features = torch.cat([x_i, x_j - x_i], dim=-1)
return self.mlp(edge_features) # shape [num_edges, out_channels]
GraphGym allows you to manage and launch GNN experiments, using a highly modularized pipeline (see here for the accompanying tutorial).
git clone https://github.com/pyg-team/pytorch_geometric.git
cd pytorch_geometric/graphgym
bash run_single.sh # run a single GNN experiment (node/edge/graph-level)
bash run_batch.sh # run a batch of GNN experiments, using differnt GNN designs/datasets/tasks
Users are highly encouraged to check out the documentation, which contains additional tutorials on the essential functionalities of PyG, including data handling, creation of datasets and a full list of implemented methods, transforms, and datasets.
For a quick start, check out our examples in examples/
.
PyG provides a multi-layer framework that enables users to build Graph Neural Network solutions on both low and high levels. It comprises of the following components:
- The PyG engine utilizes the powerful PyTorch deep learning framework with full
torch.compile
and TorchScript support, as well as additions of efficient CPU/CUDA libraries for operating on sparse data, e.g.,pyg-lib
. - The PyG storage handles data processing, transformation and loading pipelines. It is capable of handling and processing large-scale graph datasets, and provides effective solutions for heterogeneous graphs. It further provides a variety of sampling solutions, which enable training of GNNs on large-scale graphs.
- The PyG operators bundle essential functionalities for implementing Graph Neural Networks. PyG supports important GNN building blocks that can be combined and applied to various parts of a GNN model, ensuring rich flexibility of GNN design.
- Finally, PyG provides an abundant set of GNN models, and examples that showcase GNN models on standard graph benchmarks. Thanks to its flexibility, users can easily build and modify custom GNN models to fit their specific needs.
We list currently supported PyG models, layers and operators according to category:
GNN layers:
All Graph Neural Network layers are implemented via the nn.MessagePassing
interface.
A GNN layer specifies how to perform message passing, i.e. by designing different message, aggregation and update functions as defined here.
These GNN layers can be stacked together to create Graph Neural Network models.
- GCNConv from Kipf and Welling: Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017) [Example]
- ChebConv from Defferrard et al.: Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering (NIPS 2016) [Example]
- GATConv from Veličković et al.: Graph Attention Networks (ICLR 2018) [Example]
Expand to see all implemented GNN layers...
- GCN2Conv from Chen et al.: Simple and Deep Graph Convolutional Networks (ICML 2020) [Example1, Example2]
- SplineConv from Fey et al.: SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels (CVPR 2018) [Example1, Example2]
- NNConv from Gilmer et al.: Neural Message Passing for Quantum Chemistry (ICML 2017) [Example1, Example2]
- CGConv from Xie and Grossman: Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties (Physical Review Letters 120, 2018)
- ECConv from Simonovsky and Komodakis: Edge-Conditioned Convolution on Graphs (CVPR 2017)
- EGConv from Tailor et al.: Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions (GNNSys 2021) [Example]
- GATv2Conv from Brody et al.: How Attentive are Graph Attention Networks? (ICLR 2022)
- TransformerConv from Shi et al.: Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification (CoRR 2020) [Example]
- SAGEConv from Hamilton et al.: Inductive Representation Learning on Large Graphs (NIPS 2017) [Example1, Example2, Example3, Example4]
- GraphConv from, e.g., Morris et al.: Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks (AAAI 2019)
- GatedGraphConv from Li et al.: Gated Graph Sequence Neural Networks (ICLR 2016)
- ResGatedGraphConv from Bresson and Laurent: Residual Gated Graph ConvNets (CoRR 2017)
- GINConv from Xu et al.: How Powerful are Graph Neural Networks? (ICLR 2019) [Example]
- GINEConv from Hu et al.: Strategies for Pre-training Graph Neural Networks (ICLR 2020)
- ARMAConv from Bianchi et al.: Graph Neural Networks with Convolutional ARMA Filters (CoRR 2019) [Example]
- SGConv from Wu et al.: Simplifying Graph Convolutional Networks (CoRR 2019) [Example]
- APPNP from Klicpera et al.: Predict then Propagate: Graph Neural Networks meet Personalized PageRank (ICLR 2019) [Example]
- MFConv from Duvenaud et al.: Convolutional Networks on Graphs for Learning Molecular Fingerprints (NIPS 2015)
- AGNNConv from Thekumparampil et al.: Attention-based Graph Neural Network for Semi-Supervised Learning (CoRR 2017) [Example]
- TAGConv from Du et al.: Topology Adaptive Graph Convolutional Networks (CoRR 2017) [Example]
- PNAConv from Corso et al.: Principal Neighbourhood Aggregation for Graph Nets (CoRR 2020) [Example]
- FAConv from Bo et al.: Beyond Low-Frequency Information in Graph Convolutional Networks (AAAI 2021)
- PDNConv from Rozemberczki et al.: Pathfinder Discovery Networks for Neural Message Passing (WWW 2021)
- RGCNConv from Schlichtkrull et al.: Modeling Relational Data with Graph Convolutional Networks (ESWC 2018) [Example1, Example2]
- RGATConv from Busbridge et al.: Relational Graph Attention Networks (CoRR 2019) [Example]
- FiLMConv from Brockschmidt: GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation (ICML 2020) [Example]
- SignedConv from Derr et al.: Signed Graph Convolutional Network (ICDM 2018) [Example]
- DNAConv from Fey: Just Jump: Dynamic Neighborhood Aggregation in Graph Neural Networks (ICLR-W 2019) [Example]
- PANConv from Ma et al.: Path Integral Based Convolution and Pooling for Graph Neural Networks (NeurIPS 2020)
- PointNetConv (including Iterative Farthest Point Sampling, dynamic graph generation based on nearest neighbor or maximum distance, and k-NN interpolation for upsampling) from Qi et al.: PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation (CVPR 2017) and PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space (NIPS 2017) [Example1, Example2]
- EdgeConv from Wang et al.: Dynamic Graph CNN for Learning on Point Clouds (CoRR, 2018) [Example1, Example2]
- XConv from Li et al.: PointCNN: Convolution On X-Transformed Points (NeurIPS 2018) [Example]
- PPFConv from Deng et al.: PPFNet: Global Context Aware Local Features for Robust 3D Point Matching (CVPR 2018)
- GMMConv from Monti et al.: Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs (CVPR 2017)
- FeaStConv from Verma et al.: FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis (CVPR 2018)
- PointTransformerConv from Zhao et al.: Point Transformer (2020)
- HypergraphConv from Bai et al.: Hypergraph Convolution and Hypergraph Attention (CoRR 2019)
- GravNetConv from Qasim et al.: Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks (European Physics Journal C, 2019)
- SuperGAT from Kim and Oh: How To Find Your Friendly Neighborhood: Graph Attention Design With Self-Supervision (ICLR 2021) [Example]
- HGTConv from Hu et al.: Heterogeneous Graph Transformer (WWW 2020) [Example]
- HEATConv from Mo et al.: Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction (CoRR 2021)
- SSGConv from Zhu et al.: Simple Spectral Graph Convolution (ICLR 2021)
- FusedGATConv from Zhang et al.: Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective (MLSys 2022)
- GPSConv from Rampášek et al.: Recipe for a General, Powerful, Scalable Graph Transformer (NeurIPS 2022) [Example]
Pooling layers: Graph pooling layers combine the vectorial representations of a set of nodes in a graph (or a subgraph) into a single vector representation that summarizes its properties of nodes. It is commonly applied to graph-level tasks, which require combining node features into a single graph representation.
- Top-K Pooling from Gao and Ji: Graph U-Nets (ICML 2019), Cangea et al.: Towards Sparse Hierarchical Graph Classifiers (NeurIPS-W 2018) and Knyazev et al.: Understanding Attention and Generalization in Graph Neural Networks (ICLR-W 2019) [Example]
- DiffPool from Ying et al.: Hierarchical Graph Representation Learning with Differentiable Pooling (NeurIPS 2018) [Example]
Expand to see all implemented pooling layers...
- Attentional Aggregation from Li et al.: Graph Matching Networks for Learning the Similarity of Graph Structured Objects (ICML 2019) [Example]
- Set2Set from Vinyals et al.: Order Matters: Sequence to Sequence for Sets (ICLR 2016) [Example]
- Sort Aggregation from Zhang et al.: An End-to-End Deep Learning Architecture for Graph Classification (AAAI 2018) [Example]
- MinCut Pooling from Bianchi et al.: Spectral Clustering with Graph Neural Networks for Graph Pooling (ICML 2020) [Example]
- DMoN Pooling from Tsitsulin et al.: Graph Clustering with Graph Neural Networks (CoRR 2020) [Example]
- Graclus Pooling from Dhillon et al.: Weighted Graph Cuts without Eigenvectors: A Multilevel Approach (PAMI 2007) [Example]
- Voxel Grid Pooling from, e.g., Simonovsky and Komodakis: Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs (CVPR 2017) [Example]
- SAG Pooling from Lee et al.: Self-Attention Graph Pooling (ICML 2019) and Knyazev et al.: Understanding Attention and Generalization in Graph Neural Networks (ICLR-W 2019) [Example]
- Edge Pooling from Diehl et al.: Towards Graph Pooling by Edge Contraction (ICML-W 2019) and Diehl: Edge Contraction Pooling for Graph Neural Networks (CoRR 2019) [Example]
- ASAPooling from Ranjan et al.: ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations (AAAI 2020) [Example]
- PANPooling from Ma et al.: Path Integral Based Convolution and Pooling for Graph Neural Networks (NeurIPS 2020)
- MemPooling from Khasahmadi et al.: Memory-Based Graph Networks (ICLR 2020) [Example]
- Graph Multiset Transformer from Baek et al.: Accurate Learning of Graph Representations with Graph Multiset Pooling (ICLR 2021) [Example]
- Equilibrium Aggregation from Bartunov et al.: (UAI 2022) [Example]
GNN models: Our supported GNN models incorporate multiple message passing layers, and users can directly use these pre-defined models to make predictions on graphs. Unlike simple stacking of GNN layers, these models could involve pre-processing, additional learnable parameters, skip connections, graph coarsening, etc.
- SchNet from Schütt et al.: SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions (NIPS 2017) [Example]
- DimeNet and DimeNetPlusPlus from Klicpera et al.: Directional Message Passing for Molecular Graphs (ICLR 2020) and Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules (NeurIPS-W 2020) [Example]
- Node2Vec from Grover and Leskovec: node2vec: Scalable Feature Learning for Networks (KDD 2016) [Example]
- Deep Graph Infomax from Veličković et al.: Deep Graph Infomax (ICLR 2019) [Example1, Example2]
- Deep Multiplex Graph Infomax from Park et al.: Unsupervised Attributed Multiplex Network Embedding (AAAI 2020) [Example]
- Masked Label Prediction from Shi et al.: Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification (CoRR 2020) [Example]
- PMLP from Yang et al.: Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs (ICLR 2023)
Expand to see all implemented GNN models...
- Jumping Knowledge from Xu et al.: Representation Learning on Graphs with Jumping Knowledge Networks (ICML 2018) [Example]
- A MetaLayer for building any kind of graph network similar to the TensorFlow Graph Nets library from Battaglia et al.: Relational Inductive Biases, Deep Learning, and Graph Networks (CoRR 2018)
- MetaPath2Vec from Dong et al.: metapath2vec: Scalable Representation Learning for Heterogeneous Networks (KDD 2017) [Example]
- All variants of Graph Autoencoders and Variational Autoencoders from:
- Variational Graph Auto-Encoders from Kipf and Welling (NIPS-W 2016) [Example]
- Adversarially Regularized Graph Autoencoder for Graph Embedding from Pan et al. (IJCAI 2018) [Example]
- Simple and Effective Graph Autoencoders with One-Hop Linear Models from Salha et al. (ECML 2020) [Example]
- SEAL from Zhang and Chen: Link Prediction Based on Graph Neural Networks (NeurIPS 2018) [Example]
- RENet from Jin et al.: Recurrent Event Network for Reasoning over Temporal Knowledge Graphs (ICLR-W 2019) [Example]
- GraphUNet from Gao and Ji: Graph U-Nets (ICML 2019) [Example]
- AttentiveFP from Xiong et al.: Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism (J. Med. Chem. 2020) [Example]
- DeepGCN and the GENConv from Li et al.: DeepGCNs: Can GCNs Go as Deep as CNNs? (ICCV 2019) and DeeperGCN: All You Need to Train Deeper GCNs (CoRR 2020) [Example]
- RECT from Wang et al.: Network Embedding with Completely-imbalanced Labels (TKDE 2020) [Example]
- GNNExplainer from Ying et al.: GNNExplainer: Generating Explanations for Graph Neural Networks (NeurIPS 2019) [Example1, Example2, Example3]
- Graph-less Neural Networks from Zhang et al.: Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation (CoRR 2021) [Example]
- LINKX from Lim et al.: Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods (NeurIPS 2021) [Example]
- RevGNN from Li et al.: Training Graph Neural with 1000 Layers (ICML 2021) [Example]
- TransE from Bordes et al.: Translating Embeddings for Modeling Multi-Relational Data (NIPS 2013) [Example]
- ComplEx from Trouillon et al.: Complex Embeddings for Simple Link Prediction (ICML 2016) [Example]
- DistMult from Yang et al.: Embedding Entities and Relations for Learning and Inference in Knowledge Bases (ICLR 2015) [Example]
- RotatE from Sun et al.: RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (ICLR 2019) [Example]
GNN operators and utilities: PyG comes with a rich set of neural network operators that are commonly used in many GNN models. They follow an extensible design: It is easy to apply these operators and graph utilities to existing GNN layers and models to further enhance model performance.
- DropEdge from Rong et al.: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification (ICLR 2020)
- DropNode, MaskFeature and AddRandomEdge from You et al.: Graph Contrastive Learning with Augmentations (NeurIPS 2020)
- DropPath from Li et al.: MaskGAE: Masked Graph Modeling Meets Graph Autoencoders (arXiv 2022)
- ShuffleNode from Veličković et al.: Deep Graph Infomax (ICLR 2019)
- GraphNorm from Cai et al.: GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training (ICML 2021)
- GDC from Klicpera et al.: Diffusion Improves Graph Learning (NeurIPS 2019) [Example]
Expand to see all implemented GNN operators and utilities...
- GraphSizeNorm from Dwivedi et al.: Benchmarking Graph Neural Networks (CoRR 2020)
- PairNorm from Zhao and Akoglu: PairNorm: Tackling Oversmoothing in GNNs (ICLR 2020)
- MeanSubtractionNorm from Yang et al.: Revisiting "Over-smoothing" in Deep GCNs (CoRR 2020)
- DiffGroupNorm from Zhou et al.: Towards Deeper Graph Neural Networks with Differentiable Group Normalization (NeurIPS 2020)
- Tree Decomposition from Jin et al.: Junction Tree Variational Autoencoder for Molecular Graph Generation (ICML 2018)
- TGN from Rossi et al.: Temporal Graph Networks for Deep Learning on Dynamic Graphs (GRL+ 2020) [Example]
- Weisfeiler Lehman Operator from Weisfeiler and Lehman: A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction (Nauchno-Technicheskaya Informatsia 1968) [Example]
- Continuous Weisfeiler Lehman Operator from Togninalli et al.: Wasserstein Weisfeiler-Lehman Graph Kernels (NeurIPS 2019)
- Label Propagation from Zhu and Ghahramani: Learning from Labeled and Unlabeled Data with Label Propagation (CMU-CALD 2002) [Example]
- Local Degree Profile from Cai and Wang: A Simple yet Effective Baseline for Non-attribute Graph Classification (CoRR 2018)
- CorrectAndSmooth from Huang et al.: Combining Label Propagation And Simple Models Out-performs Graph Neural Networks (CoRR 2020) [Example]
- Gini and BRO regularization from Henderson et al.: Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity (ICML 2021)
- RootedEgoNets and RootedRWSubgraph from Zhao et al.: From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness (ICLR 2022)
- FeaturePropagation from Rossi et al.: On the Unreasonable Effectiveness of Feature Propagation in Learning on Graphs with Missing Node Features (CoRR 2021)
Scalable GNNs: PyG supports the implementation of Graph Neural Networks that can scale to large-scale graphs. Such application is challenging since the entire graph, its associated features and the GNN parameters cannot fit into GPU memory. Many state-of-the-art scalability approaches tackle this challenge by sampling neighborhoods for mini-batch training, graph clustering and partitioning, or by using simplified GNN models. These approaches have been implemented in PyG, and can benefit from the above GNN layers, operators and models.
- NeighborLoader from Hamilton et al.: Inductive Representation Learning on Large Graphs (NIPS 2017) [Example1, Example2, Example3, Example4]
- ClusterGCN from Chiang et al.: Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks (KDD 2019) [Example1, Example2]
- GraphSAINT from Zeng et al.: GraphSAINT: Graph Sampling Based Inductive Learning Method (ICLR 2020) [Example]
Expand to see all implemented scalable GNNs...
- ShaDow from Zeng et al.: Decoupling the Depth and Scope of Graph Neural Networks (NeurIPS 2021) [Example]
- SIGN from Rossi et al.: SIGN: Scalable Inception Graph Neural Networks (CoRR 2020) [Example]
- HGTLoader from Hu et al.: Heterogeneous Graph Transformer (WWW 2020) [Example]
PyG is available for Python 3.8 to Python 3.12.
You can now install PyG via Anaconda for all major OS/PyTorch/CUDA combinations 🤗
If you have not yet installed PyTorch, install it via conda
as described in the official PyTorch documentation.
Given that you have PyTorch installed (>=1.8.0
), simply run
conda install pyg -c pyg
From PyG 2.3 onwards, you can install and use PyG without any external library required except for PyTorch. For this, simply run
pip install torch_geometric
If you want to utilize the full set of features from PyG, there exists several additional libraries you may want to install:
pyg-lib
: Heterogeneous GNN operators and graph sampling routinestorch-scatter
: Accelerated and efficient sparse reductionstorch-sparse
:SparseTensor
supporttorch-cluster
: Graph clustering routinestorch-spline-conv
:SplineConv
support
These packages come with their own CPU and GPU kernel implementations based on the PyTorch C++/CUDA/hip(ROCm) extension interface. For a basic usage of PyG, these dependencies are fully optional. We recommend to start with a minimal installation, and install additional dependencies once you start to actually need them.
For ease of installation of these extensions, we provide pip
wheels for all major OS/PyTorch/CUDA combinations, see here.
To install the binaries for PyTorch 2.2.0, simply run
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html
where ${CUDA}
should be replaced by either cpu
, cu118
, or cu121
depending on your PyTorch installation.
cpu |
cu118 |
cu121 |
|
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
Windows | ✅ | ✅ | ✅ |
macOS | ✅ |
To install the binaries for PyTorch 2.1.0, simply run
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
where ${CUDA}
should be replaced by either cpu
, cu118
, or cu121
depending on your PyTorch installation.
cpu |
cu118 |
cu121 |
|
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
Windows | ✅ | ✅ | ✅ |
macOS | ✅ |
Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, and PyTorch 2.0.0 (following the same procedure).
For older versions, you might need to explicitly specify the latest supported version number or install via pip install --no-index
in order to prevent a manual installation from source.
You can look up the latest supported version number here.
NVIDIA provides a PyG docker container for effortlessly training and deploying GPU accelerated GNNs with PyG, see here.
In case you want to experiment with the latest PyG features which are not fully released yet, either install the nightly version of PyG via
pip install pyg-nightly
or install PyG from master via
pip install git+https://github.com/pyg-team/pytorch_geometric.git
The external pyg-rocm-build
repository provides wheels and detailed instructions on how to install PyG for ROCm.
If you have any questions about it, please open an issue here.
Please cite our paper (and the respective papers of the methods used) if you use this code in your own work:
@inproceedings{Fey/Lenssen/2019,
title={Fast Graph Representation Learning with {PyTorch Geometric}},
author={Fey, Matthias and Lenssen, Jan E.},
booktitle={ICLR Workshop on Representation Learning on Graphs and Manifolds},
year={2019},
}
Feel free to email us if you wish your work to be listed in the external resources. If you notice anything unexpected, please open an issue and let us know. If you have any questions or are missing a specific feature, feel free to discuss them with us. We are motivated to constantly make PyG even better.