-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize.py
53 lines (44 loc) · 1.44 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import logging
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def onnx_quantize(model_dir: str, model_name: str, output_dir: str) -> None:
operators_to_quantize = [
"MatMul",
"Attention",
"Gather",
"LSTM",
"Transpose",
"EmbedLayerNormalization",
]
dqconfig = AutoQuantizationConfig.avx512_vnni(
is_static=False, per_channel=False, operators_to_quantize=operators_to_quantize
)
quantizer = ORTQuantizer.from_pretrained(model_dir, file_name=model_name)
quantizer.quantize(
save_dir=output_dir,
quantization_config=dqconfig,
)
logger.info(f"Quantized model saved to {output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Quantize an onnx model.")
parser.add_argument(
"--model_name", type=str, required=True, help="Model name .onnx"
)
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to model directory"
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output directory for the ONNX model",
)
args = parser.parse_args()
onnx_quantize(
model_name=args.model_name,
model_dir=args.model_dir,
output_dir=args.output_dir,
)