-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathexample_usage.py
205 lines (179 loc) · 7.4 KB
/
example_usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from typing import List
import torch
import tokenizers
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
tokenizers_version = tuple(int(n) for n in tokenizers.__version__.split('.'))
if tokenizers_version < (0, 12, 1):
print("warning: Your tokenizers version looks old and you will likely have formatting issues. We recommend installing tokenizers >= 0.12.1")
# set BIG_MODEL to use the 6.7B parameter model
BIG_MODEL = False
# use a GPU
CUDA = True
# print intermediate outputs of infilling
VERBOSE = False
if BIG_MODEL:
model_name = "facebook/incoder-6B"
# the arguments added below will load a half precision version of the model,
# which requires less RAM than loading the full float32 version. this
# should fit in ~16GB of RAM
# NOTE: half precision should *not* be used if you plan to fine-tune the
# model. You'll need full precision and a lot of GPU memory. We have not
# tested fine-tuning in `transformers` (the model was trained in fairseq)
if CUDA:
kwargs = dict(
revision="float16",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
else:
kwargs = dict(
low_cpu_mem_usage=True,
)
else:
model_name = "facebook/incoder-1B"
kwargs = {}
print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
print("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("loading complete")
if CUDA:
# if you plan to fine-tune the model, you should not use half precision.
model = model.half().cuda()
# signals the start of a document
BOS = "<|endoftext|>"
# signals the end of a generated infill
EOM = "<|endofmask|>"
def make_sentinel(i):
# signals (1) a location to insert an infill and (2) the start of the infill generation
return f"<|mask:{i}|>"
def generate(input: str, max_to_generate: int=128, temperature: float=0.2):
"""
Do standard left-to-right completion of the prefix `input` by sampling from the model
"""
input_ids = tokenizer(input, return_tensors="pt").input_ids
if CUDA:
input_ids = input_ids.cuda()
max_length = max_to_generate + input_ids.flatten().size(0)
if max_length > 2048:
print("warning: max_length {} is greater than the context window {}".format(max_length, 2048))
with torch.no_grad():
output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
# pass clean_up_tokenization_spaces=False to avoid removing spaces before punctuation, e.g. "from ." -> "from."
detok_hypo_str = tokenizer.decode(output.flatten(), clean_up_tokenization_spaces=False)
if detok_hypo_str.startswith(BOS):
detok_hypo_str = detok_hypo_str[len(BOS):]
return detok_hypo_str
def infill(parts: List[str], max_to_generate: int=128, temperature: float=0.2, extra_sentinel: bool=True, max_retries: int=1):
"""
Generate infills to complete a partial document, e.g.
[A C E] -> [A B C D E], where B and D are infills that have been generated.
parts: List[str]. list of parts of the document. One string will be
inserted in between each element, i.e. infilling N-1 locations for a list
of length N.
max_to_generate: int. maximum number of tokens to generate. Keep in mind
that the model context size is 2048.
temperature: float. temperature parameter for sampling.
extra_sentinel: bool. we recommend setting this to True, as it makes it
easier for the model to end generated infills. See the footnote in
section 2.2 of our paper for details.
max_retries: int. if > 1, use rejection sampling to keep sampling infills until
all infills sample a completion token.
returns a dictionary containing the following:
text: str, the completed document (with infills inserted)
parts: List[str], length N. Same as passed to the method
infills: List[str], length N-1. The list of infills generated
retries_attempted: number of retries used (if max_retries > 1)
"""
assert isinstance(parts, list)
retries_attempted = 0
done = False
while (not done) and (retries_attempted < max_retries):
retries_attempted += 1
if VERBOSE:
print(f"retry {retries_attempted}")
## (1) build the prompt
if len(parts) == 1:
prompt = parts[0]
else:
prompt = ""
# encode parts separated by sentinel
for sentinel_ix, part in enumerate(parts):
prompt += part
if extra_sentinel or (sentinel_ix < len(parts) - 1):
prompt += make_sentinel(sentinel_ix)
infills = []
complete = []
done = True
## (2) generate infills
for sentinel_ix, part in enumerate(parts[:-1]):
complete.append(part)
prompt += make_sentinel(sentinel_ix)
# TODO: this is inefficient as it requires re-encoding prefixes repeatedly
completion = generate(prompt, max_to_generate, temperature)
completion = completion[len(prompt):]
if EOM not in completion:
if VERBOSE:
print(f"warning: {EOM} not found")
completion += EOM
done = False
completion = completion[:completion.index(EOM) + len(EOM)]
infilled = completion[:-len(EOM)]
infills.append(infilled)
complete.append(infilled)
prompt += completion
complete.append(parts[-1])
text = ''.join(complete)
if VERBOSE:
print("generated text:")
print(prompt)
print()
print("parts:")
print(parts)
print()
print("infills:")
print(infills)
print()
print("restitched text:")
print(text)
print()
return {
'text': text, # str, the completed document (with infills inserted)
'parts': parts, # List[str], length N. Same as passed to the method
'infills': infills, # List[str], length N-1. The list of infills generated
'retries_attempted': retries_attempted, # number of retries used (if max_retries > 1)
}
def code_to_docstring(max_to_generate=128, temperature=0.2):
# this will sometimes generate extra functions! this can be avoided by truncating generation when e.g. a """ is produced
example = '''\
def count_words(filename):
""" <insert> """
counts = Counter()
with open(filename) as file:
for line in file:
words = line.split(' ')
counts.update(words)
return counts'''
parts = example.split("<insert>")
result = infill(parts, max_to_generate=max_to_generate, temperature=temperature)
print("completed document:")
print(result["text"])
return result["text"]
def docstring_to_code(max_to_generate=128, temperature=0.2):
example = '''\
def <insert>
""" Count the number of occurrences of each word in the file. """
<insert>
<|/ file |>'''
parts = example.split("<insert>")
result = infill(parts, max_to_generate=max_to_generate, temperature=temperature)
print("completed document:")
print(result["text"])
return result["text"]
if __name__ == "__main__":
print("code to docstring test:")
code_to_docstring()
print()
print("docstring to code test:")
docstring_to_code()