From 7963b4ce2bb6f248f0a1c0a5420ba4690e45afce Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Mon, 31 Aug 2020 16:48:22 -0700 Subject: [PATCH] Propagating onnx.export() parameters; requiring torch>=1.6 --- nemo/core/classes/exportable.py | 26 +++++++++++++++----------- requirements/requirements.txt | 4 ++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 9424a889aaad..fe672875b98f 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -47,13 +47,17 @@ class Exportable(ABC): """ def export( - self, - output: str, - input_example=None, - output_example=None, - onnx_opset_version: int = 12, - try_script: bool = False, - set_eval: bool = True, + self, + output: str, + input_example=None, + output_example=None, + verbose=False, + export_params=True, + do_constant_folding=True, + keep_initializers_as_inputs=False, + onnx_opset_version: int = 12, + try_script: bool = False, + set_eval: bool = True, ): try: # Disable typechecks @@ -134,10 +138,10 @@ def export( output, input_names=input_names, output_names=output_names, - verbose=False, - export_params=True, - do_constant_folding=True, - keep_initializers_as_inputs=True, + verbose=verbose, + export_params=export_params, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=_out_example, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c229303e7c88..e0a69d02aeed 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,11 +2,11 @@ numpy>=1.18.2 onnx>=1.7.0 pytorch-lightning==0.9.0 python-dateutil -torch +torch>=1.6.0 wget wrapt ruamel.yaml scikit-learn omegaconf==2.0.1rc12 hydra-core==1.0.0rc4 -transformers>=2.11.0 \ No newline at end of file +transformers>=2.11.0