Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AddPositionalEncoding transform #4521

Merged
merged 37 commits into from
May 10, 2022
Merged

Conversation

dongkwan-kim
Copy link
Contributor

This PR includes implementations of Laplacian eigenvector positional encoding & random walk positional encoding based on the original authors' implementation.

Any comments are welcome. Plus, it looks fine to me but I do not know why there is an indentation error in the documentation build...

@codecov
Copy link

codecov bot commented Apr 23, 2022

Codecov Report

Merging #4521 (161e793) into master (8fdf895) will increase coverage by 0.06%.
The diff coverage is 98.55%.

@@            Coverage Diff             @@
##           master    #4521      +/-   ##
==========================================
+ Coverage   82.81%   82.88%   +0.06%     
==========================================
  Files         315      316       +1     
  Lines       16605    16674      +69     
==========================================
+ Hits        13752    13820      +68     
- Misses       2853     2854       +1     
Impacted Files Coverage Δ
torch_geometric/utils/__init__.py 100.00% <ø> (ø)
...ch_geometric/transforms/add_positional_encoding.py 98.27% <98.27%> (ø)
torch_geometric/transforms/__init__.py 100.00% <100.00%> (ø)
torch_geometric/utils/loop.py 84.78% <100.00%> (+1.85%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8fdf895...161e793. Read the comment docs.

@Padarn
Copy link
Contributor

Padarn commented Apr 24, 2022

I guess the paper is this one: GRAPH NEURAL NETWORKS WITH LEARNABLE
STRUCTURAL AND POSITIONAL REPRESENTATIONS
? Perhaps we can add a reference to it in the doc?

@dongkwan-kim
Copy link
Contributor Author

I guess the paper is this one: GRAPH NEURAL NETWORKS WITH LEARNABLE STRUCTURAL AND POSITIONAL REPRESENTATIONS? Perhaps we can add a reference to it in the doc?

@Padarn I added the paper you mentioned in the docs of lines 32 -- 34.

positional encoding from the `"Graph Neural Networks with Learnable
Structural and Positional Representations"
<https://arxiv.org/abs/2110.07875>`_ paper.

Is this what you mean, or if there is anything I missed, could you specify it? Thank you.

@Padarn
Copy link
Contributor

Padarn commented Apr 24, 2022

Yes thats all I meant, just common practice across the docstrings, thanks!

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good. Thank you very much! Only have some minor points regarding code structure.

test/transforms/test_add_positional_encoding.py Outdated Show resolved Hide resolved
test/transforms/test_add_positional_encoding.py Outdated Show resolved Hide resolved
test/transforms/test_add_positional_encoding.py Outdated Show resolved Hide resolved
)


def test_add_laplacian_eigenvector_pe():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can also test for output rather than solely checking for shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the output tests for RandomWalkPE.
However, for LaplacianEigenvectorPE, I found that each run produces different outputs. I guess computing eigenvectors of Laplacian is not a deterministic operation. How can we write output tests for LaplacianEigenvectorPE? Any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try setting the random seed of torch and numpy?

Copy link
Contributor Author

@dongkwan-kim dongkwan-kim May 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried np.random.seed and random.seed, but it still creates different outputs. I found that using a fixed value for v0 in scipy.linalg.eigs produces deterministic outputs. We can place **kwargs in __call__ and pass the kwargs (including v0) into eigs explicitly. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks like when it's not provided there is a flag 'info=0' set.

How different are the outputs, personally I feel maybe adding some leniency in the checking of the result (i.e using all_close with a high atol pr rtol) might be better than adding the extra **kwargs just for this.

Just my opinion though. I think both ways work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested various graphs for AddLaplacianEigenvectorPE. If there is no v0, the output really varies so we cannot test with torch.allclose.

Instead, I added a clustering test on a graph with two clusters, which uses the mean value of the first non-trivial eigenvector (Fiedler vector) of each cluster.

Also, I added tests with exact values using seed_everything and v0. The keyword argument v0 will be added through AddLaplacianEigenvectorPE.__init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... But looking at failed tests, using seed_everything and v0 does not guarantee the same output in different machines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think what you've done for the clustering test is a good solution. Thanks!

torch_geometric/transforms/add_positional_encoding.py Outdated Show resolved Hide resolved
torch_geometric/transforms/add_positional_encoding.py Outdated Show resolved Hide resolved
torch_geometric/transforms/add_positional_encoding.py Outdated Show resolved Hide resolved
torch_geometric/transforms/add_positional_encoding.py Outdated Show resolved Hide resolved
@rusty1s rusty1s changed the title Implement AddPositionalEncoding transform Implement AddPositionalEncoding transform May 10, 2022
@rusty1s rusty1s merged commit f35c85f into pyg-team:master May 10, 2022
@dongkwan-kim
Copy link
Contributor Author

This is super great. Thank you all. I will send other PRs mentioned in this thread soon.

@rusty1s
Copy link
Member

rusty1s commented May 10, 2022

Thank you!

@Padarn
Copy link
Contributor

Padarn commented May 10, 2022

Great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants