-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement AddPositionalEncoding
transform
#4521
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
ea518ee
Add skeleton of AddPositionalEncoding
dongkwan-kim aea36a1
Implement laplacian_eigenvector_pe
dongkwan-kim 7ecf572
Add random sign flip for laplacian_eigenvector_pe
dongkwan-kim c10bab5
Implement random_walk_pe
dongkwan-kim 325b51c
Update docs & linting
dongkwan-kim 88ff8a4
Add __init__ & tests
dongkwan-kim 285961e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9f52a53
Fix indentation errors in docs.
dongkwan-kim c08f8ba
Move diagonal_weight to out of the class
dongkwan-kim 9af3887
Handle Data where 'x' is None.
dongkwan-kim b2712a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2c42ca1
Separate PEs to different clases.
dongkwan-kim ffdd1e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8e0aa93
Update test/transforms/test_add_positional_encoding.py
dongkwan-kim c8340b5
Update test/transforms/test_add_positional_encoding.py
dongkwan-kim 603d784
Update tests & remove unnecessary lines
dongkwan-kim c965059
Update torch_geometric/transforms/add_positional_encoding.py
dongkwan-kim 84e7be4
Remove unnecessary ':'
dongkwan-kim 2a646f1
Add full_self_loop_attr
dongkwan-kim d900c23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b07e996
Improve docs of full_self_loop_attr
dongkwan-kim 40cd871
Make add_node_attr (add_pe) outside the class.
dongkwan-kim 516e429
Add value tests for AddRandomWalkPE
dongkwan-kim 676bdf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 24096e1
Fix an error in 'which' of eig_fn
dongkwan-kim 6b1e051
Add kwargs for AddLaplacianEigenvectorPE
dongkwan-kim 08ab837
Add output tests for AddLaplacianEigenvectorPE
dongkwan-kim 75df78b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9726b75
update
rusty1s 620b6fc
Merge branch 'master' into master
rusty1s 94df61e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] abf6445
changelog
rusty1s 4688b70
Merge branch 'master' of github.com:dongkwan-kim/pytorch_geometric
rusty1s 241f008
update
rusty1s c31e281
typo
rusty1s 4ff3f21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 161e793
update
rusty1s File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import copy | ||
|
||
import torch | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.transforms import ( | ||
AddLaplacianEigenvectorPE, | ||
AddRandomWalkPE, | ||
) | ||
|
||
|
||
def test_add_laplacian_eigenvector_pe(): | ||
x = torch.randn(6, 4) | ||
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], | ||
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) | ||
data = Data(x=x, edge_index=edge_index) | ||
|
||
transform = AddLaplacianEigenvectorPE(k=3) | ||
assert str(transform) == 'AddLaplacianEigenvectorPE()' | ||
out = transform(copy.copy(data)) | ||
assert out.laplacian_eigenvector_pe.size() == (6, 3) | ||
|
||
transform = AddLaplacianEigenvectorPE(k=3, attr_name=None) | ||
out = transform(copy.copy(data)) | ||
assert out.x.size() == (6, 4 + 3) | ||
|
||
transform = AddLaplacianEigenvectorPE(k=3, attr_name='x') | ||
out = transform(copy.copy(data)) | ||
assert out.x.size() == (6, 3) | ||
|
||
# Output tests: | ||
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5], | ||
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]]) | ||
data = Data(x=x, edge_index=edge_index) | ||
|
||
transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True) | ||
transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False) | ||
|
||
# Clustering test with first non-trivial eigenvector (Fiedler vector) | ||
pe = transform1(copy.copy(data)).laplacian_eigenvector_pe | ||
pe_cluster_1 = pe[[0, 1, 4]] | ||
pe_cluster_2 = pe[[2, 3, 5]] | ||
assert not torch.allclose(pe_cluster_1, pe_cluster_2) | ||
assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) | ||
assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) | ||
|
||
pe = transform2(copy.copy(data)).laplacian_eigenvector_pe | ||
pe_cluster_1 = pe[[0, 1, 4]] | ||
pe_cluster_2 = pe[[2, 3, 5]] | ||
assert not torch.allclose(pe_cluster_1, pe_cluster_2) | ||
assert torch.allclose(pe_cluster_1, pe_cluster_1.mean()) | ||
assert torch.allclose(pe_cluster_2, pe_cluster_2.mean()) | ||
|
||
|
||
def test_add_random_walk_pe(): | ||
x = torch.randn(6, 4) | ||
edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], | ||
[1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) | ||
data = Data(x=x, edge_index=edge_index) | ||
|
||
transform = AddRandomWalkPE(walk_length=3) | ||
assert str(transform) == 'AddRandomWalkPE()' | ||
out = transform(copy.copy(data)) | ||
assert out.random_walk_pe.size() == (6, 3) | ||
|
||
transform = AddRandomWalkPE(walk_length=3, attr_name=None) | ||
out = transform(copy.copy(data)) | ||
assert out.x.size() == (6, 4 + 3) | ||
|
||
transform = AddRandomWalkPE(walk_length=3, attr_name='x') | ||
out = transform(copy.copy(data)) | ||
assert out.x.size() == (6, 3) | ||
|
||
# Output tests: | ||
assert out.x.tolist() == [ | ||
[0.0, 0.5, 0.25], | ||
[0.0, 0.5, 0.25], | ||
[0.0, 0.5, 0.00], | ||
[0.0, 1.0, 0.00], | ||
[0.0, 0.5, 0.25], | ||
[0.0, 0.5, 0.00], | ||
] | ||
|
||
edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]]) | ||
data = Data(edge_index=edge_index, num_nodes=4) | ||
out = transform(copy.copy(data)) | ||
|
||
assert out.x.tolist() == [ | ||
[1.0, 1.0, 1.0], | ||
[1.0, 1.0, 1.0], | ||
[1.0, 1.0, 1.0], | ||
[0.0, 0.0, 0.0], | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from torch_sparse import SparseTensor | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.data.datapipes import functional_transform | ||
from torch_geometric.transforms import BaseTransform | ||
from torch_geometric.utils import ( | ||
get_laplacian, | ||
get_self_loop_attr, | ||
to_scipy_sparse_matrix, | ||
) | ||
|
||
|
||
def add_node_attr(data: Data, value: Any, | ||
attr_name: Optional[str] = None) -> Data: | ||
# TODO Move to `BaseTransform`. | ||
if attr_name is None: | ||
if 'x' in data: | ||
x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x | ||
data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1) | ||
else: | ||
data.x = value | ||
else: | ||
data[attr_name] = value | ||
|
||
return data | ||
|
||
|
||
@functional_transform('add_laplacian_eigenvector_pe') | ||
class AddLaplacianEigenvectorPE(BaseTransform): | ||
r"""Adds the Laplacian eigenvector positional encoding from the | ||
`"Benchmarking Graph Neural Networks" <https://arxiv.org/abs/2003.00982>`_ | ||
paper to the given graph | ||
(functional name: :obj:`add_laplacian_eigenvector_pe`). | ||
|
||
Args: | ||
k (int): The number of non-trivial eigenvectors to consider. | ||
attr_name (str, optional): The attribute name of the data object to add | ||
positional encodings to. If set to :obj:`None`, will be | ||
concatenated to :obj:`data.x`. | ||
(default: :obj:`"laplacian_eigenvector_pe"`) | ||
is_undirected (bool, optional): If set to :obj:`True`, this transform | ||
expects undirected graphs as input, and can hence speed up the | ||
computation of eigenvectors. (default: :obj:`False`) | ||
**kwargs (optional): Additional arguments of | ||
:meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is | ||
:obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when | ||
:attr:`is_undirected` is :obj:`True`). | ||
""" | ||
def __init__( | ||
self, | ||
k: int, | ||
attr_name: Optional[str] = 'laplacian_eigenvector_pe', | ||
is_undirected: bool = False, | ||
**kwargs, | ||
): | ||
self.k = k | ||
self.attr_name = attr_name | ||
self.is_undirected = is_undirected | ||
self.kwargs = kwargs | ||
|
||
def __call__(self, data: Data) -> Data: | ||
from scipy.sparse.linalg import eigs, eigsh | ||
eig_fn = eigs if not self.is_undirected else eigsh | ||
|
||
num_nodes = data.num_nodes | ||
edge_index, edge_weight = get_laplacian( | ||
data.edge_index, | ||
normalization='sym', | ||
num_nodes=num_nodes, | ||
) | ||
|
||
L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes) | ||
|
||
eig_vals, eig_vecs = eig_fn( | ||
L, | ||
k=self.k + 1, | ||
which='SR' if not self.is_undirected else 'SA', | ||
return_eigenvectors=True, | ||
**self.kwargs, | ||
) | ||
|
||
eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()]) | ||
pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1]) | ||
sign = -1 + 2 * torch.randint(0, 2, (self.k, )) | ||
pe *= sign | ||
|
||
data = add_node_attr(data, pe, attr_name=self.attr_name) | ||
return data | ||
|
||
|
||
@functional_transform('add_random_walk_pe') | ||
class AddRandomWalkPE(BaseTransform): | ||
r"""Adds the random walk positional encoding from the `"Graph Neural | ||
Networks with Learnable Structural and Positional Representations" | ||
<https://arxiv.org/abs/2110.07875>`_ paper to the given graph | ||
(functional name: :obj:`add_random_walk_pe`). | ||
|
||
Args: | ||
walk_length (int): The number of random walk steps. | ||
attr_name (str, optional): The attribute name of the data object to add | ||
positional encodings to. If set to :obj:`None`, will be | ||
concatenated to :obj:`data.x`. | ||
(default: :obj:`"laplacian_eigenvector_pe"`) | ||
""" | ||
def __init__( | ||
self, | ||
walk_length: int, | ||
attr_name: Optional[str] = 'random_walk_pe', | ||
): | ||
self.walk_length = walk_length | ||
self.attr_name = attr_name | ||
|
||
def __call__(self, data: Data) -> Data: | ||
num_nodes = data.num_nodes | ||
edge_index, edge_weight = data.edge_index, data.edge_weight | ||
|
||
adj = SparseTensor.from_edge_index(edge_index, edge_weight, | ||
sparse_sizes=(num_nodes, num_nodes)) | ||
|
||
# Compute D^{-1} A: | ||
deg_inv = 1.0 / adj.sum(dim=1) | ||
deg_inv[deg_inv == float('inf')] = 0 | ||
adj = adj * deg_inv.view(-1, 1) | ||
|
||
out = adj | ||
row, col, value = out.coo() | ||
pe_list = [get_self_loop_attr((row, col), value, num_nodes)] | ||
for _ in range(self.walk_length - 1): | ||
out = out @ adj | ||
row, col, value = out.coo() | ||
pe_list.append(get_self_loop_attr((row, col), value, num_nodes)) | ||
pe = torch.stack(pe_list, dim=-1) | ||
|
||
data = add_node_attr(data, pe, attr_name=self.attr_name) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we can also test for output rather than solely checking for shapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the output tests for RandomWalkPE.
However, for LaplacianEigenvectorPE, I found that each run produces different outputs. I guess computing eigenvectors of Laplacian is not a deterministic operation. How can we write output tests for LaplacianEigenvectorPE? Any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try setting the random seed of torch and numpy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried
np.random.seed
andrandom.seed
, but it still creates different outputs. I found that using a fixed value forv0
inscipy.linalg.eigs
produces deterministic outputs. We can place **kwargs in__call__
and pass the kwargs (includingv0
) into eigs explicitly. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it looks like when it's not provided there is a flag 'info=0' set.
How different are the outputs, personally I feel maybe adding some leniency in the checking of the result (i.e using
all_close
with a highatol
prrtol
) might be better than adding the extra**kwargs
just for this.Just my opinion though. I think both ways work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested various graphs for
AddLaplacianEigenvectorPE
. If there is nov0
, the output really varies so we cannot test withtorch.allclose
.Instead, I added a clustering test on a graph with two clusters, which uses the mean value of the first non-trivial eigenvector (Fiedler vector) of each cluster.
Also, I added tests with exact values using
seed_everything
andv0
. The keyword argumentv0
will be added throughAddLaplacianEigenvectorPE.__init__
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well... But looking at failed tests, using
seed_everything
andv0
does not guarantee the same output in different machines.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I think what you've done for the clustering test is a good solution. Thanks!