Skip to content

Commit

Permalink
Support trace with dictionary type example_inputs (#1353)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang1 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
changwangss and pre-commit-ci[bot] authored Oct 27, 2023
1 parent d8a035a commit afe3159
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
12 changes: 9 additions & 3 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,14 +2765,20 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False):
q_model._model = ipex.quantization.convert(model._model, inplace=inplace)
try:
if isinstance(self.example_inputs, dict):
q_model._model = torch.jit.trace(q_model._model, example_kwarg_inputs=self.example_inputs)
q_model._model = torch.jit.trace(
q_model._model,
example_kwarg_inputs=self.example_inputs,
)
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs)
q_model._model = torch.jit.freeze(q_model._model.eval())
except:
if isinstance(self.example_inputs, dict):
q_model._model = torch.jit.trace(
q_model._model, example_kwarg_inputs=self.example_inputs, strict=False
q_model._model,
example_kwarg_inputs=self.example_inputs,
strict=False,
check_trace=False,
)
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False)
Expand All @@ -2789,7 +2795,7 @@ def _ipex_post_quant_process(self, model, q_model, dataloader, inplace=False):
except:
if isinstance(self.example_inputs, dict):
q_model._model = torch.jit.trace(
q_model._model, example_kwarg_inputs=self.example_inputs, strict=False
q_model._model, example_kwarg_inputs=self.example_inputs, strict=False, check_trace=False
)
else:
q_model._model = torch.jit.trace(q_model._model, self.example_inputs, strict=False)
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/torch_utils/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def ipex_mixed_precision(model, example_inputs=None):
try:
mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs)
except:
mp_model = torch.jit.trace(mp_model, example_kwarg_inputs=example_inputs, strict=False)
mp_model = torch.jit.trace(
mp_model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False
)
else:
try:
mp_model = torch.jit.trace(mp_model, example_inputs)
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,9 @@ def trace(self, model, dummy_input):
dummy_input = move_input_to_device(dummy_input, "cpu")
if isinstance(dummy_input, dict) or isinstance(dummy_input, UserDict):
try:
traced_model = torch.jit.trace(model, example_kwarg_inputs=dict(dummy_input), strict=False)
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(dummy_input), strict=False, check_trace=False
)
traced_model = torch.jit.freeze(traced_model.eval(), optimize_numerics=optimize_numerics)
except Exception as e:
logger.warning(e)
Expand Down

0 comments on commit afe3159

Please sign in to comment.