diff --git a/CHANGELOG.md b/CHANGELOG.md index 877a367b28b2..1be0426da651 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521)) - Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604)) - Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600)) - Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) diff --git a/test/transforms/test_add_positional_encoding.py b/test/transforms/test_add_positional_encoding.py new file mode 100644 index 000000000000..ab64a3490c54 --- /dev/null +++ b/test/transforms/test_add_positional_encoding.py @@ -0,0 +1,93 @@ +import copy + +import torch + +from torch_geometric.data import Data +from torch_geometric.transforms import ( + AddLaplacianEigenvectorPE, + AddRandomWalkPE, +) + + +def test_add_laplacian_eigenvector_pe(): + x = torch.randn(6, 4) + edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + data = Data(x=x, edge_index=edge_index) + + transform = AddLaplacianEigenvectorPE(k=3) + assert str(transform) == 'AddLaplacianEigenvectorPE()' + out = transform(copy.copy(data)) + assert out.laplacian_eigenvector_pe.size() == (6, 3) + + transform = AddLaplacianEigenvectorPE(k=3, attr_name=None) + out = transform(copy.copy(data)) + assert out.x.size() == (6, 4 + 3) + + transform = AddLaplacianEigenvectorPE(k=3, attr_name='x') + out = transform(copy.copy(data)) + assert out.x.size() == (6, 3) + + # Output tests: + edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]]) + data = Data(x=x, edge_index=edge_index) + + transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True) + transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False) + + # Clustering test with first non-trivial eigenvector (Fiedler vector) + pe = transform1(copy.copy(data)).laplacian_eigenvector_pe + pe_cluster_1 = pe[[0, 1, 4]] + pe_cluster_2 = pe[[2, 3, 5]] + assert not torch.allclose(pe_cluster_1, pe_cluster_2) + assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) + assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) + + pe = transform2(copy.copy(data)).laplacian_eigenvector_pe + pe_cluster_1 = pe[[0, 1, 4]] + pe_cluster_2 = pe[[2, 3, 5]] + assert not torch.allclose(pe_cluster_1, pe_cluster_2) + assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) + assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) + + +def test_add_random_walk_pe(): + x = torch.randn(6, 4) + edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + data = Data(x=x, edge_index=edge_index) + + transform = AddRandomWalkPE(walk_length=3) + assert str(transform) == 'AddRandomWalkPE()' + out = transform(copy.copy(data)) + assert out.random_walk_pe.size() == (6, 3) + + transform = AddRandomWalkPE(walk_length=3, attr_name=None) + out = transform(copy.copy(data)) + assert out.x.size() == (6, 4 + 3) + + transform = AddRandomWalkPE(walk_length=3, attr_name='x') + out = transform(copy.copy(data)) + assert out.x.size() == (6, 3) + + # Output tests: + assert out.x.tolist() == [ + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.00], + [0.0, 1.0, 0.00], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.00], + ] + + edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) + data = Data(edge_index=edge_index, num_nodes=4) + out = transform(copy.copy(data)) + + assert out.x.tolist() == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + ] diff --git a/test/utils/test_loop.py b/test/utils/test_loop.py index cf245befcd51..ded7a69cb091 100644 --- a/test/utils/test_loop.py +++ b/test/utils/test_loop.py @@ -4,6 +4,7 @@ add_remaining_self_loops, add_self_loops, contains_self_loops, + get_self_loop_attr, remove_self_loops, segregate_self_loops, ) @@ -103,3 +104,21 @@ def test_add_remaining_self_loops_without_initial_loops(): edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight) assert edge_index.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]] assert edge_weight.tolist() == [0.5, 0.5, 1, 1] + + +def test_get_self_loop_attr(): + edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]]) + edge_weight = torch.tensor([0.2, 0.3, 0.5]) + + full_loop_weight = get_self_loop_attr(edge_index, edge_weight) + assert full_loop_weight.tolist() == [0.5, 0.0] + + full_loop_weight = get_self_loop_attr(edge_index, edge_weight, num_nodes=4) + assert full_loop_weight.tolist() == [0.5, 0.0, 0.0, 0.0] + + full_loop_weight = get_self_loop_attr(edge_index) + assert full_loop_weight.tolist() == [1.0, 0.0] + + edge_attr = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 1.0]]) + full_loop_attr = get_self_loop_attr(edge_index, edge_attr) + assert full_loop_attr.tolist() == [[0.5, 1.0], [0.0, 0.0]] diff --git a/torch_geometric/transforms/__init__.py b/torch_geometric/transforms/__init__.py index deb6fe11c28b..fbd52180fdb8 100644 --- a/torch_geometric/transforms/__init__.py +++ b/torch_geometric/transforms/__init__.py @@ -48,6 +48,7 @@ from .add_metapaths import AddMetaPaths from .largest_connected_components import LargestConnectedComponents from .virtual_node import VirtualNode +from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE __all__ = [ 'BaseTransform', @@ -100,6 +101,8 @@ 'AddMetaPaths', 'LargestConnectedComponents', 'VirtualNode', + 'AddLaplacianEigenvectorPE', + 'AddRandomWalkPE', ] classes = __all__ diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py new file mode 100644 index 000000000000..05aed9387dc0 --- /dev/null +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -0,0 +1,139 @@ +from typing import Any, Optional + +import numpy as np +import torch +from torch_sparse import SparseTensor + +from torch_geometric.data import Data +from torch_geometric.data.datapipes import functional_transform +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import ( + get_laplacian, + get_self_loop_attr, + to_scipy_sparse_matrix, +) + + +def add_node_attr(data: Data, value: Any, + attr_name: Optional[str] = None) -> Data: + # TODO Move to `BaseTransform`. + if attr_name is None: + if 'x' in data: + x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x + data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) + else: + data.x = value + else: + data[attr_name] = value + + return data + + +@functional_transform('add_laplacian_eigenvector_pe') +class AddLaplacianEigenvectorPE(BaseTransform): + r"""Adds the Laplacian eigenvector positional encoding from the + `"Benchmarking Graph Neural Networks" `_ + paper to the given graph + (functional name: :obj:`add_laplacian_eigenvector_pe`). + + Args: + k (int): The number of non-trivial eigenvectors to consider. + attr_name (str, optional): The attribute name of the data object to add + positional encodings to. If set to :obj:`None`, will be + concatenated to :obj:`data.x`. + (default: :obj:`"laplacian_eigenvector_pe"`) + is_undirected (bool, optional): If set to :obj:`True`, this transform + expects undirected graphs as input, and can hence speed up the + computation of eigenvectors. (default: :obj:`False`) + **kwargs (optional): Additional arguments of + :meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is + :obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when + :attr:`is_undirected` is :obj:`True`). + """ + def __init__( + self, + k: int, + attr_name: Optional[str] = 'laplacian_eigenvector_pe', + is_undirected: bool = False, + **kwargs, + ): + self.k = k + self.attr_name = attr_name + self.is_undirected = is_undirected + self.kwargs = kwargs + + def __call__(self, data: Data) -> Data: + from scipy.sparse.linalg import eigs, eigsh + eig_fn = eigs if not self.is_undirected else eigsh + + num_nodes = data.num_nodes + edge_index, edge_weight = get_laplacian( + data.edge_index, + normalization='sym', + num_nodes=num_nodes, + ) + + L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes) + + eig_vals, eig_vecs = eig_fn( + L, + k=self.k + 1, + which='SR' if not self.is_undirected else 'SA', + return_eigenvectors=True, + **self.kwargs, + ) + + eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()]) + pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1]) + sign = -1 + 2 * torch.randint(0, 2, (self.k, )) + pe *= sign + + data = add_node_attr(data, pe, attr_name=self.attr_name) + return data + + +@functional_transform('add_random_walk_pe') +class AddRandomWalkPE(BaseTransform): + r"""Adds the random walk positional encoding from the `"Graph Neural + Networks with Learnable Structural and Positional Representations" + `_ paper to the given graph + (functional name: :obj:`add_random_walk_pe`). + + Args: + walk_length (int): The number of random walk steps. + attr_name (str, optional): The attribute name of the data object to add + positional encodings to. If set to :obj:`None`, will be + concatenated to :obj:`data.x`. + (default: :obj:`"laplacian_eigenvector_pe"`) + """ + def __init__( + self, + walk_length: int, + attr_name: Optional[str] = 'random_walk_pe', + ): + self.walk_length = walk_length + self.attr_name = attr_name + + def __call__(self, data: Data) -> Data: + num_nodes = data.num_nodes + edge_index, edge_weight = data.edge_index, data.edge_weight + + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=(num_nodes, num_nodes)) + + # Compute D^{-1} A: + deg_inv = 1.0 / adj.sum(dim=1) + deg_inv[deg_inv == float('inf')] = 0 + adj = adj * deg_inv.view(-1, 1) + + out = adj + row, col, value = out.coo() + pe_list = [get_self_loop_attr((row, col), value, num_nodes)] + for _ in range(self.walk_length - 1): + out = out @ adj + row, col, value = out.coo() + pe_list.append(get_self_loop_attr((row, col), value, num_nodes)) + pe = torch.stack(pe_list, dim=-1) + + data = add_node_attr(data, pe, attr_name=self.attr_name) + return data diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 64c77fbda34e..a9c09230d6c0 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -6,7 +6,7 @@ from .undirected import is_undirected, to_undirected from .loop import (contains_self_loops, remove_self_loops, segregate_self_loops, add_self_loops, - add_remaining_self_loops) + add_remaining_self_loops, get_self_loop_attr) from .isolated import contains_isolated_nodes, remove_isolated_nodes from .subgraph import (get_num_hops, subgraph, k_hop_subgraph, bipartite_subgraph) @@ -46,6 +46,7 @@ 'segregate_self_loops', 'add_self_loops', 'add_remaining_self_loops', + 'get_self_loop_attr', 'contains_isolated_nodes', 'remove_isolated_nodes', 'get_num_hops', diff --git a/torch_geometric/utils/loop.py b/torch_geometric/utils/loop.py index 2d50d6004867..7c08b322d112 100644 --- a/torch_geometric/utils/loop.py +++ b/torch_geometric/utils/loop.py @@ -227,3 +227,39 @@ def add_remaining_self_loops( edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1) return edge_index, edge_attr + + +def get_self_loop_attr(edge_index: Tensor, edge_attr: OptTensor = None, + num_nodes: Optional[int] = None) -> Tensor: + r"""Returns the edge features or weights of self-loops + :math:`(i, i)` of every node :math:`i \in \mathcal{V}` in the + graph given by :attr:`edge_index`. Edge features of missing self-loops not + present in :attr:`edge_index` will be filled with zeros. If + :attr:`edge_attr` is not given, it will be the vector of ones. + + .. note:: + This operation is analogous to getting the diagonal elements of the + dense adjacency matrix. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional edge + features. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + loop_mask = edge_index[0] == edge_index[1] + loop_index = edge_index[0][loop_mask] + + if edge_attr is not None: + loop_attr = edge_attr[loop_mask] + else: # A vector of ones: + loop_attr = torch.ones_like(loop_index, dtype=torch.float) + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:]) + full_loop_attr[loop_index] = loop_attr + + return full_loop_attr