Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] TransformerLens's use of einsum leads to different training dynamics on TPUs #591

Closed
jqhoogland opened this issue May 13, 2024 · 9 comments
Labels
wontfix This will not be worked on

Comments

@jqhoogland
Copy link

I'm cross-posting an issue from torch_xla to here, so TransformerLens users will have an easier time finding it. The message is: don't trust TransformerLens's HookedTransformer if you're using TPUs. I think most of the responsibility lies with torch_xla, but it might be worth adding a warning message until it's been fixed with them.

Otherwise, the other possible fix would be to replace each instance of torch.einsum with other torch operations. I've already done this in my own fork, so let me know if you'd like to see a PR with this change.

@jbloomAus
Copy link
Collaborator

@jqhoogland Thanks for raising this. I'm looking into it as I've had similar issues. @neelnanda-io @ArthurConmy and I are discussing how to deal with this (it's a bit complicated) so probably to avoid a bunch of back / forth we'll make a PR.

@jqhoogland
Copy link
Author

Another possible fix I stumbled across yesterday is to use opt_einsum.contract instead of torch.einsum. This seems to solve the problem I was running into (even when the opt_einsum backend is already available for torch to use... don't ask me why).

@jbloomAus
Copy link
Collaborator

jbloomAus commented May 14, 2024 via email

@jbloomAus
Copy link
Collaborator

@jqhoogland It seems like if you install opt_einsum it will be used by default by pytorch. I'm not sure if TransformerLens should attempt to handle this directly, though some guidance or check which will assist people if they have this issue might be useful.

@bryce13950 Will defer to you. TLDR: On some hardware (TPU) one of our dependencies produces different results. This can be fixed by installing the opt_einsum which pytorch will use by default. I've also noticed that using the mac gpu mps can also induce errors (more noticeable in larger models). A good protocol would be for people to validate models operate correctly on any hardware they use rather than assume this is always the case.

@jqhoogland
Copy link
Author

Yes, this seems to be a problem with pytorch. If you like I can keep an eye on this and close the issue when pytorch resolves it.

@jbloomAus
Copy link
Collaborator

@jqhoogland thanks, that would be useful. Appreciate you flagging this too.

@bryce13950
Copy link
Collaborator

@jqhoogland Thanks for bringing this up. I think @jbloomAus is already aware of this, but I am in the middle of setting up some benchmarking tools for TransformerLens https://github.com/TransformerLensOrg/TransformerLens/tree/benchmark-utlitities. I was in the middle of this before I got pulled away to deal with some logistical things a couple weeks ago, but I should be able to get back to it either Friday, or next week. It seems like there could quite a bit of use for the tool I am building. I will keep this situation in mind to make sure there is something easy someone can do to ensure that the models they are using are producing the right results.

@bryce13950 bryce13950 added the wontfix This will not be worked on label May 23, 2024
@bryce13950
Copy link
Collaborator

We are about ready to put up a release where all einsum implementations have been replaced with standard PyTorch functions. Issues like this will no longer be an issue, so no need to keep track of this anymore.

@jqhoogland
Copy link
Author

jqhoogland commented Nov 25, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants