Skip to content

Commit

Permalink
Merge pull request #4 from PINTO0309/fix_irversion
Browse files Browse the repository at this point in the history
Fix to preserve domain and ir_version
  • Loading branch information
PINTO0309 authored Apr 30, 2024
2 parents ccaaa51 + 8516417 commit cbd28bd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ $ sog4onnx -h
usage: sog4onnx [-h]
--ot OP_TYPE
--os OPSET
--ir IR_VERSION
--on OP_NAME
[-iv NAME TYPE VALUE]
[-ov NAME TYPE VALUE]
Expand All @@ -57,6 +58,9 @@ optional arguments:
-os OPSET, --opset OPSET
ONNX opset number.
-ir IR_VERSION, --ir_version IR_VERSION
ONNX ir_version number.
-on OP_NAME, --op_name OP_NAME
OP name.
Expand Down Expand Up @@ -112,6 +116,7 @@ Help on function generate in module sog4onnx.onnx_operation_generator:
generate(
op_type: str,
opset: int,
ir_version: int,
op_name: str,
input_variables: dict,
output_variables: dict,
Expand All @@ -134,6 +139,11 @@ generate(

e.g. 11

ir_version: int
ONNX ir_version number.

e.g. 9

op_name: str
OP name.

Expand Down
2 changes: 1 addition & 1 deletion sog4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sog4onnx.onnx_operation_generator import generate, main

__version__ = '1.0.16'
__version__ = '1.0.17'
16 changes: 15 additions & 1 deletion sog4onnx/onnx_operation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def generate(
op_type: str,
opset: int,
op_name: str,
ir_version: int = 9,
input_variables: Optional[OrderedDict] = None,
output_variables: Optional[OrderedDict] = None,
attributes: Optional[OrderedDict] = None,
Expand All @@ -89,6 +90,10 @@ def generate(
ONNX opset number.\n\
e.g. 11
ir_version: int
ONNX ir_version number.\n\
e.g. 9
op_name: str
OP name.
Expand Down Expand Up @@ -208,7 +213,7 @@ def generate(
outputs=output_gs_variables,
opset=opset,
)
single_op_graph = gs.export_onnx(graph)
single_op_graph = gs.export_onnx(graph, do_type_check=False, **{'ir_version': ir_version})
else:
graph_def = onnx.helper.make_graph(
nodes=[node],
Expand All @@ -223,6 +228,7 @@ def generate(
single_op_graph = onnx.helper.make_model(
graph=graph_def,
opset_imports=[opset_id_proto],
ir_version=ir_version,
)

# 4. Graph Check
Expand Down Expand Up @@ -270,6 +276,13 @@ def main():
required=True,
help='ONNX opset number.'
)
parser.add_argument(
'-ir',
'--ir_version',
type=int,
default=9,
help='ONNX ir_version number.'
)
parser.add_argument(
'-on',
'--op_name',
Expand Down Expand Up @@ -457,6 +470,7 @@ def main():
single_op_graph = generate(
op_type=args.op_type,
opset=args.opset,
ir_version=args.ir_version,
op_name=args.op_name,
input_variables=input_variables_tmp,
output_variables=output_variables_tmp,
Expand Down

0 comments on commit cbd28bd

Please sign in to comment.