Skip to content

Commit

Permalink
Reduce memory consumption in ONNXRT adaptor (#1266)
Browse files Browse the repository at this point in the history
* reduce memory consumption

Signed-off-by: yuwenz <[email protected]>
  • Loading branch information
yuwenzho authored and mengniwang95 committed Nov 10, 2023
1 parent 5ba9efe commit f64833d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 19 deletions.
16 changes: 7 additions & 9 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,13 +712,7 @@ def _detect_domain(self, model):
# 2. according to input
# typically, NLP models have multiple inputs,
# and the dimension of each input is usually 2 (batch_size, max_seq_len)
if not model.is_large_model:
sess = ort.InferenceSession(model.model.SerializeToString(), providers=["CPUExecutionProvider"])
elif model.model_path is not None: # pragma: no cover
sess = ort.InferenceSession(model.model_path, providers=["CPUExecutionProvider"])
else: # pragma: no cover
assert False, "Please use model path instead of onnx model object to quantize."
input_shape_lens = [len(input.shape) for input in sess.get_inputs()]
input_shape_lens = [len(inp.type.tensor_type.shape.dim) for inp in model.model.graph.input]
if len(input_shape_lens) > 1 and all(shape_len == 2 for shape_len in input_shape_lens):
is_nlp = True

Expand Down Expand Up @@ -778,11 +772,15 @@ def _pre_optimize(self, model, level=1):

sess_options.register_custom_ops_library(get_library_path())
if not model.is_large_model:
ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"])
sess = ort.InferenceSession(
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)
elif model.model_path is not None: # pragma: no cover
ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
model.model = onnx.ModelProto() # clean memory for large model
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
else: # pragma: no cover
logger.warning("Please use model path instead of onnx model object to quantize")
del sess

tmp_model = onnx.load(sess_options.optimized_model_filepath, load_external_data=False)

Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
"DmlExecutionProvider": "onnxrt_dml_ep",
}

MAXIMUM_PROTOBUF = 2147483648


def dtype_to_name(dtype_mapping, dtype):
"""Map data type and its string representation."""
Expand Down
35 changes: 25 additions & 10 deletions neural_compressor/model/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import logging
import os
import sys
from pathlib import Path

from neural_compressor.adaptor.ox_utils.util import MAXIMUM_PROTOBUF
from neural_compressor.model.base_model import BaseModel
from neural_compressor.utils.utility import LazyImport

Expand All @@ -41,16 +43,9 @@ def __init__(self, model, **kwargs):
"""
self._model = model if not isinstance(model, str) else onnx.load(model)
self._model_path = None if not isinstance(model, str) else model
self._is_large_model = False
try:
ort.InferenceSession(self._model.SerializeToString(), providers=["CPUExecutionProvider"])
except Exception as e: # pragma: no cover
if self._model_path is not None:
ort.InferenceSession(self._model_path, providers=["CPUExecutionProvider"])
self._is_large_model = True
else:
logger.warning("Please use model path instead of onnx model object to quantize")

self._is_large_model = self.check_large_model()
if self._is_large_model and self._model_path is None:
logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize")
self._config = None
if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()):
from transformers import PretrainedConfig
Expand All @@ -66,6 +61,26 @@ def __init__(self, model, **kwargs):
self._get_graph_info()
self._q_config = None

def check_large_model(self):
"""Check model > 2GB."""
init_size = 0
for init in self._model.graph.initializer:
# if initializer has external data location, return True
if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL:
return True
# if raise error of initializer size > 2GB, return True
try:
init_bytes = init.SerializeToString()
init_size += sys.getsizeof(init_bytes)
except Exception as e:
if "exceeds maximum protobuf size of 2GB" in str(e):
return True
else: # pragma: no cover
raise e
if init_size > MAXIMUM_PROTOBUF:
return True
return False

@property
def is_large_model(self):
"""Check the onnx model is over 2GB."""
Expand Down
48 changes: 48 additions & 0 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def setUp(self):
def tearDownClass(self):
shutil.rmtree("./gptj", ignore_errors=True)
shutil.rmtree("./hf_test", ignore_errors=True)
os.remove("model.onnx")

def test_hf_model(self):
from optimum.onnxruntime import ORTModelForCausalLM
Expand Down Expand Up @@ -407,6 +408,53 @@ def test_remove_unused_nodes(self):
self.model.remove_unused_nodes()
self.assertEqual(len(self.model.nodes()), 6)

def test_check_large_model(self):
import onnx
import torch
import torch.nn as nn

from neural_compressor.model.onnx_model import ONNXModel

class Net(nn.Module):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.fc = nn.Linear(in_features, out_features)

def forward(self, x):
x = self.fc(x)
return x

# model > 2GB
model = Net(512, 1024 * 1024)
input = torch.randn(512, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model) # pass ModelProto
self.assertTrue(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertTrue(model.check_large_model())

model = onnx.load("model.onnx", load_external_data=False) # not load init
model = ONNXModel(model)
self.assertTrue(model.check_large_model())

# model < 2GB
model = Net(10, 10 * 10)
input = torch.randn(10, requires_grad=True)
with torch.no_grad():
torch.onnx.export(model, (input,), "model.onnx", do_constant_folding=True, opset_version=13)
model = onnx.load("model.onnx")
model = ONNXModel(model) # pass ModelProto
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx") # pass string
self.assertFalse(model.check_large_model())

model = ONNXModel("model.onnx", load_external_data_for_model=False) # not load init
self.assertFalse(model.check_large_model())


if __name__ == "__main__":
unittest.main()

0 comments on commit f64833d

Please sign in to comment.