Skip to content

Commit

Permalink
Propagating onnx.export() parameters; requiring torch>=1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
borisfom committed Aug 31, 2020
1 parent 6471d10 commit 7963b4c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
26 changes: 15 additions & 11 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ class Exportable(ABC):
"""

def export(
self,
output: str,
input_example=None,
output_example=None,
onnx_opset_version: int = 12,
try_script: bool = False,
set_eval: bool = True,
self,
output: str,
input_example=None,
output_example=None,
verbose=False,
export_params=True,
do_constant_folding=True,
keep_initializers_as_inputs=False,
onnx_opset_version: int = 12,
try_script: bool = False,
set_eval: bool = True,
):
try:
# Disable typechecks
Expand Down Expand Up @@ -134,10 +138,10 @@ def export(
output,
input_names=input_names,
output_names=output_names,
verbose=False,
export_params=True,
do_constant_folding=True,
keep_initializers_as_inputs=True,
verbose=verbose,
export_params=export_params,
do_constant_folding=do_constant_folding,
keep_initializers_as_inputs=keep_initializers_as_inputs,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
example_outputs=_out_example,
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ numpy>=1.18.2
onnx>=1.7.0
pytorch-lightning==0.9.0
python-dateutil
torch
torch>=1.6.0
wget
wrapt
ruamel.yaml
scikit-learn
omegaconf==2.0.1rc12
hydra-core==1.0.0rc4
transformers>=2.11.0
transformers>=2.11.0

0 comments on commit 7963b4c

Please sign in to comment.