Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama #270

Merged
merged 60 commits into from
Aug 30, 2023
Merged

Llama #270

Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ca6ada5
Add Llama config, Mlp, Attention, and RotaryEmb
Ivan-Zhou Jul 30, 2023
667a56c
address integration test
Ivan-Zhou Jul 31, 2023
0e77701
test compare with hf implementation
Ivan-Zhou Aug 1, 2023
ed10ee1
Finish LlamaRotaryEmbedding
Ivan-Zhou Aug 1, 2023
6034aee
Add LlamaLinearScalingRotaryEmbedding
Ivan-Zhou Aug 1, 2023
e9a7c56
Add LlamaDynamicNTKScalingRotaryEmbedding
Ivan-Zhou Aug 1, 2023
6e97f5f
Refactor to simplified class differences
Ivan-Zhou Aug 1, 2023
271be4b
Implement _get_rotary_emb
Ivan-Zhou Aug 1, 2023
669e6e7
work on attention
Ivan-Zhou Aug 2, 2023
e45b39b
test initialize attention
Ivan-Zhou Aug 6, 2023
bbcd0ec
_apply_rotary_pos_emb
Ivan-Zhou Aug 8, 2023
442b93d
update llama and test
Ivan-Zhou Aug 8, 2023
673fa17
Merge branch 'main' into llama
Ivan-Zhou Aug 13, 2023
40bb7b4
update llama and test
Ivan-Zhou Aug 13, 2023
466c1db
Finish Llama Attention
Ivan-Zhou Aug 13, 2023
f9db049
Finish implementing LlamaLMHeadModel
Ivan-Zhou Aug 14, 2023
6781a8c
fix build
Ivan-Zhou Aug 14, 2023
abb70fa
fix build
Ivan-Zhou Aug 14, 2023
f21cd1c
remove max_position_embeddings
Ivan-Zhou Aug 16, 2023
6e222b7
Fix issues found from testing
Ivan-Zhou Aug 16, 2023
182327e
Fix issues found from end-to-end tests
Ivan-Zhou Aug 20, 2023
46939d5
Fix torch import issue
Ivan-Zhou Aug 20, 2023
92a7f23
Refactor RoPE
Ivan-Zhou Aug 20, 2023
65f5888
remove ()
Ivan-Zhou Aug 20, 2023
3f028b9
NamedArray type hint
Ivan-Zhou Aug 20, 2023
2fb1a3f
Remove position_ids
Ivan-Zhou Aug 20, 2023
97dadce
from/to HF config
Ivan-Zhou Aug 20, 2023
fefb93a
Update to state_dict
Ivan-Zhou Aug 22, 2023
8239cc3
ignore type in default_hf_checkpoint_converter
Ivan-Zhou Aug 22, 2023
73f1c90
attn to self_attn in LlamaDecoderLayer
Ivan-Zhou Aug 22, 2023
8f16ec0
Remove position embed from LlamaEmbedding
Ivan-Zhou Aug 23, 2023
5578db2
Set out_dims_first_in_dict
Ivan-Zhou Aug 23, 2023
d73c6f7
Add (incomplete) roundtrip test
Ivan-Zhou Aug 23, 2023
79f4fa9
remove unused Axis form mlp
Ivan-Zhou Aug 24, 2023
ee6e11c
Rename v's position
Ivan-Zhou Aug 24, 2023
f06e3fa
pretty sure this is the problem: weights weren't being deserialized/r…
dlwh Aug 24, 2023
41b103c
document when to use out_dims_first_in_dict
dlwh Aug 24, 2023
6ed1ee5
Add LlamaRMSNorm and add more consistency
Ivan-Zhou Aug 28, 2023
9c669fd
Fix issues from pre-commit tests
Ivan-Zhou Aug 28, 2023
6245f9b
tie llama weights by default
dlwh Aug 28, 2023
190cba3
add a todo for ivan
dlwh Aug 28, 2023
a022bb1
make test pass even without auth token
dlwh Aug 28, 2023
ec742c4
Intermediate -> Mlp
Ivan-Zhou Aug 29, 2023
a64e63a
Update src/levanter/models/llama.py
Ivan-Zhou Aug 29, 2023
1820ea9
Merge branch 'llama' of github.com:stanford-crfm/levanter into llama
Ivan-Zhou Aug 29, 2023
1bd25ca
Update src/levanter/models/llama.py
Ivan-Zhou Aug 29, 2023
d7412d4
Update src/levanter/models/llama.py
Ivan-Zhou Aug 29, 2023
2e3c1fb
Update src/levanter/models/llama.py
Ivan-Zhou Aug 29, 2023
05f684f
Fix issues from pre-commit checks
Ivan-Zhou Aug 29, 2023
f6157df
Start from llama 2 hf in roundtrip
Ivan-Zhou Aug 29, 2023
5109b55
Update model_id in the round trip test
Ivan-Zhou Aug 29, 2023
0a5d36d
Untie weight at LMHead Linear Layer
Ivan-Zhou Aug 30, 2023
c056942
fix round trip test, use compile time eval for the cos/sin cache
dlwh Aug 30, 2023
8380283
Update src/levanter/models/llama.py
dlwh Aug 30, 2023
93eeebe
Update tests/test_llama.py
dlwh Aug 30, 2023
dca9157
let's just use llama names where reasonable
dlwh Aug 30, 2023
9ccd795
implement LmHeadModel in LLama
dlwh Aug 30, 2023
5d8a102
use haliax's built in attention
dlwh Aug 30, 2023
d6e780c
Merge remote-tracking branch 'origin/main' into llama
dlwh Aug 30, 2023
8c6a20d
update for latest main: tokenizer resizing
dlwh Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Start from llama 2 hf in roundtrip
Ivan-Zhou committed Aug 29, 2023
commit f6157dfacff160f228dba43e1ce66fdf84fc98f1
15 changes: 14 additions & 1 deletion src/levanter/models/llama.py
Original file line number Diff line number Diff line change
@@ -93,7 +93,19 @@ def from_hf_config(cls, hf_config: HfConfig):
rope_scaling=hf_config.rope_scaling,
)

def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlamaConfig:
def to_hf_config(
self, vocab_size: int = 32000, tie_word_embeddings: bool = False, config_overrides: Optional[Dict] = None
) -> HfLlamaConfig:
"""Convert to HuggingFace's LlamaConfig

Args:
vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000.
tie_word_embeddings (bool, optional): Whether to tie weight embeddings. HuggingFace's default value is False
config_overrides (dict, optional): Overrides for the config. Defaults to None.

Returns:
HfLlamaConfig: HuggingFace's LlamaConfig
"""
if config_overrides is None:
config_overrides = {}

@@ -108,6 +120,7 @@ def to_hf_config(self, vocab_size: int = 32000, config_overrides=None) -> HfLlam
rms_norm_eps=self.layer_norm_epsilon,
rope_scaling=self.rope_scaling,
vocab_size=vocab_size,
tie_word_embeddings=tie_word_embeddings,
**config_overrides,
)

44 changes: 22 additions & 22 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
@@ -222,26 +222,27 @@ def test_llama_lm_head_model():
@skip_if_no_torch
def test_llama_roundtrip():
Ivan-Zhou marked this conversation as resolved.
Show resolved Hide resolved
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, LlamaForCausalLM

converter = LlamaConfig.default_hf_checkpoint_converter

config = _get_llama_config()
Vocab = hax.Axis("vocab", 1000)
config = LlamaConfig()
Vocab = hax.Axis("vocab", 32000)

# TODO: load the first torch model with model_id from HF
# Make input and attn_mask
input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size)
attn_mask = hax.nn.attention.causal_mask(config.Pos, config.KeyPos)
input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0)

# randomly initialize a levanter model
# TODO: use converter.load_pretrained
model = LlamaLMHeadModel.init(
Vocab=Vocab,
config=config,
key=random.PRNGKey(0),
)
torch_config = config.to_hf_config(vocab_size=Vocab.size)
torch_model = LlamaForCausalLM(torch_config)
torch_model.eval()

input = hax.random.randint(random.PRNGKey(0), model.Pos, 0, model.Vocab.size)
attn_mask = hax.nn.attention.causal_mask(model.Pos, model.config.KeyPos)
input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0)
torch_out = torch_model(input_torch)
torch_out = torch_out.logits[0].detach().cpu().numpy()
torch_out = jax.nn.softmax(torch_out, axis=-1)

model = converter.load_pretrained(LlamaLMHeadModel)

def compute(input):
model_output = model(input, attn_mask=attn_mask)
@@ -250,6 +251,9 @@ def compute(input):
compute = jax.jit(compute)
jax_out = compute(input).array

assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}"
assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}"

with tempfile.TemporaryDirectory() as tmpdir:
converter.save_pretrained(model, tmpdir, save_reference_code=False)
torch_model2 = AutoModelForCausalLM.from_pretrained(tmpdir)
@@ -263,20 +267,16 @@ def compute(input):


def _get_llama_config() -> LlamaConfig:
seq_len = 128
hidden_dim = 16
num_heads = 4
rope_scaling = {
"type": "linear",
"factor": 2.0,
}
return LlamaConfig(
seq_len=seq_len,
hidden_dim=hidden_dim,
num_heads=num_heads,
seq_len=128,
hidden_dim=16,
num_heads=4,
rope_scaling=rope_scaling,
# disable for tests so debugging is easier
gradient_checkpointing=False,
gradient_checkpointing=False, # disable for tests so debugging is easier
)