Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precision errors from casting rotary parameters to FP16 with AMP #27700

Merged
merged 10 commits into from
Nov 29, 2023
Merged

Fix precision errors from casting rotary parameters to FP16 with AMP #27700

merged 10 commits into from
Nov 29, 2023

Conversation

kevinhu
Copy link
Contributor

@kevinhu kevinhu commented Nov 25, 2023

What does this PR do?

When training with AMP, using einsum to multiply t and self.inv_freq will introduce precision errors because it casts the result to FP16. This can be avoided by using torch.outer instead, as originally mentioned here: https://github.com/Dao-AILab/flash-attention/blob/2c3baba4a63c4007c8a132c5380edc9430f88a22/flash_attn/layers/rotary.py#L396C1-L398C45

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kevinhu kevinhu changed the title Fix precision errors from casting rotary parameters to FP16 with AMP in Llama Fix precision errors from casting rotary parameters to FP16 with AMP Nov 25, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for opening this PR, it seems to me that the issue lies with AMP no?
My only concern would have been performances, outer might be a little bit slower but it seems to be negligible so LGTM.
Let's make sure that the failing test is fixed!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the following script for benchmarking:

import torch
from torch.utils import benchmark

results = []
for b in [10, 10000, 2000000]:
    for n in [10, 100, 10000, 1000000]:
        if b * n >= 1000000000:
            continue

        description = f'[{b}, {n}]'

        x = torch.rand(b, device='mps')
        y = torch.rand(n, device='mps')

        results.append(benchmark.Timer(
            stmt='torch.outer(x,y)',
            globals={'x': x, 'y': y},
            description=description,
        ).blocked_autorange())

        results.append(benchmark.Timer(
            stmt='torch.einsum("i,j->ij",x,y)',
            globals={'x': x, 'y': y},
            description=description,
        ).blocked_autorange())

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize()
compare.print()

Got the following:
image

So looks good to me 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

failing test is unrelated to the PR i'll fix it on main

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ArthurZucker ArthurZucker merged commit 083e369 into huggingface:main Nov 29, 2023
19 checks passed
@ArthurZucker
Copy link
Collaborator

FYI @gante and @Rocketknight1 if we see anything failing. I ran slow tests locally and it was all good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants