PyTorch Geometric (PyG)
PyTorch Geometric (PyG) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data. It consists of various methods for deep learning on graphs and other irregular structures, providing easy-to-use mini-batch loaders, multi-GPU support, `torch.compile` support, a large number of common benchmark datasets, and helpful transforms. It is actively maintained with frequent minor releases delivering new features and bug fixes.
Warnings
- breaking PyG 2.7.0 dropped support for Python 3.9 and PyTorch versions 1.11 through 2.5. Ensure your environment uses Python >=3.10 and PyTorch >=2.6 for compatibility.
- gotcha Installation of optional C++/CUDA extensions (e.g., `torch-scatter`, `torch-sparse`) is complex and requires careful matching of PyTorch and CUDA versions, often necessitating specific pre-built wheels. Mismatched versions can lead to compilation errors or runtime issues.
- deprecated The `torch_geometric.compile` utility and `MessagePassing.jittable` attribute have been deprecated. Users should migrate to `torch.compile` for model optimization.
- breaking The interface and implementation of `GraphMultisetTransformer` changed in PyG 2.7.0, potentially affecting existing models that use this layer.
Install
-
pip install torch_geometric -
# Ensure PyTorch is installed with appropriate CUDA support, e.g., for CUDA 12.1: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 # Then install PyG's optional C++/CUDA extensions (replace TORCH and CUDA with your versions): pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
Imports
- Data
from torch_geometric.data import Data
- MessagePassing
from torch_geometric.nn import MessagePassing
- GCNConv
from torch_geometric.nn import GCNConv
- Planetoid
from torch_geometric.datasets import Planetoid
- T
import torch_geometric.transforms as T
Quickstart
import torch
from torch_geometric.data import Data
# Define an edge list (COO format: [source_nodes, target_nodes])
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# Define node features (3 nodes, 1 feature per node)
x = torch.tensor([[-1],
[0],
[1]], dtype=torch.float)
# Create a Data object to represent the graph
data = Data(x=x, edge_index=edge_index)
print(data)
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Is undirected: {data.is_undirected()}")