Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AddPositionalEncoding transform #4521

Merged
merged 37 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ea518ee
Add skeleton of AddPositionalEncoding
dongkwan-kim Apr 21, 2022
aea36a1
Implement laplacian_eigenvector_pe
dongkwan-kim Apr 21, 2022
7ecf572
Add random sign flip for laplacian_eigenvector_pe
dongkwan-kim Apr 21, 2022
c10bab5
Implement random_walk_pe
dongkwan-kim Apr 23, 2022
325b51c
Update docs & linting
dongkwan-kim Apr 23, 2022
88ff8a4
Add __init__ & tests
dongkwan-kim Apr 23, 2022
285961e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2022
9f52a53
Fix indentation errors in docs.
dongkwan-kim Apr 24, 2022
c08f8ba
Move diagonal_weight to out of the class
dongkwan-kim Apr 25, 2022
9af3887
Handle Data where 'x' is None.
dongkwan-kim Apr 25, 2022
b2712a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2022
2c42ca1
Separate PEs to different clases.
dongkwan-kim Apr 25, 2022
ffdd1e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2022
8e0aa93
Update test/transforms/test_add_positional_encoding.py
dongkwan-kim May 2, 2022
c8340b5
Update test/transforms/test_add_positional_encoding.py
dongkwan-kim May 2, 2022
603d784
Update tests & remove unnecessary lines
dongkwan-kim May 2, 2022
c965059
Update torch_geometric/transforms/add_positional_encoding.py
dongkwan-kim May 2, 2022
84e7be4
Remove unnecessary ':'
dongkwan-kim May 2, 2022
2a646f1
Add full_self_loop_attr
dongkwan-kim May 3, 2022
d900c23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2022
b07e996
Improve docs of full_self_loop_attr
dongkwan-kim May 3, 2022
40cd871
Make add_node_attr (add_pe) outside the class.
dongkwan-kim May 3, 2022
516e429
Add value tests for AddRandomWalkPE
dongkwan-kim May 4, 2022
676bdf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2022
24096e1
Fix an error in 'which' of eig_fn
dongkwan-kim May 10, 2022
6b1e051
Add kwargs for AddLaplacianEigenvectorPE
dongkwan-kim May 10, 2022
08ab837
Add output tests for AddLaplacianEigenvectorPE
dongkwan-kim May 10, 2022
75df78b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
9726b75
update
rusty1s May 10, 2022
620b6fc
Merge branch 'master' into master
rusty1s May 10, 2022
94df61e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
abf6445
changelog
rusty1s May 10, 2022
4688b70
Merge branch 'master' of github.com:dongkwan-kim/pytorch_geometric
rusty1s May 10, 2022
241f008
update
rusty1s May 10, 2022
c31e281
typo
rusty1s May 10, 2022
4ff3f21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2022
161e793
update
rusty1s May 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
Expand Down
93 changes: 93 additions & 0 deletions test/transforms/test_add_positional_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import copy

import torch

from torch_geometric.data import Data
from torch_geometric.transforms import (
AddLaplacianEigenvectorPE,
AddRandomWalkPE,
)


def test_add_laplacian_eigenvector_pe():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can also test for output rather than solely checking for shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the output tests for RandomWalkPE.
However, for LaplacianEigenvectorPE, I found that each run produces different outputs. I guess computing eigenvectors of Laplacian is not a deterministic operation. How can we write output tests for LaplacianEigenvectorPE? Any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try setting the random seed of torch and numpy?

Copy link
Contributor Author

@dongkwan-kim dongkwan-kim May 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried np.random.seed and random.seed, but it still creates different outputs. I found that using a fixed value for v0 in scipy.linalg.eigs produces deterministic outputs. We can place **kwargs in __call__ and pass the kwargs (including v0) into eigs explicitly. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks like when it's not provided there is a flag 'info=0' set.

How different are the outputs, personally I feel maybe adding some leniency in the checking of the result (i.e using all_close with a high atol pr rtol) might be better than adding the extra **kwargs just for this.

Just my opinion though. I think both ways work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested various graphs for AddLaplacianEigenvectorPE. If there is no v0, the output really varies so we cannot test with torch.allclose.

Instead, I added a clustering test on a graph with two clusters, which uses the mean value of the first non-trivial eigenvector (Fiedler vector) of each cluster.

Also, I added tests with exact values using seed_everything and v0. The keyword argument v0 will be added through AddLaplacianEigenvectorPE.__init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... But looking at failed tests, using seed_everything and v0 does not guarantee the same output in different machines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think what you've done for the clustering test is a good solution. Thanks!

x = torch.randn(6, 4)
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])
data = Data(x=x, edge_index=edge_index)

transform = AddLaplacianEigenvectorPE(k=3)
assert str(transform) == 'AddLaplacianEigenvectorPE()'
out = transform(copy.copy(data))
assert out.laplacian_eigenvector_pe.size() == (6, 3)

