Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama #270

Merged
merged 60 commits into from
Aug 30, 2023
Merged

Llama #270

merged 60 commits into from
Aug 30, 2023

Conversation

Ivan-Zhou
Copy link
Contributor

@Ivan-Zhou Ivan-Zhou commented Aug 2, 2023

Implement Llama based on HF implementation and the paper.

A few noteable difference from Gpt2:

  • Rotary Positional Embedding;
  • No dropout being used (therefore, many of the call() doesn't need key as input).

Tasks

  • Add Llama code based on HF implementation
  • Refactor with Haliax
  • Match Levanter's state dict with HF's
  • Roundtrip test
  • Ensure perf roughly matches GPT-2 implementation at scale

src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
@Ivan-Zhou Ivan-Zhou marked this pull request as ready for review August 14, 2023 00:56
@dlwh
Copy link
Member

dlwh commented Aug 16, 2023 via email

src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
src/levanter/models/llama.py Outdated Show resolved Hide resolved
tests/test_llama.py Show resolved Hide resolved
@Ivan-Zhou
Copy link
Contributor Author

Untie word_embeddings at LMHead is done. The only pending issue is Jax leakage when loading HF weight to Levanter's model.

tests/test_llama.py Outdated Show resolved Hide resolved
@Ivan-Zhou
Copy link
Contributor Author

Great thanks to @dlwh for helping with the roundtrip tests and massively improve the code style & taste in this PR 👍

@Ivan-Zhou Ivan-Zhou merged commit d07ff37 into main Aug 30, 2023
@Ivan-Zhou Ivan-Zhou deleted the llama branch August 30, 2023 19:59
@dlwh dlwh mentioned this pull request Aug 30, 2023
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants