Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch flash attention in for llama #520

Merged
merged 16 commits into from
Aug 15, 2023
Merged

Conversation

dakinggg
Copy link
Collaborator

@dakinggg dakinggg commented Aug 11, 2023

This PR adds a monkeypatch of triton flash attention for llama2 models. If we start doing this for more models we can try to generalize into an algorithm, but until that time I think this monkeypatch is good enough.

TODO:

  • Copy profiling results for 7b
  • Copy profiling results for 70b
  • Paste in evidence that the test passes

Result for 7b scale to compare implementations:
Screenshot 2023-08-10 at 4 31 57 PM
Screenshot 2023-08-10 at 4 32 04 PM
Screenshot 2023-08-10 at 4 32 20 PM

Test results:

tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-7b-hf-True-torch] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:17 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:23 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 12%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-7b-hf-True-triton] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:23 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:24 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 25%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-7b-hf-False-torch] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:25 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:26 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 37%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-7b-hf-False-triton] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:27 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:27 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 50%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-70b-hf-True-torch] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:28 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:29 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 62%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-70b-hf-True-triton] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:30 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:32 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 75%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-70b-hf-False-torch] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:33 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:34 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [ 87%]
tests/test_llama_patch.py::test_patch_equivalence[meta-llama/Llama-2-70b-hf-False-triton] 
------------------------------------------------------------------------------------------------------------------- live log call -------------------------------------------------------------------------------------------------------------------
2023-08-11 22:18:36 [    INFO] Setting seed to 42 (reproducibility.py:159)
2023-08-11 22:18:37 [    INFO] Setting seed to 42 (reproducibility.py:159)
PASSED                                                                                                                                                                                                                                        [100%]

================================================================================================================= warnings summary ==================================================================================================================
../../miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29
  /mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives
    from distutils.util import strtobool

../../miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/deepspeed/ops/op_builder/builder.py:15
  /mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/deepspeed/ops/op_builder/builder.py:15: DeprecationWarning: The distutils.sysconfig module is deprecated, use sysconfig instead
    import distutils.sysconfig

../../miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/jupyter_client/connect.py:20
  /mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/jupyter_client/connect.py:20: DeprecationWarning: Jupyter is migrating its paths to use standard platformdirs
  given by the platformdirs library.  To remove this warning and
  see the appropriate new directories, set the environment variable
  `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
  The use of platformdirs will be the default in `jupyter_core` v6
    from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write

../../miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/comet_ml/monkey_patching.py:19
  /mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/comet_ml/monkey_patching.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib and slated for removal in Python 3.12; see the module's documentation for alternative uses
    import imp

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================== 8 passed, 4 warnings in 22.00s ===========================================================================================================

70b running on 16 and 32 gpus:
Screenshot 2023-08-11 at 2 59 36 PM
Screenshot 2023-08-11 at 2 59 45 PM
Screenshot 2023-08-11 at 2 59 50 PM
Screenshot 2023-08-11 at 3 00 00 PM

@dakinggg dakinggg marked this pull request as ready for review August 12, 2023 00:49
tests/test_llama_patch.py Outdated Show resolved Hide resolved
tests/test_llama_patch.py Outdated Show resolved Hide resolved
tests/test_llama_patch.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few comments

@germanjke
Copy link

Hi guys! @vchiley @dakinggg when do you planning to release this request? Thanks

@dakinggg dakinggg requested a review from vchiley August 14, 2023 21:06
Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@dakinggg dakinggg merged commit aff3eaa into mosaicml:main Aug 15, 2023
9 checks passed
@dakinggg dakinggg deleted the llama2-2 branch September 9, 2023 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants