-
Notifications
You must be signed in to change notification settings - Fork 23
/
generate.py
57 lines (48 loc) · 1.87 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import jax
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer
from lib.generation import generate
from lib.llama import Llama
from lib.logits_processing import PresencePenaltyProcessor, TopKSampler, TopPSampler, make_logits_processor
from lib.param_utils import load_params
from lib.multihost_utils import shard_model_params
from lib.seeding import BEST_INTEGER
def load_params_from_disk() -> Llama:
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
params = load_params('llama2-7B.pickle')
params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params = shard_model_params(params)
return params
def main():
top_k = 6
# top_p = 0.05
max_len = 256
params = load_params_from_disk()
print('Successfully loaded model parameters!')
key = rand.key(BEST_INTEGER, impl='rbg')
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
logits_processor = make_logits_processor(
PresencePenaltyProcessor(penalty=0.05),
TopKSampler(top_k=top_k),
# TopPSampler(top_p=top_p),
)
batched_sentences = [
[
'Four score and seven years ago our fathers',
'Marta is a Syriac student. Aday asks her: ‘Why do you want to learn Surayt?’',
],
[
'Marta is a Syriac student. Aday asks her: ‘Why do you want to learn Surayt?’',
'Four score and seven years ago our fathers',
]
]
for sentences in batched_sentences:
key, subkey = rand.split(key)
generated_sentences = generate(sentences, tokenizer, params, logits_processor, max_len=max_len, key=subkey)
for sentence in generated_sentences:
print(sentence, end='\n\n')
if __name__ == '__main__':
main()