-
Notifications
You must be signed in to change notification settings - Fork 316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug Report] TransformerLens's use of einsum
leads to different training dynamics on TPUs
#591
Comments
@jqhoogland Thanks for raising this. I'm looking into it as I've had similar issues. @neelnanda-io @ArthurConmy and I are discussing how to deal with this (it's a bit complicated) so probably to avoid a bunch of back / forth we'll make a PR. |
Another possible fix I stumbled across yesterday is to use |
Fantastic! Much easier to implement. Will try this out shortly.
…On Tue, May 14, 2024, 5:57 PM Jesse Hoogland ***@***.***> wrote:
Another possible fix I stumbled across yesterday is to use
opt_einsum.contract instead of torch.einsum. This seems to solve the
problem I was running into (even when the opt_einsum backend is already
available for torch to use... don't ask me why).
—
Reply to this email directly, view it on GitHub
<#591 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQPMYZ4M2WZHNGYJLCP5AIDZCI667AVCNFSM6AAAAABHTQGOI6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJQG4YDINRYGA>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
@jqhoogland It seems like if you install opt_einsum it will be used by default by pytorch. I'm not sure if TransformerLens should attempt to handle this directly, though some guidance or check which will assist people if they have this issue might be useful. @bryce13950 Will defer to you. TLDR: On some hardware (TPU) one of our dependencies produces different results. This can be fixed by installing the |
Yes, this seems to be a problem with pytorch. If you like I can keep an eye on this and close the issue when pytorch resolves it. |
@jqhoogland thanks, that would be useful. Appreciate you flagging this too. |
@jqhoogland Thanks for bringing this up. I think @jbloomAus is already aware of this, but I am in the middle of setting up some benchmarking tools for TransformerLens https://github.com/TransformerLensOrg/TransformerLens/tree/benchmark-utlitities. I was in the middle of this before I got pulled away to deal with some logistical things a couple weeks ago, but I should be able to get back to it either Friday, or next week. It seems like there could quite a bit of use for the tool I am building. I will keep this situation in mind to make sure there is something easy someone can do to ensure that the models they are using are producing the right results. |
We are about ready to put up a release where all einsum implementations have been replaced with standard PyTorch functions. Issues like this will no longer be an issue, so no need to keep track of this anymore. |
Thanks Bryce!
…On Mon, Nov 25, 2024 at 5:28 PM Bryce Meyer ***@***.***> wrote:
We are about ready to put up a release where all einsum implementations
have been replaced with standard PyTorch functions. Issues like this will
no longer be an issue, so no need to keep track of this anymore.
—
Reply to this email directly, view it on GitHub
<#591 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AG2FN7AU4MJ7JFB362AR5KL2COP75AVCNFSM6AAAAABSPBE3C2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIOJZGE3DINJWHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
*Jesse Hoogland*
Executive Director @ Timaeus <https://timaeus.co/>
Book a call with me: Calendly <https://calendly.com/jqhoogland/30min>
|
I'm cross-posting an issue from
torch_xla
to here, so TransformerLens users will have an easier time finding it. The message is: don't trust TransformerLens's HookedTransformer if you're using TPUs. I think most of the responsibility lies withtorch_xla
, but it might be worth adding a warning message until it's been fixed with them.Otherwise, the other possible fix would be to replace each instance of
torch.einsum
with other torch operations. I've already done this in my own fork, so let me know if you'd like to see a PR with this change.The text was updated successfully, but these errors were encountered: