-
Notifications
You must be signed in to change notification settings - Fork 0
/
py_quant_utils.py
140 lines (101 loc) · 4.99 KB
/
py_quant_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import List
import json
# PyTorch
import torch
# Pytorch Quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.nn.modules import _utils as quant_nn_utils
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import quant_modules
from absl import logging as quant_logging
import pdb
class SummaryTool:
def __init__(self, file):
self.file = file
self.data = []
def append(self, item):
self.data.append(item)
json.dump(self.data, open(self.file, "w"), indent=4)
def have_quantizer(module):
for name, module in module.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
return True
class disable_quantization:
def __init__(self, model):
self.model = model
def apply(self, disabled=True):
for name, module in self.model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module._disabled = disabled
def __enter__(self):
self.apply(True)
def __exit__(self, *args, **kwargs):
self.apply(False)
class enable_quantization:
def __init__(self, model):
self.model = model
def apply(self, enabled=True):
for name, module in self.model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module._disabled = not enabled
def __enter__(self):
self.apply(True)
return self
def __exit__(self, *args, **kwargs):
self.apply(False)
# Initialize PyTorch Quantization
def initialize_calib_method(per_channel_quantization=True, calib_method="histogram"):
## Initialize quantization, model and data loaders
if per_channel_quantization:
quant_desc_input = QuantDescriptor(calib_method=calib_method)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_logging.set_verbosity(quant_logging.ERROR)
else:
## Force per tensor quantization for onnx runtime
quant_desc_input = QuantDescriptor(calib_method=calib_method, axis=None)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_desc_weight = QuantDescriptor(calib_method=calib_method, axis=None)
quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
quant_nn.QuantMaxPool2d.set_default_quant_desc_weight(quant_desc_weight)
quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)
quant_logging.set_verbosity(quant_logging.ERROR)
def transfer_torch_to_quantization(nninstance: torch.nn.Module, quantmodule):
quant_instance = quantmodule.__new__(quantmodule)
for k, val in vars(nninstance).items():
setattr(quant_instance, k, val)
def __init__(self):
quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__)
if isinstance(self, quant_nn_utils.QuantInputMixin):
self.init_quantizer(quant_desc_input)
# Turn on torch_hist to enable higher calibration speeds
if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
self._input_quantizer._calibrator._torch_hist = True
else:
self.init_quantizer(quant_desc_input, quant_desc_weight)
# Turn on torch_hist to enable higher calibration speeds
if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
self._input_quantizer._calibrator._torch_hist = True
self._weight_quantizer._calibrator._torch_hist = True
__init__(quant_instance)
return quant_instance
def replace_to_quantization_module(model: torch.nn.Module, ignore_policy: List[str] = None):
module_dict = {}
for entry in quant_modules._DEFAULT_QUANT_MAP:
module = getattr(entry.orig_mod, entry.mod_name)
module_dict[id(module)] = entry.replace_mod
def recursive_and_replace_module(module, prefix=""):
for name in module._modules:
submodule = module._modules[name]
path = name if prefix == "" else prefix + "." + name
recursive_and_replace_module(submodule, path)
submodule_id = id(type(submodule))
if submodule_id in module_dict:
if ignore_policy is not None and path in ignore_policy:
continue
module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id])
recursive_and_replace_module(model)