Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Jul 23, 2024
2 parents 237073a + be276c6 commit 1e142b0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 53 deletions.
63 changes: 34 additions & 29 deletions onnx_neural_compressor/algorithms/layer_wise/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def layer_wise_quant(
Returns:
_type_: _description_
"""
logger.warning(
"Layer-wise quantization requires data_type info for some tensors. "
"We will try to infer the data_type automatically if it doesn't exist."
"You can use model with symbolic shape inference before layer-wise quantization as well like follows:\n"
"import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n"
"model = onnx.load(your_model_path)\n"
"out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n"
"onnx.save_model(out, infer_shape_model_path, save_as_external_data=True)\n"
)

if not isinstance(model, onnx_model.ONNXModel):
model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False)

Expand Down Expand Up @@ -80,7 +90,7 @@ def layer_wise_quant(
split_model = model_to_split.pop(0)
split_node = split_nodes.pop(0)
if require_data_reader:
current_data_reader = lwq_data_reader.pop(0)
complete_data_reader = lwq_data_reader.pop(0)

# if no remaining split nodes, it means this is the last split, and the two split models will be saved.
save_both_split_models = True if len(split_nodes) == 0 else False
Expand All @@ -95,17 +105,22 @@ def layer_wise_quant(
model_to_split.append(split_model_part_2)

logger.info("Quantize split model {}".format(split_idx))

if require_data_reader:
# process data_reader for current split and next split

current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_1.model, current_data_reader, data_reader
split_model_part_1.model, complete_data_reader
)
# next_data_reader contains split_model_part_1 output data
next_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path, current_data_reader, providers

# complete_data_reader contains split_model_part_1 output data
complete_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path,
[i.name for i in split_model_part_2.model.graph.input],
complete_data_reader,
providers,
)
lwq_data_reader.append(next_data_reader)

lwq_data_reader.append(complete_data_reader)

# perform quantization
split_model_part_1_quantized = quant_func(
Expand All @@ -123,7 +138,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_1_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand All @@ -148,7 +163,7 @@ def layer_wise_quant(
# process data_reader for current split
current_data_reader = lwq_data_reader.pop(0)
current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_2.model, current_data_reader, data_reader
split_model_part_2.model, complete_data_reader
)

# perform quantization
Expand All @@ -167,7 +182,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_2_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand Down Expand Up @@ -206,22 +221,19 @@ def rewind(self):
def _filter_data_reader_for_current_split_model(
model: onnx.ModelProto,
current_data_reader: data_reader.CalibrationDataReader,
data_reader: data_reader.CalibrationDataReader,
):
"""Filter data reader to remove data that is not in model input.
Args:
model (onnx.ModelProto): onnx model.
current_data_reader (data_reader.CalibrationDataReader): data reader of current split model.
data_reader (data_reader.CalibrationDataReader): data reader of the original model.
Returns:
data_reader.CalibrationDataReader: filtered data reader.
"""
filter_inputs = []
input_names = [input.name for input in model.graph.input]
current_data_reader.rewind()
data_reader.rewind()

while True:
inputs = current_data_reader.get_next()
Expand All @@ -232,22 +244,12 @@ def _filter_data_reader_for_current_split_model(
}
filter_inputs.append(filter_input)

idx = 0
while True:
inputs = data_reader.get_next()
if not inputs:
break
filter_input = {
input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names
}
if len(filter_input) > 0:
filter_inputs[idx].update(filter_input)
idx += 1
return DataReader(filter_inputs)


def _prepare_data_reader_for_next_split_model(
model_path: str,
next_model_input_names: list,
data_reader: data_reader.CalibrationDataReader,
providers: List[str] = ["CPUExecutionProvider"],
):
Expand All @@ -263,16 +265,19 @@ def _prepare_data_reader_for_next_split_model(
Returns:
data_reader.CalibrationDataReader: data reader for next split model.
"""
data_reader = copy.deepcopy(data_reader)

data_reader.rewind()
data_reader_for_next_split_model = []
session = ort.InferenceSession(model_path, providers=providers)
output_names = [output.name for output in session.get_outputs()]
input_names = [input.name for input in session.get_inputs()]
while True:
inputs = data_reader.get_next()
if not inputs:
break
out = session.run(None, inputs)
inputs.update({name: value for name, value in zip(output_names, out)})
data_reader_for_next_split_model.append(inputs)
out = session.run(None, {name: inputs[name] for name in input_names})
filter_input = {name: value for name, value in zip(output_names, out)}
for name, value in inputs.items():
if name in next_model_input_names and name not in filter_input:
filter_input[name] = value
data_reader_for_next_split_model.append(filter_input)
return DataReader(data_reader_for_next_split_model)
1 change: 1 addition & 0 deletions onnx_neural_compressor/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def gptq_quantize(
if return_modelproto:
return model.model
else:
model.save(model.model_path + "_quant.onnx")
return model


Expand Down
1 change: 1 addition & 0 deletions onnx_neural_compressor/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def rtn_quantize(
if return_modelproto:
return model.model
else:
model.save(model.model_path + "_quant.onnx")
return model


Expand Down
54 changes: 35 additions & 19 deletions onnx_neural_compressor/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def is_graph_output(self, name):
def save(self, root):
"""Save ONNX model."""
if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]):
raise ValueError('"root" directory does not exists.')
os.mkdir(os.path.split(root)[0])
if self.is_large_model: # pragma: no cover
onnx.external_data_helper.load_external_data_for_model(self.model, os.path.split(self._model_path)[0])
onnx.save_model(
Expand All @@ -248,7 +248,9 @@ def save(self, root):
else:
onnx.save(self.model, root)

if self._config is not None:
self._model_path = root

if self._config is not None and not os.path.exists(os.path.join(os.path.split(root)[0], "config.json")):
model_type = "" if not hasattr(self._config, "model_type") else getattr(self._config, "model_type")
setattr(self._config.__class__, "model_type", model_type)
output_config_file = pathlib.Path(root).parent.joinpath("config.json").as_posix()
Expand Down Expand Up @@ -897,30 +899,44 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo
split_model_part_2.CopyFrom(self.model)
split_model_part_2.graph.ClearField("node")

split_node_output = None
part_idx = 1
split_node = None
nodes = []
for node in self.model.graph.node:
if part_idx == 1:
split_model_part_1.graph.node.append(node)
elif part_idx == 2:
split_model_part_2.graph.node.append(node)
nodes.append(node)

if node.name == split_node_name:
split_node_output = node.output
part_idx = 2
split_node = node
break

assert len(split_node_output) == 1, (
assert len(split_node.output) == 1, (
"Only support split at node with 1 output tensor, while "
"current split node {} has {} output tensors".format(split_node_name, len(split_node_output))
"current split node {} has {} output tensors".format(split_node_name, len(split_node.output))
)
split_tensor_name = split_node_output[0]
split_tensor_name = split_node.output[0]

split_tensor = self._build_input_output_tensor(split_tensor_name, value_info)

split_model_part_1.graph.node.extend(nodes)
split_model_part_1.graph.output.append(split_tensor)
split_model_part_2.graph.input.append(split_tensor)

split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True)

# remove isolated graphs which are not related to the split_node
output_name_to_node = split_model_part_1.output_name_to_node()
valid_nodes = [split_node]
while len(valid_nodes) > 0:
node = valid_nodes.pop(0)
for inp in node.input:
if inp in output_name_to_node:
valid_nodes.append(output_name_to_node[inp])
if node in nodes:
nodes.remove(node)
split_model_part_1.remove_nodes(nodes)

for node in self.model.graph.node:
if node not in split_model_part_1.nodes():
split_model_part_2.graph.node.append(node)

split_model_part_2.graph.input.append(split_tensor)
split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True)

# remove unused input & output
Expand Down Expand Up @@ -994,14 +1010,14 @@ def _remove_unused_input_output(self):
"""Remove unused input & output for split model."""
remove_outputs = []
remove_inputs = []
if len(self._input_name_to_nodes) == 0:
self._input_name_to_nodes = self.input_name_to_nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
for output in self.model.graph.output:
if output.name not in self._output_name_to_node.keys():
if output.name not in output_name_to_node.keys():
remove_outputs.append(output)

for input in self.model.graph.input:
if input.name not in self._input_name_to_nodes.keys():
if input.name not in input_name_to_nodes.keys():
remove_inputs.append(input)

for output in remove_outputs:
Expand Down
6 changes: 1 addition & 5 deletions test/utils/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,7 @@ def test_save(self):
save_path = ".large_model_save.onnx"
model.save(save_path)

# test save path does not exist
with self.assertRaises(ValueError) as cm:
save_path = "./gptj_output/test.onnx"
model.save(save_path)
self.assertEqual(str(cm.exception), '"root" directory does not exists.')
self.assertEqual(model.model_path, ".large_model_save.onnx")

def test_get_initializer_share_num(self):
model = onnx_model.ONNXModel(self.matmul_add_model)
Expand Down

0 comments on commit 1e142b0

Please sign in to comment.