transform = AddLaplacianEigenvectorPE(k=3, attr_name=None)
out = transform(copy.copy(data))
assert out.x.size() == (6, 4 + 3)

transform = AddLaplacianEigenvectorPE(k=3, attr_name='x')
out = transform(copy.copy(data))
assert out.x.size() == (6, 3)

# Output tests:
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5],
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]])
data = Data(x=x, edge_index=edge_index)

transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True)
transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False)

# Clustering test with first non-trivial eigenvector (Fiedler vector)
pe = transform1(copy.copy(data)).laplacian_eigenvector_pe
pe_cluster_1 = pe[[0, 1, 4]]
pe_cluster_2 = pe[[2, 3, 5]]
assert not torch.allclose(pe_cluster_1, pe_cluster_2)
assert torch.allclose(pe_cluster_1, pe_cluster_1.mean())
assert torch.allclose(pe_cluster_2, pe_cluster_2.mean())

pe = transform2(copy.copy(data)).laplacian_eigenvector_pe
pe_cluster_1 = pe[[0, 1, 4]]
pe_cluster_2 = pe[[2, 3, 5]]
assert not torch.allclose(pe_cluster_1, pe_cluster_2)
assert torch.allclose(pe_cluster_1, pe_cluster_1.mean())
assert torch.allclose(pe_cluster_2, pe_cluster_2.mean())


def test_add_random_walk_pe():
x = torch.randn(6, 4)
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])
data = Data(x=x, edge_index=edge_index)

transform = AddRandomWalkPE(walk_length=3)
assert str(transform) == 'AddRandomWalkPE()'
out = transform(copy.copy(data))
assert out.random_walk_pe.size() == (6, 3)

transform = AddRandomWalkPE(walk_length=3, attr_name=None)
out = transform(copy.copy(data))
assert out.x.size() == (6, 4 + 3)

transform = AddRandomWalkPE(walk_length=3, attr_name='x')
out = transform(copy.copy(data))
assert out.x.size() == (6, 3)

# Output tests:
assert out.x.tolist() == [
[0.0, 0.5, 0.25],
[0.0, 0.5, 0.25],
[0.0, 0.5, 0.00],
[0.0, 1.0, 0.00],
[0.0, 0.5, 0.25],
[0.0, 0.5, 0.00],
]

edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])
data = Data(edge_index=edge_index, num_nodes=4)
out = transform(copy.copy(data))

assert out.x.tolist() == [
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[0.0, 0.0, 0.0],
]
19 changes: 19 additions & 0 deletions test/utils/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
add_remaining_self_loops,
add_self_loops,
contains_self_loops,
get_self_loop_attr,
remove_self_loops,
segregate_self_loops,
)
Expand Down Expand Up @@ -103,3 +104,21 @@ def test_add_remaining_self_loops_without_initial_loops():
edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight)
assert edge_index.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]]
assert edge_weight.tolist() == [0.5, 0.5, 1, 1]


def test_get_self_loop_attr():
edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])
edge_weight = torch.tensor([0.2, 0.3, 0.5])

full_loop_weight = get_self_loop_attr(edge_index, edge_weight)
assert full_loop_weight.tolist() == [0.5, 0.0]

full_loop_weight = get_self_loop_attr(edge_index, edge_weight, num_nodes=4)
assert full_loop_weight.tolist() == [0.5, 0.0, 0.0, 0.0]

full_loop_weight = get_self_loop_attr(edge_index)
assert full_loop_weight.tolist() == [1.0, 0.0]

edge_attr = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 1.0]])
full_loop_attr = get_self_loop_attr(edge_index, edge_attr)
assert full_loop_attr.tolist() == [[0.5, 1.0], [0.0, 0.0]]
3 changes: 3 additions & 0 deletions torch_geometric/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .add_metapaths import AddMetaPaths
from .largest_connected_components import LargestConnectedComponents
from .virtual_node import VirtualNode
from .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE

__all__ = [
'BaseTransform',
Expand Down Expand Up @@ -100,6 +101,8 @@
'AddMetaPaths',
'LargestConnectedComponents',
'VirtualNode',
'AddLaplacianEigenvectorPE',
'AddRandomWalkPE',
]

classes = __all__
139 changes: 139 additions & 0 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Any, Optional

import numpy as np
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import (
get_laplacian,
get_self_loop_attr,
to_scipy_sparse_matrix,
)


def add_node_attr(data: Data, value: Any,
attr_name: Optional[str] = None) -> Data:
# TODO Move to `BaseTransform`.
if attr_name is None:
if 'x' in data:
x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
else:
data.x = value
else:
data[attr_name] = value

return data


