-
-
Notifications
You must be signed in to change notification settings - Fork 586
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
3,000 additions
and
0 deletions.
There are no files selected for viewing
123 changes: 123 additions & 0 deletions
123
459_YOLOv9-Wholebody25/yolov9/convert_int64_to_int32.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import onnx | ||
|
||
from onnx import helper as h | ||
from onnx import checker as ch | ||
from onnx import TensorProto, GraphProto | ||
from onnx import numpy_helper as nph | ||
|
||
import numpy as np | ||
from collections import OrderedDict | ||
|
||
import logging as log | ||
import typer | ||
|
||
|
||
def make_param_dictionary(initializer): | ||
params = OrderedDict() | ||
for data in initializer: | ||
params[data.name] = data | ||
return params | ||
|
||
|
||
def convert_params_to_int32(params_dict): | ||
converted_params = [] | ||
for param in params_dict: | ||
data = params_dict[param] | ||
if data.data_type == TensorProto.INT64: | ||
data_cvt = nph.to_array(data).astype(np.int32) | ||
data = nph.from_array(data_cvt, data.name) | ||
converted_params += [data] | ||
return converted_params | ||
|
||
|
||
def convert_constant_nodes_to_int32(nodes): | ||
""" | ||
convert_constant_nodes_to_int32 Convert Constant nodes to INT32. If a constant node has data type INT64, a new version of the | ||
node is created with INT32 data type and stored. | ||
Args: | ||
nodes (list): list of nodes | ||
Returns: | ||
list: list of new nodes all with INT32 constants. | ||
""" | ||
new_nodes = [] | ||
for node in nodes: | ||
if ( | ||
node.op_type == "Constant" | ||
and node.attribute[0].t.data_type == TensorProto.INT64 | ||
): | ||
data = nph.to_array(node.attribute[0].t).astype(np.int32) | ||
new_t = nph.from_array(data) | ||
new_node = h.make_node( | ||
"Constant", | ||
inputs=[], | ||
outputs=node.output, | ||
name=node.name, | ||
value=new_t, | ||
) | ||
new_nodes += [new_node] | ||
else: | ||
new_nodes += [node] | ||
|
||
return new_nodes | ||
|
||
|
||
def convert_model_to_int32(model_path: str, out_path: str): | ||
""" | ||
convert_model_to_int32 Converts ONNX model with INT64 params to INT32 params.\n | ||
Args:\n | ||
model_path (str): path to original ONNX model.\n | ||
out_path (str): path to save converted model. | ||
""" | ||
log.info("ONNX INT64 --> INT32 Converter") | ||
log.info(f"Loading Model: {model_path}") | ||
# * load model. | ||
model = onnx.load_model(model_path) | ||
ch.check_model(model) | ||
# * get model opset version. | ||
opset_version = model.opset_import[0].version | ||
domain: str = model.domain | ||
ir_version: int = model.ir_version | ||
meta_data = {'domain': domain, 'ir_version': ir_version} | ||
metadata_props = None | ||
if hasattr(model, 'metadata_props'): | ||
metadata_props = model.metadata_props | ||
graph = model.graph | ||
# * The initializer holds all non-constant weights. | ||
init = graph.initializer | ||
# * collect model params in a dictionary. | ||
params_dict = make_param_dictionary(init) | ||
log.info("Converting INT64 model params to INT32...") | ||
# * convert all INT64 aprams to INT32. | ||
converted_params = convert_params_to_int32(params_dict) | ||
log.info("Converting constant INT64 nodes to INT32...") | ||
new_nodes = convert_constant_nodes_to_int32(graph.node) | ||
|
||
graph_name = f"{graph.name}-int32" | ||
log.info("Creating new graph...") | ||
# * create a new graph with converted params and new nodes. | ||
graph_int32 = h.make_graph( | ||
new_nodes, | ||
graph_name, | ||
graph.input, | ||
graph.output, | ||
initializer=converted_params, | ||
) | ||
log.info("Creating new int32 model...") | ||
model_int32 = h.make_model(graph_int32, producer_name="onnx-typecast") | ||
model_int32.opset_import[0].version = opset_version | ||
model_int32.domain = domain | ||
model_int32.ir_version = ir_version | ||
if metadata_props is not None: | ||
model_int32.metadata_props.extend(metadata_props) | ||
ch.check_model(model_int32) | ||
log.info(f"Saving converted model as: {out_path}") | ||
onnx.save_model(model_int32, out_path) | ||
log.info(f"Done Done London. 🎉") | ||
return | ||
|
||
|
||
if __name__ == "__main__": | ||
typer.run(convert_model_to_int32) |
Oops, something went wrong.