-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathrun_generative_inference.py
172 lines (144 loc) · 5.85 KB
/
run_generative_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from functools import reduce
from copy import deepcopy
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import GPT2Tokenizer
import popxl
from popxl.utils import to_numpy
from popxl_addons.array_munging import tensor_parallel_input, repeat
from generative_inference import generative_inference
from modelling.hf_mapping import hf_mapping_lm_gen_inference_TP
from popxl_addons import timer
from utils.setup import gpt_config_setup
from utils.inference import batch_inference
from data.mnli.mnli_data import form_validation_prompts, prepare_validation_features
from config import GPTConfig
def unwrap(dl):
for example in tqdm(dl):
yield torch.tensor(example, dtype=torch.long)
def run_inference_hf(dataset, tokenizer, hf_model, sequence_length, output_length, micro_batch_size):
logging.info("Running inference HF")
def next_token(inputs, lengths):
outputs = hf_model(input_ids=inputs)
logits = outputs.logits # Tensor[mbs, seq, vocab]
# Batched index_select:
# Flatten [mbs, seq] dimension and offset indices
mbs = logits.shape[0]
seq = logits.shape[1]
offsets = (lengths - 1) + (torch.arange(0, mbs) * seq)
next_token_logits = torch.index_select(logits.reshape(-1, *logits.shape[2:]), 0, offsets) # Tensor[mbs, vocab]
return torch.argmax(next_token_logits, dim=-1).reshape(-1)
answers = batch_inference(
unwrap(dataset),
next_token,
sequence_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=0,
output_length=output_length,
micro_batch_size=micro_batch_size,
)
logging.info("HF output")
for p, a in zip(dataset, answers):
prompt = tokenizer.decode(p)
answer = tokenizer.decode(a)
logging.info(f"Prompt: {prompt}")
logging.info(f"Answer: {answer}")
def run_inference_popxl(config: GPTConfig, dataset, tokenizer, hf_model, sequence_length, output_length):
config.model.sequence_length = sequence_length
tp = config.execution.tensor_parallel
rf = config.execution.tensor_parallel * config.execution.data_parallel
session = generative_inference(config)
if config.model.dtype == popxl.float16:
hf_model.half()
with timer("Loading HF pretrained model to IPU"):
weights = hf_mapping_lm_gen_inference_TP(config, session, hf_model)
session.write_variables_data(weights)
def next_token(inputs, lengths):
data_map = {}
words = to_numpy(inputs, session.inputs.words.dtype).reshape(-1, *session.inputs.words.shape)
data_map[session.inputs.words] = tensor_parallel_input(words, tp, rf).squeeze()
data_map[session.inputs.last_token_indices] = repeat(lengths - 1, tp, axis=0)
outputs = session.run(data_map)
next_token_id = outputs[session.outputs.next_token][0] # identical for all tp, take first
return torch.LongTensor(next_token_id)
logging.info("Attach to IPUs")
with session:
answers = batch_inference(
unwrap(dataset),
next_token,
config.model.sequence_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=0,
output_length=output_length,
micro_batch_size=config.execution.micro_batch_size,
)
logging.info("popxl output")
for p, a in zip(dataset, answers):
prompt = tokenizer.decode(p)
try:
answer = tokenizer.decode(a)
except TypeError as e:
raise Exception(f"Couldn't de-tokenize: {a}") from e
logging.info(f"Prompt: {prompt}")
logging.info(f"Answer: {answer}")
def get_mnli_dataset(config, tokenizer):
"""MNLI dataset"""
dataset = load_dataset("glue", "mnli", split="validation_mismatched")
dataset = dataset.select(range(config.execution.micro_batch_size))
dataset = dataset.map(form_validation_prompts, remove_columns=["hypothesis", "premise", "idx"])
dataset = dataset.map(
prepare_validation_features,
batched=True,
remove_columns=dataset.column_names,
load_from_cache_file=True,
fn_kwargs={"tokenizer": tokenizer},
)
dataset = [e["input_ids"] for e in dataset]
return dataset
def get_dummy_dataset(config, tokenizer):
"""Dummy dataset"""
text = [
"Marry had a little ",
"Edinburg is the capital of ",
"My name is ",
]
dataset = [tokenizer.encode(t, return_tensors="pt").flatten() for t in text]
return dataset
def main():
# --- Setup ---
config, args, pretrained = gpt_config_setup(
"config/inference.yml", "release", "gpt2_small", wandb_setup=False, hf_model_setup=True
)
# --- Tokenizer ---
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
assert (
config.model.embedding.vocab_size >= tokenizer.vocab_size
), f"Vocab size of model is too small for tokenizer: {config.model.embedding.vocab_size} < {tokenizer.vocab_size}"
# --- Dataset ---
dataset = list(get_dummy_dataset(config, tokenizer)) # Should just be input_ids
max_len = reduce(lambda l, e: max(l, len(e)), dataset, 0)
output_length = config.inference.generative_output_len
# --- HF example ---
if pretrained:
logging.info("Initialising HF model")
pretrained.eval()
run_inference_hf(
deepcopy(dataset),
tokenizer,
pretrained,
max_len + output_length,
output_length,
config.execution.micro_batch_size,
)
# --- POPXL example ---
run_inference_popxl(config, dataset, tokenizer, pretrained, max_len + output_length, output_length)
if __name__ == "__main__":
try:
main()
except Exception as e:
logging.exception(e) # Log time of exception
raise