-
Notifications
You must be signed in to change notification settings - Fork 258
/
util.py
1273 lines (1095 loc) · 48.1 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util Class and Functions."""
import copy
import json
import re
from collections import UserDict
from functools import partial
import numpy as np
from packaging.version import Version
from ...utils import logger
from ...utils.utility import CpuInfo, LazyImport
tqdm = LazyImport("tqdm")
torch = LazyImport("torch")
def get_embedding_contiguous(model):
"""This is a helper function for nn.Embedding, and it will get input contiguous.
Args:
model (object): the input model
Returns:
None
"""
def contiguous_hook(module, input):
embeddings = input[0].contiguous()
modified_input = (embeddings, *input[1:])
return modified_input
for child in model.modules():
child_type = child.__class__.__name__
if child_type == "Embedding":
child.register_forward_pre_hook(contiguous_hook)
def is_fused_module(module):
"""This is a helper function for `_propagate_qconfig_helper` to detect if this module is fused.
Args:
module (object): the input module
Returns:
(bool): is fused or not
"""
op_type = str(type(module))
if "fused" in op_type:
return True
else:
return False
def collate_torch_preds(results):
"""Fetch collated results.
Args:
result (list): input result
Returns:
collate_results (list): collated results
"""
batch = results[0]
if isinstance(batch, list):
results = zip(*results)
collate_results = []
for output in results:
output = [batch.numpy() if isinstance(batch, torch.Tensor) else batch for batch in output]
collate_results.append(np.concatenate(output))
elif isinstance(batch, torch.Tensor):
results = [batch.numpy() if isinstance(batch, torch.Tensor) else batch for batch in results]
collate_results = np.concatenate(results)
return collate_results
def input2tuple(input):
"""This is a helper function to converting a inputting dict values or a list to a tuple.
Args:
input (list or dict).
Returns:
A tuple.
"""
if isinstance(input, dict) or isinstance(input, UserDict):
output = tuple(input.values())
elif isinstance(input, list) or isinstance(input, tuple):
output = tuple(input)
else:
output = input
return output
def append_attr(fx_model, model, fx_white_list=[]):
"""This is a helper method to append attributes for the symbolic traced model.
Args:
fx_model (torch.fx.GraphModule): The symbolic traced model.
model (torch.nn.Module): The original model.
Returns:
fx_model (dir): The symbolic traced model with additional attributes.
"""
fx_attr = dir(fx_model)
org_attr = dir(model)
ignore_match_patterns = [r"_", r"quant", r"dequant", r"weight", r"bias", r"activation_post_process"]
ignore_search_patterns = [r"_scale_", r"_zero_point_", r"_activation_post_process_"]
add_special_patterns = [r"_forward_hooks", r"_forward_pre_hooks", r"_backward_hooks"]
attr_names = []
if hasattr(fx_model, "module") and hasattr(fx_model.module, "weight"):
if not isinstance(fx_model.module.weight, torch.Tensor):
fx_model.weight = fx_model.module.weight()
else:
fx_model.weight = fx_model.module.weight
for i in org_attr:
if (
type(model) in fx_white_list
and type(model) != torch.nn.Sequential
and any([re.search(p, i) for p in add_special_patterns])
):
continue
if any([re.search(p, i) for p in add_special_patterns]) or (
i not in fx_attr
and not any([re.match(p, i) for p in ignore_match_patterns])
and not any([re.search(p, i) for p in ignore_search_patterns])
):
attr_names.append(i)
for name in attr_names:
attr = getattr(model, name, None)
if isinstance(attr, torch.nn.Module) or isinstance(attr, torch.quantization.qconfig.QConfig):
continue
setattr(fx_model, name, attr)
return fx_model
def generate_activation_observer(scheme, algorithm): # pragma: no cover
"""This is a helper method to generate an activation observer.
Args:
scheme (str): Quantization scheme to be used.
algorithm (str): What algorithm for computing the quantization parameters based on.
Returns:
An observer.
"""
kl_activation_observer = {
"name": "HistogramObserver",
"bins": 2048,
"upsample_rate": 128,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
}
minmax_activation_observer = {
"name": "MinMaxObserver",
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
}
REDUCE_RANGE = False if CpuInfo().vnni else True
if REDUCE_RANGE:
minmax_activation_observer["reduce_range"] = REDUCE_RANGE
kl_activation_observer["reduce_range"] = REDUCE_RANGE
if scheme == "sym":
minmax_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
minmax_activation_observer["dtype"] = "torch.qint8"
minmax_activation_observer["quant_min"] = -128
minmax_activation_observer["quant_max"] = 127
kl_activation_observer["qscheme"] = "torch.per_tensor_symmetric"
kl_activation_observer["dtype"] = "torch.qint8"
kl_activation_observer["quant_min"] = -128
kl_activation_observer["quant_max"] = 127
if algorithm == "kl":
return kl_activation_observer
if algorithm == "minmax":
return minmax_activation_observer
def check_cfg_and_qconfig(tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_op_name): # pragma: no cover
"""Check configs and quantization configs.
Args:
tune_cfg (dict): dictionary of quantization configuration.
cfgs (dict): the input configs.
op_infos_from_cfgs (dict): op infos from configs.
output_tensor_ids_op_name (dict): dictionary of output tensor op names.
Returns:
cfgs (dict).
"""
for op_name in tune_cfg:
inc_op_cfg = tune_cfg[op_name]
for i, name in enumerate(op_name[0]):
# to int8
ipex_op_cfg = op_infos_from_cfgs[name]
input_tensor_infos = ipex_op_cfg["input_tensor_infos"]
for index, input_tensor_info in enumerate(input_tensor_infos):
if "force_dtype" not in input_tensor_info.keys():
continue
if (
input_tensor_info["force_dtype"] == "torch.qint8"
or input_tensor_info["force_dtype"] == "torch.quint8"
):
# int8 -> int8
if inc_op_cfg["weight"]["dtype"] == "int8":
inc_scheme = inc_op_cfg["activation"]["scheme"]
inc_algorithm = inc_op_cfg["activation"]["algorithm"]
ipex_op_cfg["input_tensor_infos"] = input_tensor_infos
activation_observer = generate_activation_observer(inc_scheme, inc_algorithm)
if inc_scheme == "sym":
input_tensor_infos[index]["force_dtype"] = "torch.qint8"
if inc_scheme == "asym":
input_tensor_infos[index]["force_dtype"] = "torch.quint8"
ipex_op_cfg["activation_observer"] = activation_observer
# int8 -> fp32
else:
input_tensor_infos[index]["force_dtype"] = "torch.float32"
# modify pre_op output inf_dtype
if i == 0:
input_tensor_id = input_tensor_info["id"]
input_tensor_dtype = input_tensor_info["force_dtype"]
if input_tensor_id in output_tensor_ids_op_name.keys():
pre_op_name = output_tensor_ids_op_name[input_tensor_id]
pre_op_module = pre_op_name[0][0]
pre_op_state = pre_op_name[0][1]
pre_op_index = pre_op_name[0][2]
pre_op_infos = cfgs[pre_op_module][pre_op_state][pre_op_index]
pre_op_output_infos = pre_op_infos["output_tensor_infos"]
for index, pre_op_output in enumerate(pre_op_output_infos):
if pre_op_output["id"] == input_tensor_id:
pre_op_output_infos[index]["inf_dtype"] = input_tensor_dtype
else:
pass
pre_op_infos["output_tensor_infos"] = pre_op_output_infos
cfgs[pre_op_module][pre_op_state][pre_op_index] = pre_op_infos
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs
def paser_cfgs(cfgs): # pragma: no cover
"""Parse configs.
Args:
cfgs (dict): the input configs.
Returns:
ops_name (list): list of op names.
tune_cfg (dict): dictionary of quantization configuration.
op_infos_from_cfgs (dict): op infos from configs.
output_tensor_ids_op_name (dict): dictionary of output tensor op names.
"""
ops_name = []
layer_output_infos_ids = []
op_infos_from_cfgs = {}
# record input_tensor_id and op_name
# {"0": [(" ", "q_op_infos", "0"), (" ", "q_op_infos", "1")]}
input_tensor_ids_op_name = {}
output_tensor_ids_op_name = {}
for module_key in cfgs.keys():
for state in cfgs[module_key]:
if state == "layer_output_infos":
for index, op_info in enumerate(cfgs[module_key][state]):
name = (module_key, state, index)
ops_name.append(name)
layer_output_infos_ids.append(op_info["id"])
op_infos_from_cfgs[name] = op_info
continue
for op_cfg_id in cfgs[module_key][state].keys():
op_info = cfgs[module_key][state][op_cfg_id]
name = (module_key, state, op_cfg_id)
if name not in ops_name:
ops_name.append(name)
else:
assert False, "Please check IPEX int8 configure json whether have the same name ops"
op_infos_from_cfgs[name] = op_info
input_tensors = op_info["input_tensor_infos"]
for input_tensor in input_tensors:
if "id" not in input_tensor.keys():
continue
else:
input_tensor_id = input_tensor["id"]
if input_tensor_id not in input_tensor_ids_op_name.keys():
input_tensor_ids_op_name[input_tensor_id] = [name]
else:
input_tensor_ids_op_name[input_tensor_id].append(name)
output_tensors = op_info["output_tensor_infos"]
for output_tensor in output_tensors:
if "id" not in output_tensor.keys():
continue
else:
output_tensor_id = output_tensor["id"]
if output_tensor_id not in output_tensor_ids_op_name.keys():
output_tensor_ids_op_name[output_tensor_id] = [name]
else:
output_tensor_ids_op_name[output_tensor_id].append(name)
return ops_name, op_infos_from_cfgs, input_tensor_ids_op_name, output_tensor_ids_op_name
def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids_op_name): # pragma: no cover
"""Get quantizable ops from configs, combine fused ops as one op.
Args:
ops_name (list): list of op names.
op_infos_from_cfgs (dict): op infos from configs.
input_tensor_ids_op_name (dict): dictionary of input tensor op names.
Returns:
cfgs (dict).
"""
quantizable_ops = []
seen_ops = []
for name in ops_name:
start = True
if name in seen_ops:
continue
elif name[1] not in ["q_op_infos"]:
continue
else:
# judge fuse ops the first op
op_info = op_infos_from_cfgs[name]
output_tensors = op_info["output_tensor_infos"]
input_tensors = op_info["input_tensor_infos"]
for input_tensor in input_tensors:
if "inf_dtype" not in input_tensor.keys():
continue
if input_tensor["inf_dtype"] == torch.float32:
pre_op_name = input_tensor_ids_op_name[input_tensor["id"]]
if pre_op_name[1] in ["q_op_infos"]:
print(pre_op_name, "is not the fuse ops first op.")
start = False
continue
if not start:
continue
# add quantizable ops, include op and fuse ops.
q_ops, stack = [], [(name, [])]
while stack:
cur_name, cur = stack.pop()
seen_ops.append(cur_name)
if cur_name[1] not in ["q_op_infos"]:
q_ops.append(cur)
break
op_info = op_infos_from_cfgs[cur_name]
output_tensors = op_info["output_tensor_infos"]
for output_tensor in output_tensors:
if output_tensor["inf_dtype"] == "torch.qint8" or output_tensor["inf_dtype"] == "torch.quint8":
q_ops.append(cur + [cur_name])
break
try:
next_op_names = input_tensor_ids_op_name[output_tensor["id"]]
for next_op_name in next_op_names:
stack.append((next_op_name, cur + [cur_name]))
except:
next_op_name = None
if next_op_name is None:
q_ops.append(cur + [cur_name])
for q_op in q_ops:
quantizable_ops.append(q_op)
return quantizable_ops
def update_sq_scale(ipex_config_path, smoothquant_scale_info):
"""Update ipex_config.json with smoothquant scale info generated by our algorithm.
Args:
ipex_config_path (str): a path to temporary ipex_config.json file.
smoothquant_scale_info (dict): a dict contains smoothquant scale info.
"""
with open(ipex_config_path, "r") as f:
ipex_config = json.load(f)
for module_name, v in ipex_config.items():
if "q_op_infos" in v and v["q_op_infos"]:
for op_num, v1 in v["q_op_infos"].items():
# update alpha data instead of updating weight scale
op_name = v1["fqn"] # fqn always exists even it's empty.
if op_name in smoothquant_scale_info:
input_scale_for_mul = smoothquant_scale_info[op_name]["input_scale_for_mul"].tolist()
input_scale_after_mul = smoothquant_scale_info[op_name]["input_scale_after_mul"].tolist()
input_zero_point_after_mul = smoothquant_scale_info[op_name][
"input_zero_point_after_mul"
].tolist()
weight_scale_for_mul = (1 / smoothquant_scale_info[op_name]["input_scale_for_mul"]).tolist()
weight_scale_after_mul = smoothquant_scale_info[op_name]["weight_scale_after_mul"].tolist()
v1["input_tensor_infos"][0]["smooth_quant_scaling_factor"] = input_scale_for_mul
v1["input_tensor_infos"][0]["scale"] = input_scale_after_mul
v1["input_tensor_infos"][0]["zero_point"] = input_zero_point_after_mul
v1["weight_tensor_infos"][0]["smooth_quant_scaling_factor"] = weight_scale_for_mul
v1["weight_tensor_infos"][0]["scale"] = weight_scale_after_mul
# # observers were overridden by the fallback step, setting it back.
v1["activation_observer"] = {
"name": "SmoothQuantActivationObserver",
"smooth_quant_enabled": True,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
"alpha": smoothquant_scale_info[op_name]["alpha"],
"act_observer": {
"name": "HistogramObserver",
"bins": 2048,
"upsample_rate": 128,
"dtype": "torch.quint8",
"qscheme": "torch.per_tensor_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
},
"act_ic_observer": {
"name": "PerChannelMinMaxObserver",
"ch_axis": -1,
"dtype": "torch.quint8",
"qscheme": "torch.per_channel_affine",
"reduce_range": False,
"quant_min": 0,
"quant_max": 255,
},
}
v1["weight_observer"] = {
"name": "SmoothQuantWeightObserver",
"smooth_quant_enabled": True,
"dtype": "torch.qint8",
"qscheme": "torch.per_channel_symmetric",
"reduce_range": False,
"quant_min": -128,
"quant_max": 127,
"alpha": smoothquant_scale_info[op_name]["alpha"],
"wei_observer": {
"name": "PerChannelMinMaxObserver",
"ch_axis": 0,
"dtype": "torch.qint8",
"qscheme": "torch.per_channel_symmetric",
"reduce_range": False,
"quant_min": -128,
"quant_max": 127,
},
"wei_ic_observer": {
"name": "PerChannelMinMaxObserver",
"ch_axis": 1,
"dtype": "torch.qint8",
"qscheme": "torch.per_channel_affine",
"reduce_range": False,
"quant_min": -128,
"quant_max": 127,
},
}
f.close()
# overwrite ipex_config_path
with open(ipex_config_path, "w") as f1:
json.dump(ipex_config, f1, indent=4)
f1.close()
def auto_copy(module): # pragma: no cover
"""Get an IPEX prepared model and return a fp32 model.
Args:
module (object): IPEX prepared model.
Returns:
fp32 model.
"""
from intel_extension_for_pytorch.quantization._quantization_state import AutoQuantizationStateModuleDict
def _nn_sequential_patched_forward(cls, x):
for module in cls:
if not isinstance(module, AutoQuantizationStateModuleDict):
x = module(x)
return x
new_module = copy.deepcopy(module)
if hasattr(new_module, "_qconf_summary"):
del new_module._qconf_summary
if hasattr(new_module, "_fqn_to_auto_quant_state_map"):
del new_module._fqn_to_auto_quant_state_map
if hasattr(new_module, "q_config"):
del new_module.q_config
def convert_to_dispatch_proxy(x):
if isinstance(x, torch.Tensor):
return x.as_subclass(CopyTensorProxy) # type: ignore[arg-type]
else:
return x
global_disable_torch_function_override = False
class CopyTensorProxy(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
nonlocal global_disable_torch_function_override
if (
# global override means disable the override here
global_disable_torch_function_override
or
# to prevent printing things from going into an infinite loop
func == torch.Tensor.__repr__
or
# we don't need to override getters in this framework
func.__name__ == "__get__"
):
return super().__torch_function__(func, types, args, kwargs)
kwargs = kwargs if kwargs else {}
output = super().__torch_function__(func, types, args, kwargs)
if output is NotImplemented:
with torch._C.DisableTorchFunction():
output = func(*args, **kwargs).as_subclass(
CopyConvertTensorProxy # pylint: disable=E0602 # noqa: F821
)
assert output is not NotImplemented
return output
def __repr__(self):
return f"CopyTensorProxy({super().__repr__()})"
cur_module = None
module_stack: List[torch.nn.Module] = [] # pylint: disable=E0602 # noqa: F821
assert len(module.__class__.__bases__) == 1
class CopyDispatchModule(module.__class__.__bases__[0]):
def __call__(self, *args, **kwargs):
new_args = torch.fx.node.map_aggregate(args, convert_to_dispatch_proxy)
new_kwargs = torch.fx.node.map_aggregate(kwargs, convert_to_dispatch_proxy)
orig_module_call = torch.nn.Module.__call__
orig_nn_sequential_forward = torch.nn.Sequential.forward
def _patched_module_call(self, *args, **kwargs):
nonlocal cur_module
old_module = cur_module
cur_module = self
nonlocal global_disable_torch_function_override
try:
parent_module = module_stack[-1] if len(module_stack) else None
module_stack.append(self)
output = orig_module_call(self, *args, **kwargs)
return output
finally:
module_stack.pop()
cur_module = old_module
torch.nn.Module.__call__ = _patched_module_call
torch.nn.Sequential.forward = _nn_sequential_patched_forward # type: ignore[assignment]
try:
output = super().__call__(*new_args, **new_kwargs)
def unwrap_proxy(a):
if isinstance(a, CopyTensorProxy):
a.__class__ = torch.Tensor # type: ignore[assignment]
return a
output = torch.fx.node.map_aggregate(output, unwrap_proxy)
return output
finally:
torch.nn.Module.__call__ = orig_module_call
torch.nn.Sequential.forward = orig_nn_sequential_forward # type: ignore[assignment]
new_module.__class__ = CopyDispatchModule
return new_module
def fetch_module(model, op_name):
"""Get module with a given op name.
Args:
model (object): the input model.
op_name (str): name of op.
Returns:
module (object).
"""
module = model
name_list = op_name.split(".")
for name in name_list:
if hasattr(module, name):
module = getattr(module, name)
else:
module = module
return module
def set_module(model, op_name, new_module):
"""Set module with a given op name.
Args:
model (object): the input model.
op_name (str): name of op.
new_module (object): the input model.
Returns:
module (object).
"""
module = model
name_list = op_name.split(".")
for name in name_list[:-1]:
if hasattr(module, name):
module = getattr(module, name)
else:
module = module
setattr(module, name_list[-1], new_module)
return module
def simple_inference(model, input):
"""Record model output tensor.
Args:
model (object): the input model.
input (object).
Returns:
output (object).
"""
with torch.no_grad():
if isinstance(input, (dict, UserDict)):
output = model(**input)
elif isinstance(input, (list, tuple)):
try:
output = model(*input)
except:
output = model(input)
else:
output = model(input)
return output
def get_example_input(dataloader, i=1):
"""Get the example input.
Args:
dataloader (object): calibration dataset.
Returns:
example_inp (object).
"""
iter = 0
try:
for example_inp, label in dataloader:
if iter == i:
break
else:
iter += 1
except:
for example_inp in dataloader:
if iter == i:
break
else:
iter += 1
return example_inp
def get_fallback_order(
adaptor, fp32_model, dataloader, tune_cfg, confidence_batches, fallback=False, requantize_cfgs=None
):
"""Get the fall back order for strategy.
Args:
fp32_model (object): the input model.
dataloader(torch.utils.data.DataLoader): The calibration dataloader.
tune_cfg (dict): dictionary of quantization configuration.
confidence_batches (int): number of confidence batches.
fallback (bool): if the order is fallback.
Returns:
ordered_ops (dict/list): The fallback order for strategy.
"""
fp32_model.eval()
order_dict = {}
for i in range(0, confidence_batches):
example_input = get_example_input(dataloader, i)
if fallback:
ordered_ops = get_mse_order_per_fp32(adaptor, fp32_model, example_input, tune_cfg)
for i, name in enumerate(ordered_ops):
order_dict[name] = order_dict.get(name, 0) + len(order_dict) - i
ordered_ops = sorted(order_dict, key=lambda k: order_dict[k], reverse=True)
else:
ordered_ops = get_mse_order_per_int8(adaptor, fp32_model, example_input, tune_cfg)
for i, name in enumerate(ordered_ops):
order_dict[name] = order_dict.get(name, 0) + len(order_dict) - i
return ordered_ops
op_cfg_mapping = {}
def get_mse_order_per_fp32(adaptor, model, example_inp, tune_cfg):
"""This is a helper method to check the mse influence to last module after QDQ(quant/dequant).
Args:
model (torch.fx.GraphModule/torch.nn.Module): A torch model.
example_inp (object): example inputs.
tune_cfg (dict): dictionary of quantization configuration.
Returns:
fallback_order (dict/list): The fallback order for strategy.
"""
inner_output = None
def output_hook(self, input, output):
nonlocal inner_output
inner_output = output
return output
op_type_dict = {}
for k, v in tune_cfg["op"].keys():
op_type_dict[k] = v
from ..pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _cfgs_to_fx_cfgs
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg["approach"])
# insert hook to get output tesnor from last module
last_module_name = list(op_cfgs.keys())[-1]
module = fetch_module(model, last_module_name) # get last module
module.register_forward_hook(output_hook)
# record fp32 model output tensor at first
output_fp32 = simple_inference(model, example_inp)
inner_output_fp32 = inner_output
fx_op_cfgs = {}
fallback_order = {}
logger.info("Evaluate the sensitivity for each int8 operation")
for op_name, qconfig in tqdm(op_cfgs.items()):
if op_name == "bf16_ops_list":
continue
global op_cfg_mapping
if op_name not in op_cfg_mapping:
op_cfg_mapping[op_name] = qconfig
tmp_model = copy.deepcopy(model)
if not qconfig:
logger.debug(f"No qconfig for {op_name}, next op.")
continue
op_cfgs[op_name] = None
fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"])
op_cfgs[op_name] = qconfig
from torch.quantization.quantize_fx import convert_fx, prepare_fx
# do quantization
if adaptor.sub_module_list is None:
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model,
fx_op_cfgs,
)
else:
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, tmp_model, prefix="")
simple_inference(tmp_model, example_inp)
if adaptor.sub_module_list is None:
tmp_model = convert_fx(tmp_model)
else:
PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, tmp_model, prefix="")
# insert hook to get output tesnor from last module
module = fetch_module(tmp_model, list(op_cfgs.keys())[-1]) # get last module
module.register_forward_hook(output_hook)
output_qdq = simple_inference(tmp_model, example_inp)
inner_output_int8 = inner_output.dequantize() if inner_output.dtype == torch.quint8 else inner_output
mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum()
fallback_order[(op_name, op_type_dict[op_name])] = mse_val
logger.debug(f"fallback order: {fallback_order}")
ordered_ops = sorted(fallback_order.keys(), key=lambda key: fallback_order[key], reverse=False)
if not ordered_ops:
return ordered_ops
min_mse, max_mse = fallback_order[ordered_ops[0]], fallback_order[ordered_ops[-1]]
if min_mse < 0.8 * max_mse:
logger.debug("Return the sorted ops early.")
return ordered_ops
double_check_list = []
for op_name in ordered_ops:
if min_mse <= fallback_order[op_name] <= (max_mse - min_mse) * 0.1 + min_mse:
double_check_list.append(op_name)
check_num = min(len(ordered_ops) // 10 + 1, 5)
double_check_list = ordered_ops[:check_num]
logger.debug(f"double check list: {double_check_list}")
worst_op_name = ordered_ops[-1]
op_cfgs[worst_op_name[0]] = None # fallback worst module first
new_fallback_order = {}
logger.info("Evaluate the sensitivity gradient for selected operations")
for op_name, op_type in tqdm(double_check_list):
tmp_model = copy.deepcopy(model)
qconfig = op_cfgs[op_name]
op_cfgs[op_name] = None
fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"])
op_cfgs[op_name] = qconfig
from torch.quantization.quantize_fx import convert_fx, prepare_fx
# do quantization
if adaptor.sub_module_list is None:
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model,
fx_op_cfgs,
)
else:
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, tmp_model, prefix="")
simple_inference(tmp_model, example_inp)
if adaptor.sub_module_list is None:
tmp_model = convert_fx(tmp_model)
else:
PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, tmp_model, prefix="")
# insert hook to get output tesnor from last module
module = fetch_module(tmp_model, last_module_name) # get last module
module.register_forward_hook(output_hook)
output_qdq = simple_inference(tmp_model, example_inp)
inner_output_int8 = inner_output.dequantize() if inner_output.dtype == torch.quint8 else inner_output
mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum()
new_fallback_order[(op_name, op_type_dict[op_name])] = mse_val
ordered_ops = sorted(new_fallback_order.keys(), key=lambda key: new_fallback_order[key], reverse=False)
return ordered_ops
def get_mse_order_per_int8(adaptor, fp32_model, example_input, tune_cfg):
"""This is a helper method to check the mse influence to last module after QDQ(quant/dequant).
Args:
model (torch.fx.GraphModule/torch.nn.Module): A torch model.
example_inp (object): example inputs.
tune_cfg (dict): dictionary of quantization configuration.
Returns:
fallback_order (dict/list): The fallback order for strategy.
"""
inner_output = None
def output_hook(self, input, output):
nonlocal inner_output
inner_output = output
return output
op_type_dict = {}
for k, v in tune_cfg["op"].keys():
op_type_dict[k] = v
example_inp = example_input
from ..pytorch import _cfg_to_qconfig
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg["approach"])
module = fetch_module(fp32_model, list(op_cfgs.keys())[-1]) # get last module
# insert hook to get output tesnor from last module
module.register_forward_hook(output_hook)
# record fp32 model output tensor at first
output_fp32 = simple_inference(fp32_model, example_inp)
inner_output_fp32 = inner_output
quant_list = []
for k, v in tune_cfg["op"].items():
if k[1] in ["LayerNorm", "Dropout", "InstanceNorm3d"]:
continue
if v["weight"]["dtype"] == "fp32":
quant_list.append(k)
fallback_order = {}
logger.info("Evaluate the sensitivity for each fp32 operation")
for op_name, op_type in tqdm(quant_list):
if op_name in op_cfg_mapping:
tmp_model = copy.deepcopy(fp32_model)
from ..pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _cfgs_to_fx_cfgs
op_cfgs[op_name] = op_cfg_mapping[op_name]
fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg["approach"])
from torch.quantization.quantize_fx import convert_fx, prepare_fx
# do quantization
if adaptor.sub_module_list is None:
if adaptor.version.release >= Version("1.13.0").release: # pragma: no cover
tmp_model = prepare_fx(tmp_model, fx_op_cfgs, example_inp)
else:
tmp_model = prepare_fx(
tmp_model,
fx_op_cfgs,
)
else:
PyTorch_FXAdaptor.prepare_sub_graph(adaptor.sub_module_list, fx_op_cfgs, tmp_model, prefix="")
simple_inference(tmp_model, example_inp)
if adaptor.sub_module_list is None:
tmp_model = convert_fx(tmp_model)
else:
PyTorch_FXAdaptor.convert_sub_graph(adaptor.sub_module_list, tmp_model, prefix="")
# record int8 model output tensor
module = fetch_module(tmp_model, list(op_cfgs.keys())[-1]) # get last module
module.register_forward_hook(output_hook)
output_qdq = simple_inference(tmp_model, example_inp)
inner_output_int8 = inner_output
if inner_output_fp32.dtype == torch.quint8:
inner_output_fp32 = inner_output_fp32.dequantize()
if inner_output_int8.dtype == torch.quint8:
inner_output_int8 = inner_output_int8.dequantize()
mse_val = (inner_output_fp32 - inner_output_int8).pow(2).sum()
fallback_order[(op_name, op_type_dict[op_name])] = mse_val
# re-insert fp32 module into model
ordered_ops = sorted(fallback_order.keys(), key=lambda key: fallback_order[key], reverse=False)
return ordered_ops
def get_torch_version():
"""Get torch version."""
from packaging.version import Version
try:
torch_version = torch.__version__.split("+")[0]
except ValueError as e: # pragma: no cover
assert False, "Got an unknown version of torch: {}".format(e)
version = Version(torch_version)
return version
def match_datatype_pattern(datatype, pattern=None):
"""Check the datatype pattern."""
import re
if not pattern:
pattern = r"(uint|int)([1-8])"
match = re.match(pattern, datatype)
return match
def _get_signed_and_bits(datatype):
"""Parse sign and bits from datatype."""
unsigned = datatype[0] == "u"
if unsigned:
num_bits = int(datatype[4:])
else:
num_bits = int(datatype[3:])
return unsigned, num_bits
def calculate_quant_min_max(unsigned, num_bits):
"""Calculate the qmin and qmax according to the datatype."""
# TODO handle reduce range
quant_min, quant_max = None, None
if unsigned:
quant_min, quant_max = 0.0, 2.0 ** (num_bits) - 1.0
else:
quant_min, quant_max = -1 * 2.0 ** (num_bits - 1), 2.0 ** (num_bits - 1) - 1
return quant_min, quant_max
def get_depth(d) -> int:
"""Query the depth of the dict."""
if isinstance(d, dict):
return 1 + max(get_depth(v) for v in d.values())
return 0
def get_dict_at_depth(d, target_depth, result, depth=0):
"""Get all sub-dicts that are at a specified depth in a nested dict."""
if depth == target_depth:
result.append(d)
return
elif depth < target_depth and isinstance(d, dict):
for k, v in d.items():
get_dict_at_depth(v, target_depth, result, depth=depth + 1)
def get_element_under_depth(d, ops_lst):
"""Get all values in a nested dict."""
if isinstance(d, dict):
for k, v in d.items():
get_element_under_depth(v, ops_lst)
else:
ops_lst.append(d)
def get_op_type_by_name(op_name, quantizable_ops):
"""Get op type by op name."""