-
Notifications
You must be signed in to change notification settings - Fork 259
/
Copy pathtest_ipex_optimize_transformers.py
568 lines (532 loc) · 22.7 KB
/
test_ipex_optimize_transformers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
import unittest
import torch
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch._C as core
import sys
import subprocess
import os
import copy
import re
import tempfile
from intel_extension_for_pytorch.quantization import prepare, convert
from collections import namedtuple
import itertools
try:
import transformers
from transformers import AutoConfig
except ImportError:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "transformers==4.45.0"]
)
import transformers
from transformers import AutoConfig
from intel_extension_for_pytorch.cpu._auto_kernel_selection import _disable_tpp
from common_utils import TestCase
torch.manual_seed(128)
curpath = os.path.abspath(os.path.dirname(__file__))
def _get_gptj_example_inputs(batch_size=8):
input_ids = torch.ones(batch_size).to(torch.long)
attention_mask = torch.ones(len(input_ids))
position_ids = torch.arange(len(input_ids))
past_key_values = tuple(
[
(
torch.zeros(1, 1, 0, 1, dtype=torch.long).contiguous(),
torch.zeros([1, 1, 1, 1]).contiguous(),
torch.zeros([1, 1, 1, 1]).contiguous(),
torch.zeros(1, 4, dtype=torch.long),
)
for i in range(1)
]
)
return (
input_ids.unsqueeze(0),
attention_mask.unsqueeze(0),
past_key_values,
position_ids.unsqueeze(0),
)
model_info = namedtuple(
"model_info",
"name, model_class, has_position_ids, attention_class, decoder_class",
)
supported_models = [
model_info(
"gptj",
transformers.models.gptj.modeling_gptj.GPTJForCausalLM,
True,
lambda m: m.transformer.h[0].attn.__class__,
lambda m: m.transformer.h[0].__class__,
),
model_info(
"llama",
transformers.models.llama.modeling_llama.LlamaForCausalLM,
True,
lambda m: m.model.layers[0].self_attn.__class__,
lambda m: m.model.layers[0].__class__,
),
]
class OptimizeTransformersTester(TestCase):
def model_replacement_check(
self, m, dtype, deployment_mode, torchcompile=False, return_dict=False
):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/{m.name}",
return_dict=return_dict,
trust_remote_code=True,
)
model = m.model_class(config).eval()
if m.name == "falcon":
with torch.no_grad():
ipex.nn.utils._model_convert.replace_customized_linear_with_linear(
model.eval()
)
elif m.name == "chatglm":
state_dict = model.state_dict()
for weight in [
"transformer.encoder.layers.0.input_layernorm.weight",
"transformer.encoder.layers.0.post_attention_layernorm.weight",
"transformer.encoder.final_layernorm.weight",
]:
state_dict[weight] = torch.rand(state_dict[weight].shape)
model.load_state_dict(state_dict)
elif m.name == "baichuan":
state_dict = model.state_dict()
for weight in [
"model.layers.0.input_layernorm.weight",
"model.layers.0.post_attention_layernorm.weight",
"model.norm.weight",
]:
state_dict[weight] = torch.rand(state_dict[weight].shape)
model.load_state_dict(state_dict)
model.eval()
ref_m = copy.deepcopy(model)
ipex_m = copy.deepcopy(model)
ipex_m = ipex.llm.optimize(
ipex_m, dtype=dtype, deployment_mode=deployment_mode, inplace=True
)
if torchcompile:
torch._dynamo.reset()
ipex._set_compiler_backend("inductor")
ipex_m = torch.compile(ipex_m, backend="ipex")
assert (
m.attention_class(ipex_m)
is ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU
)
assert (
m.decoder_class(ipex_m)
is ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU
if m.decoder_class is not None
else True
)
input_ids = torch.ones(10).to(torch.long)
attention_mask = torch.ones(len(input_ids))
position_ids = torch.arange(len(input_ids))
decoder_input_ids = torch.ones(1).to(torch.long)
input_dict = {
"input_ids": input_ids.unsqueeze(0),
"attention_mask": attention_mask.unsqueeze(0),
"use_cache": True,
}
if m.has_position_ids:
input_dict["position_ids"] = position_ids.unsqueeze(0)
if re.search("t5", model.config.architectures[0], re.IGNORECASE):
input_dict["decoder_input_ids"] = decoder_input_ids.unsqueeze(0)
with torch.no_grad():
key_hf = ref_m(**input_dict)
with torch.no_grad(), torch.cpu.amp.autocast(
enabled=True if dtype in [torch.bfloat16, torch.float16] else False,
dtype=dtype,
):
key_ipex = ipex_m(**input_dict)
error_message = f"model={m.name}, deployment_mode={deployment_mode}, torchcompile={torchcompile}, return_dict={return_dict}"
if return_dict:
assert isinstance(key_ipex, dict)
self.assertEqual(
key_hf["logits"], key_ipex["logits"], prec=0.1, message=error_message
)
else:
assert isinstance(key_ipex, tuple)
self.assertEqual(key_hf[0], key_ipex[0], prec=0.1, message=error_message)
def test_model_replacement(self):
dtypes = [torch.bfloat16]
if core.onednn_has_fp16_support():
dtypes.append(torch.float16)
enable_torchcompile = [False, True]
deployment_mode = [True, False]
return_dict = [False, True]
for m, torchcompile, dtype, jit, return_dict in itertools.product(
supported_models, enable_torchcompile, dtypes, deployment_mode, return_dict
):
if torchcompile and deployment_mode:
continue
self.model_replacement_check(m, dtype, jit, torchcompile, return_dict)
_disable_tpp()
def _model_replacement_check_woq(self, model):
qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping()
orig_model = copy.deepcopy(model)
orig_woq_model = prepare(orig_model, qconfig_mapping, inplace=True)
orig_woq_model = convert(orig_woq_model, inplace=True)
model = ipex.llm.optimize(
model,
dtype=torch.float,
quantization_config=qconfig_mapping,
deployment_mode=True,
)
if not hasattr(model, "trace_graph"):
AssertionError(False)
_IPEXAttentionCPU = (
ipex.transformers.models.cpu.modules.attentions._IPEXAttentionCPU
)
_IPEXDecoderLayerCPU = (
ipex.transformers.models.cpu.modules.decoder._IPEXDecoderLayerCPU
)
WeightOnlyQuantizedLinear = ipex.nn.modules.WeightOnlyQuantizedLinear
if re.search("GPTJ", model.config.architectures[0]):
assert model.transformer.h[0].attn.__class__ is _IPEXAttentionCPU
assert model.transformer.h[0].__class__ is _IPEXDecoderLayerCPU
assert all(
mod.__class__ is WeightOnlyQuantizedLinear
for mod in [
model.transformer.h[0].attn.concat_qkv.concat_linear,
model.transformer.h[0].attn.out_proj,
model.transformer.h[0].linear_add_add.linear,
model.transformer.h[0].linear_gelu.linear,
]
)
elif re.search("llama", model.config.architectures[0], re.IGNORECASE):
assert model.model.layers[0].self_attn.__class__ is _IPEXAttentionCPU
assert model.model.layers[0].__class__ is _IPEXDecoderLayerCPU
assert all(
mod.__class__ is WeightOnlyQuantizedLinear
for mod in [
model.model.layers[0].self_attn.concat_qkv.concat_linear,
model.model.layers[0].mha_linear_add.linear,
model.model.layers[0].mlp_linear_add.linear,
model.model.layers[0].linear_silu_mul.linear_s,
model.model.layers[0].linear_silu_mul.linear_m,
]
)
# Ensure model can run without errors
with torch.no_grad():
example_inputs = _get_gptj_example_inputs()
y = model(*example_inputs)
y_ref = orig_woq_model(
input_ids=example_inputs[0],
attention_mask=example_inputs[1],
position_ids=example_inputs[3],
use_cache=True,
)
self.assertEqual(y[0], y_ref[0], prec=1e-4)
def test_weight_only_quant_flow_for_gptj(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
self._model_replacement_check_woq(m)
def test_weight_only_quant_flow_for_llama(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/llama", return_dict=False
)
m = transformers.models.llama.modeling_llama.LlamaForCausalLM(config).eval()
self._model_replacement_check_woq(m)
def test_weight_only_quant_cache_weight_for_large_batch(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
weight_dtype_list = [
ipex.quantization.WoqWeightDtype.INT8,
ipex.quantization.WoqWeightDtype.INT4,
ipex.quantization.WoqWeightDtype.NF4,
]
lowp_mode_list = [
ipex.quantization.WoqLowpMode.BF16,
ipex.quantization.WoqLowpMode.INT8,
]
cases = itertools.product(weight_dtype_list, lowp_mode_list)
for weight_dtype, lowp_mode in cases:
if (
weight_dtype != ipex.quantization.WoqWeightDtype.INT4
and lowp_mode == ipex.quantization.WoqLowpMode.INT8
):
continue
qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
)
model_ref = ipex.llm.optimize(
copy.deepcopy(m),
dtype=torch.bfloat16,
quantization_config=qconfig_mapping,
deployment_mode=True,
cache_weight_for_large_batch=False,
)
model = ipex.llm.optimize(
copy.deepcopy(m),
dtype=torch.bfloat16,
quantization_config=qconfig_mapping,
deployment_mode=True,
cache_weight_for_large_batch=True,
)
linear_list = [
model.transformer.h[0].attn.concat_qkv.concat_linear,
model.transformer.h[0].attn.out_proj,
model.transformer.h[0].linear_add_add.linear,
model.transformer.h[0].linear_gelu.linear,
]
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True):
example_inputs = _get_gptj_example_inputs(batch_size=128)
y = model(*example_inputs)
y_ref = model_ref(*example_inputs)
for l in linear_list:
if l._op_context.get_weight().ndim == 4:
assert l._op_context.get_cached_weight() is not None
tol = (
1.5e-1
if weight_dtype == ipex.quantization.WoqWeightDtype.NF4
else 5e-2
)
self.assertEqual(y[0], y_ref[0], prec=tol)
def test_static_quant_flow(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
quant_m = copy.deepcopy(m)
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping()
example_inputs = _get_gptj_example_inputs()
quant_m = ipex.llm.optimize(
quant_m, dtype=torch.float, quantization_config=qconfig, inplace=True
)
from intel_extension_for_pytorch.quantization import prepare
prepared_model = prepare(
quant_m.eval(), qconfig, example_inputs=example_inputs, inplace=True
)
with torch.no_grad():
prepared_model(*example_inputs)
with tempfile.NamedTemporaryFile() as fp:
prepared_model.save_qconf_summary(qconf_summary=fp.name)
for dtype in [torch.float, torch.bfloat16]:
ipex_m = copy.deepcopy(m)
if dtype is torch.bfloat16:
ipex_m = ipex_m.to(torch.bfloat16)
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping()
ipex_m = ipex.llm.optimize(
ipex_m,
dtype=dtype,
quantization_config=qconfig,
qconfig_summary_file=fp.name,
inplace=True,
)
if not hasattr(ipex_m, "trace_graph"):
AssertionError(False)
def test_weight_only_quant_gptq(self):
# Test the HuggingFace Optimum format
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
ipex_m = copy.deepcopy(m)
with tempfile.TemporaryDirectory() as work_dir:
# Generate dummy checkpoint
checkpoint_file_name = work_dir + "/checkpoint.pt"
state_dict = ipex_m.state_dict()
linear_keys = []
for k, v in state_dict.items():
if any(
k.endswith(suffix)
for suffix in ["proj.weight", "fc_in.weight", "fc_out.weight"]
):
linear_keys.append(k[:-7])
group_size = 128
comp_ratio = 8
for k in linear_keys:
N = state_dict[k + ".weight"].shape[0]
K = state_dict[k + ".weight"].shape[1]
del state_dict[k + ".weight"]
n_groups = K // group_size
stored_weight_shape = (K // comp_ratio, N)
stored_scales_shape = (n_groups, N)
stored_zeros_shape = (n_groups, N // comp_ratio)
state_dict[k + ".qweight"] = torch.randint(
-(2**31), 2**31 - 1, stored_weight_shape, dtype=torch.int32
)
state_dict[k + ".scales"] = torch.randn(
stored_scales_shape, dtype=torch.half
)
state_dict[k + ".qzeros"] = torch.randint(
-(2**31), 2**31 - 1, stored_zeros_shape, dtype=torch.int32
)
g_idx = torch.arange(n_groups).repeat(group_size)
g_idx[:] = g_idx[torch.randperm(K)]
state_dict[k + ".g_idx"] = g_idx
torch.save(state_dict, checkpoint_file_name)
state_dict = torch.load(checkpoint_file_name)
# test loading checkpoint and quant info
lowp_mode = ipex.quantization.WoqLowpMode.INT8
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
lowp_mode=lowp_mode
)
ipex_m = ipex.llm.optimize(
ipex_m,
dtype=torch.float,
quantization_config=qconfig,
low_precision_checkpoint=state_dict,
deployment_mode=True,
inplace=True,
)
assert hasattr(ipex_m, "trace_graph")
# Ensure model can run without errors
with torch.no_grad():
example_inputs = _get_gptj_example_inputs()
# the optimized model is ipex_m.trace_graph
ipex_m.trace_graph(*example_inputs)
def test_weight_only_quant_awq(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
ipex_m = copy.deepcopy(m)
with tempfile.TemporaryDirectory() as work_dir:
# Generate dummy checkpoint
checkpoint_file_name = work_dir + "/checkpoint.pt"
state_dict = ipex_m.state_dict()
linear_keys = []
for k, v in state_dict.items():
if any(
k.endswith(suffix)
for suffix in ["proj.weight", "fc_in.weight", "fc_out.weight"]
):
linear_keys.append(k[:-7])
group_size = 128
comp_ratio = 8
for k in linear_keys:
N = state_dict[k + ".weight"].shape[0]
K = state_dict[k + ".weight"].shape[1]
del state_dict[k + ".weight"]
n_groups = K // group_size
stored_weight_shape = (K, N // comp_ratio)
stored_scales_shape = (n_groups, N)
stored_zeros_shape = (n_groups, N // comp_ratio)
state_dict[k + ".qweight"] = torch.randint(
-(2**31), 2**31 - 1, stored_weight_shape, dtype=torch.int32
)
state_dict[k + ".scales"] = torch.randn(
stored_scales_shape, dtype=torch.half
)
state_dict[k + ".qzeros"] = torch.randint(
-(2**31), 2**31 - 1, stored_zeros_shape, dtype=torch.int32
)
g_idx = torch.arange(n_groups).repeat(group_size)
g_idx[:] = g_idx[torch.randperm(K)]
state_dict[k + ".g_idx"] = g_idx
torch.save(state_dict, checkpoint_file_name)
state_dict = torch.load(checkpoint_file_name)
# test loading checkpoint and quant info
lowp_mode = ipex.quantization.WoqLowpMode.INT8
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
lowp_mode=lowp_mode
)
ipex_m = ipex.llm.optimize(
ipex_m,
dtype=torch.float,
quantization_config=qconfig,
low_precision_checkpoint=(state_dict, {"quant_method": "awq"}),
deployment_mode=True,
inplace=True,
)
assert hasattr(ipex_m, "trace_graph")
# Ensure model can run without errors
with torch.no_grad():
example_inputs = _get_gptj_example_inputs()
# the optimized model is ipex_m.trace_graph
ipex_m.trace_graph(*example_inputs)
def test_generate_functions(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
dtypes = [torch.bfloat16]
if core.onednn_has_fp16_support():
dtypes.append(torch.float16)
for dtype in dtypes:
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
ref_m = copy.deepcopy(m)
ipex_m = ipex.llm.optimize(
m, dtype=dtype, deployment_mode=True, inplace=True
)
input_ids = torch.ones(8).unsqueeze(0).to(torch.long)
# beam_search, beam=4
generate_kwargs_beam = dict(
do_sample=False,
temperature=0.9,
num_beams=4,
max_new_tokens=2,
min_new_tokens=2,
)
# greedy_search
generate_kwargs_greedy = dict(
do_sample=False, temperature=0.9, max_new_tokens=2, min_new_tokens=2
)
# sample, use a temperature of 0.001 to constrain text generation diversity in UT.
generate_kwargs_sample = dict(
do_sample=True, temperature=0.001, max_new_tokens=2, min_new_tokens=2
)
# beam_sample, use a temperature of 0.001 to constrain text generation diversity in UT.
generate_kwargs_beam_sample = dict(
do_sample=True,
temperature=0.001,
num_beams=4,
max_new_tokens=2,
min_new_tokens=2,
)
for generate_kwargs in [
generate_kwargs_beam,
generate_kwargs_greedy,
generate_kwargs_sample,
generate_kwargs_beam_sample,
]:
with torch.inference_mode(), torch.no_grad(), torch.cpu.amp.autocast(
enabled=True, dtype=dtype
):
ipex_res = ipex_m.generate(input_ids, **generate_kwargs)
ref_res = ref_m.generate(input_ids, **generate_kwargs)
self.assertEqual(ipex_res, ref_res)
ipex_res_dict = ipex_m.generate(
input_ids, return_dict_in_generate=True, **generate_kwargs
)
ref_res_dict = ref_m.generate(
input_ids, return_dict_in_generate=True, **generate_kwargs
)
self.assertEqual(ipex_res_dict.sequences, ref_res_dict.sequences)
def test_cache_weight_for_large_batch(self):
config = AutoConfig.from_pretrained(
f"{curpath}/hf_configs/gptj", return_dict=False
)
model = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
model_ref = ipex.llm.optimize(
copy.deepcopy(model),
dtype=torch.bfloat16,
deployment_mode=True,
cache_weight_for_large_batch=False,
)
model = ipex.llm.optimize(
model,
dtype=torch.bfloat16,
deployment_mode=True,
cache_weight_for_large_batch=True,
)
linear_list = [
model.transformer.h[0].attn.concat_qkv.concat_linear,
model.transformer.h[0].attn.out_proj,
model.transformer.h[0].linear_add_add.linear,
model.transformer.h[0].linear_gelu.linear,
]
with torch.no_grad(), torch.cpu.amp.autocast(enabled=True):
example_inputs = _get_gptj_example_inputs(batch_size=512)
y = model(*example_inputs)
y_ref = model_ref(*example_inputs)
assert all(hasattr(l, "weight_for_large_batch") for l in linear_list)
assert all(l.weight_for_large_batch is not None for l in linear_list)
self.assertEqual(y[0], y_ref[0])
if __name__ == "__main__":
test = unittest.main()