-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathpasskey.py
127 lines (107 loc) · 4.68 KB
/
passkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import random
import re
import sys
import torch
import warnings
from transformers import AutoTokenizer, pipeline
from tqdm import tqdm, trange
from tqdm.contrib import tenumerate
from model_loader import *
# from https://github.com/epfml/landmark-attention/blob/main/llama/run_test.py
def generate_prompt(n_garbage):
"""Generates a text file and inserts an execute line at a random position."""
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
garbage_inf = " ".join([garbage] * 10000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
final_question = "What is the pass key? The pass key is"
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix,
final_question
]
return "\n".join(lines), pass_key
def test_model(pipe, prompt_text, pass_key):
response = pipe(prompt_text, num_return_sequences=1, max_new_tokens=10)[
0]["generated_text"][len(prompt_text):]
assert f"The pass key is {pass_key}" in prompt_text
try:
pass_key = int(re.search(r'\d+', response).group())
except:
pass_key = response[:20]
return pass_key
def main(args):
models = [x[0] for x in args.model]
tokenizer = AutoTokenizer.from_pretrained(
models[0], model_max_length=sys.maxsize, padding_side="right", trust_remote_code=True)
if args.fixed_length:
lengths = [args.fixed_length]
tokens = [len(tokenizer.encode(generate_prompt(args.fixed_length)[0]))]
print(f"Prompt is {tokens[0]} tokens")
else:
if args.tokens_step:
tokens = [x for x in range(
args.min_tokens, args.max_tokens + 1, args.tokens_step)]
else:
tokens = [args.min_tokens]
while args.min_tokens < args.max_tokens:
point = tokens[-1] * 2
if point <= args.max_tokens:
tokens.append(point)
else:
break
lengths = []
last_n = 0
for target in tqdm(tokens, desc="Determining sequence lengths"):
num_tokens = 0
n = last_n
while num_tokens < target:
last_n = n
n += args.length_step
prompt = generate_prompt(n)[0]
num_tokens = len(tokenizer.encode(prompt))
lengths.append(last_n)
results = []
for model in tqdm(models, desc="Model", leave=False):
torch.cuda.empty_cache()
loaded = load_model_and_apply_patches(model, args)
pipe = pipeline("text-generation", model=loaded,
tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id)
result = [0] * len(lengths)
for i, length in tenumerate(lengths, desc="Lengths", leave=False):
for _ in trange(0, args.iterations, desc="Iterations", leave=False):
prompt_text, pass_key = generate_prompt(length)
num_tokens = len(pipe.tokenizer.encode(prompt_text))
answer = test_model(pipe, prompt_text, pass_key)
if answer == pass_key:
result[i] += 1
result[i] /= args.iterations
print(f"{model}: {tokens[i]}={int(result[i]*100)}%")
result.insert(0, model)
results.append(result)
if args.output_file:
with open(args.output_file, "w", encoding="utf-8") as f:
f.write(f",{','.join([str(x) for x in tokens])}\n")
for result in results:
f.write(f"{','.join([str(x) for x in result])}\n")
if __name__ == "__main__":
warnings.simplefilter("ignore")
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", action="append", nargs="+")
parser.add_argument("--fixed-length", type=int)
parser.add_argument("--max-tokens", type=int, default=8192)
parser.add_argument("--min-tokens", type=int, default=256)
parser.add_argument("--tokens-step", type=int)
parser.add_argument("--length-step", type=int, default=128)
parser.add_argument("--iterations", type=int, default=20)
parser.add_argument("--output-file", type=str)
main(add_args(parser).parse_args())