Skip to content

Commit

Permalink
Add export examples for new API (#225)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
xin3he authored and zehao-intel committed Dec 9, 2022
1 parent 7e8c755 commit 19d0cf7
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,10 @@ Shapley values originate from cooperative game theory that come with desirable p
> **Note** : run_glue_tune_with_shap.py is the example of "SST2" task. If you want to execute other glue task, you may take some slight change under "ShapleyMSE" class.

# Appendix

## Export to ONNX

Right now, we experimentally support exporting PyTorch model to ONNX model, includes FP32 and INT8 model.

By enabling `--onnx` argument, Intel Neural Compressor will export fp32 ONNX model, INT8 QDQ ONNX model, and INT8 QLinear ONNX model.
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ torch >= 1.3
transformers>=4.10.0
shap
scipy
sacremoses
sacremoses
onnx
onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,25 @@ class ModelArguments:
tune: bool = field(
default=False,
metadata={
"help": "tune quantized model with Intel Neural Compressor)."
},
"help": "tune quantized model with Intel Neural Compressor)."},
)
benchmark: bool = field(
default=False,
metadata={"help": "run benchmark."})
metadata={"help": "run benchmark."},
)
int8: bool = field(
default=False,
metadata={"help":"run benchmark."})
metadata={"help":"initialize int8 model."},
)
accuracy_only: bool = field(
default=False,
metadata={"help":"Whether to only test accuracy for model tuned by Neural Compressor."})
metadata={"help":"Whether to only test accuracy for model tuned by Neural Compressor."},
)
onnx: bool = field(
default=False, metadata={"help": "convert PyTorch model to ONNX"}
)


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -439,6 +446,24 @@ def eval_func_for_nc(model_tuned):
q_model = fit(model, conf=conf, eval_func=eval_func_for_nc)
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)

if model_args.onnx:
eval_dataloader = trainer.get_eval_dataloader()
it = iter(eval_dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}
from neural_compressor.config import Torch2ONNXConfig
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-model.onnx', int8_onnx_config)
exit(0)

if model_args.accuracy_only:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ function run_tuning {
--no_cuda \
--output_dir ${tuned_checkpoint} \
--tune \
--onnx \
${extra_cmd}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,11 @@ quantizer.model = common.Model(model)
model = quantizer.fit()
model.save(training_args.output_dir)
```

# Appendix

## Export to ONNX

Right now, we experimentally support exporting PyTorch model to ONNX model, includes FP32 and INT8 model.

By enabling `--onnx` argument, Intel Neural Compressor will export fp32 ONNX model, INT8 QDQ ONNX model, and INT8 QLinear ONNX model.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ protobuf
scipy
scikit-learn
Keras-Preprocessing
onnx
onnxruntime
transformers >= 4.16.0
--find-links https://download.pytorch.org/whl/torch_stable.html
torch >= 1.8.0+cpu
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class ModelArguments:
accuracy_only: bool = field(
default=False, metadata={"help": "get accuracy"}
)
onnx: bool = field(
default=False, metadata={"help": "convert PyTorch model to ONNX"}
)


def main():
Expand Down Expand Up @@ -502,9 +505,46 @@ def eval_func(model):
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
tuning_criterion = TuningCriterion(max_trials=600)
conf = PostTrainingQuantConfig(approach="static", backend="pytorch_fx", tuning_criterion=tuning_criterion)
model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func)
q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func)
from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(model, tokenizer, training_args.output_dir)
save_for_huggingface_upstream(q_model, tokenizer, training_args.output_dir)

if model_args.onnx:
it = iter(eval_dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}
from neural_compressor.config import Torch2ONNXConfig
fp32_onnx_config = Torch2ONNXConfig(
dtype="fp32",
opset_version=14,
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('fp32-model.onnx', fp32_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QDQ",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QLinear",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
q_model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
return

if model_args.benchmark or model_args.accuracy_only:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ function run_tuning {
--no_cuda \
--output_dir ${tuned_checkpoint} \
--tune \
--onnx \
--overwrite_output_dir \
${extra_cmd}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,11 @@ model = OptimizedModel.from_pretrained(
```

We also upstreamed several int8 models into HuggingFace [model hub](https://huggingface.co/models?other=Intel%C2%AE%20Neural%20Compressor) for users to ramp up.

# Appendix

## Export to ONNX

Right now, we experimentally support exporting PyTorch model to ONNX model, includes FP32 and INT8 model.

By enabling `--onnx` argument, Intel Neural Compressor will export fp32 ONNX model, INT8 QDQ ONNX model, and INT8 QLinear ONNX model.
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ datasets == 1.18.0
sentencepiece != 0.1.92
protobuf
scipy
onnx
onnxruntime
--find-links https://download.pytorch.org/whl/torch_stable.html
torch >= 1.8.0+cpu
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ class ModelArguments:
benchmark: bool = field(
default=False, metadata={"help": "get benchmark instead of accuracy"}
)
onnx: bool = field(
default=False, metadata={"help": "convert PyTorch model to ONNX"}
)


def main():
Expand Down Expand Up @@ -533,6 +536,43 @@ def benchmark(model):

from neural_compressor.utils.load_huggingface import save_for_huggingface_upstream
save_for_huggingface_upstream(model, tokenizer, training_args.output_dir)

if model_args.onnx:
it = iter(eval_dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}
from neural_compressor.config import Torch2ONNXConfig
fp32_onnx_config = Torch2ONNXConfig(
dtype="fp32",
opset_version=14,
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
model.export('fp32-model.onnx', fp32_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QDQ",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
model.export('int8-nlp-qdq-model.onnx', int8_onnx_config)
int8_onnx_config = Torch2ONNXConfig(
dtype="int8",
opset_version=14,
quant_format="QLinear",
example_inputs=tuple(input.values()),
input_names=list(input.keys()),
output_names=['labels'],
dynamic_axes=dynamic_axes,
)
model.export('int8-nlp-qlinear-model.onnx', int8_onnx_config)
return

if model_args.benchmark:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function run_tuning {
--save_strategy steps \
--metric_for_best_model f1 \
--save_total_limit 1 \
--onnx \
--tune
}

Expand Down

0 comments on commit 19d0cf7

Please sign in to comment.