@functional_transform('add_laplacian_eigenvector_pe')
class AddLaplacianEigenvectorPE(BaseTransform):
r"""Adds the Laplacian eigenvector positional encoding from the
`"Benchmarking Graph Neural Networks" <https://arxiv.org/abs/2003.00982>`_
paper to the given graph
(functional name: :obj:`add_laplacian_eigenvector_pe`).

Args:
k (int): The number of non-trivial eigenvectors to consider.
attr_name (str, optional): The attribute name of the data object to add
positional encodings to. If set to :obj:`None`, will be
concatenated to :obj:`data.x`.
(default: :obj:`"laplacian_eigenvector_pe"`)
is_undirected (bool, optional): If set to :obj:`True`, this transform
expects undirected graphs as input, and can hence speed up the
computation of eigenvectors. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is
:obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when
:attr:`is_undirected` is :obj:`True`).
"""
def __init__(
self,
k: int,
attr_name: Optional[str] = 'laplacian_eigenvector_pe',
is_undirected: bool = False,
**kwargs,
):
self.k = k
self.attr_name = attr_name
self.is_undirected = is_undirected
self.kwargs = kwargs

def __call__(self, data: Data) -> Data:
from scipy.sparse.linalg import eigs, eigsh
eig_fn = eigs if not self.is_undirected else eigsh

num_nodes = data.num_nodes
edge_index, edge_weight = get_laplacian(
data.edge_index,
normalization='sym',
num_nodes=num_nodes,
)

L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)

eig_vals, eig_vecs = eig_fn(
L,
k=self.k + 1,
which='SR' if not self.is_undirected else 'SA',
return_eigenvectors=True,
**self.kwargs,
)

eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])
pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1])
sign = -1 + 2 * torch.randint(0, 2, (self.k, ))
pe *= sign

data = add_node_attr(data, pe, attr_name=self.attr_name)
return data


@functional_transform('add_random_walk_pe')
class AddRandomWalkPE(BaseTransform):
r"""Adds the random walk positional encoding from the `"Graph Neural
Networks with Learnable Structural and Positional Representations"
<https://arxiv.org/abs/2110.07875>`_ paper to the given graph
(functional name: :obj:`add_random_walk_pe`).

Args:
walk_length (int): The number of random walk steps.
attr_name (str, optional): The attribute name of the data object to add
positional encodings to. If set to :obj:`None`, will be
concatenated to :obj:`data.x`.
(default: :obj:`"laplacian_eigenvector_pe"`)
"""
def __init__(
self,
walk_length: int,
attr_name: Optional[str] = 'random_walk_pe',
):
self.walk_length = walk_length
self.attr_name = attr_name

def __call__(self, data: Data) -> Data:
num_nodes = data.num_nodes
edge_index, edge_weight = data.edge_index, data.edge_weight

adj = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=(num_nodes, num_nodes))

# Compute D^{-1} A:
deg_inv = 1.0 / adj.sum(dim=1)
deg_inv[deg_inv == float('inf')] = 0
adj = adj * deg_inv.view(-1, 1)

out = adj
row, col, value = out.coo()
pe_list = [get_self_loop_attr((row, col), value, num_nodes)]
for _ in range(self.walk_length - 1):
out = out @ adj
row, col, value = out.coo()
pe_list.append(get_self_loop_attr((row, col), value, num_nodes))
pe = torch.stack(pe_list, dim=-1)

data = add_node_attr(data, pe, attr_name=self.attr_name)
return data
3 changes: 2 additions & 1 deletion torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .undirected import is_undirected, to_undirected
from .loop import (contains_self_loops, remove_self_loops,
segregate_self_loops, add_self_loops,
add_remaining_self_loops)
add_remaining_self_loops, get_self_loop_attr)
from .isolated import contains_isolated_nodes, remove_isolated_nodes
from .subgraph import (get_num_hops, subgraph, k_hop_subgraph,
bipartite_subgraph)
Expand Down Expand Up @@ -46,6 +46,7 @@
'segregate_self_loops',
'add_self_loops',
'add_remaining_self_loops',
'get_self_loop_attr',
'contains_isolated_nodes',
'remove_isolated_nodes',
'get_num_hops',
Expand Down
36 changes: 36 additions & 0 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,39 @@ def add_remaining_self_loops(

edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)
return edge_index, edge_attr


def get_self_loop_attr(edge_index: Tensor, edge_attr: OptTensor = None,
num_nodes: Optional[int] = None) -> Tensor:
r"""Returns the edge features or weights of self-loops
:math:`(i, i)` of every node :math:`i \in \mathcal{V}` in the
graph given by :attr:`edge_index`. Edge features of missing self-loops not
present in :attr:`edge_index` will be filled with zeros. If
:attr:`edge_attr` is not given, it will be the vector of ones.

.. note::
This operation is analogous to getting the diagonal elements of the
dense adjacency matrix.

Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
features. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

:rtype: :class:`Tensor`
"""
loop_mask = edge_index[0] == edge_index[1]
loop_index = edge_index[0][loop_mask]

if edge_attr is not None:
loop_attr = edge_attr[loop_mask]
else: # A vector of ones:
loop_attr = torch.ones_like(loop_index, dtype=torch.float)

num_nodes = maybe_num_nodes(edge_index, num_nodes)
full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])
full_loop_attr[loop_index] = loop_attr

return full_loop_attr