Skip to content

Commit

Permalink
Release calibration model memory to fix OOM for tensorflow (#542)
Browse files Browse the repository at this point in the history
Signed-off-by: Lv, Liang1 <[email protected]>
  • Loading branch information
lvliang-intel authored Feb 14, 2023
1 parent 5acea85 commit ad0f1e0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
27 changes: 17 additions & 10 deletions neural_compressor/adaptor/tf_utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,6 @@ def __init__(self,
self._sampling_model.output_tensor_names = self.output_tensor_names
self._sampling_model.input_tensor_names = self.input_tensor_names

self._itex_model = Model(self.model._model, **self.model.kwargs)
self._itex_model.graph_def = self.model.graph_def
self._itex_model.output_tensor_names = self.output_tensor_names
self._itex_model.input_tensor_names = self.input_tensor_names
self._tmp_graph_def = copy.deepcopy(self.model.graph_def)
self.new_api = new_api #bool(version1_gte_version2(tf.version.VERSION, '2.8.0'))
self.performance_only = performance_only
Expand Down Expand Up @@ -347,11 +343,11 @@ def convert(self):

if self.itex_mode:
host_const_graph_def = \
PostHostConstConverter(self._itex_model.graph_def).do_transformation()
PostHostConstConverter(self._tmp_model.graph_def).do_transformation()
host_const_graph_def.library.CopyFrom(self.model.graph_def.library)
self._itex_model.graph_def = host_const_graph_def
self._tmp_model.graph_def = host_const_graph_def

return self._itex_model
return self._tmp_model

if self.exclude_node_names:
self.bf16_ops.extend(self.exclude_node_names)
Expand Down Expand Up @@ -490,7 +486,6 @@ def quantize(self):
sampling_graph_def = copy.deepcopy(self._fp32_model.graph_def)
# TODO: this is a workaround to make Min/Max node be completly eliminated in int8 graph
# after enabling pad+conv2d in new API.

non_pad_ops = list(list(set(self.fp32_ops).union(set(self.bf16_ops))))
sampling_graph_def = FusePadWithFP32Conv2DOptimizer(
sampling_graph_def,
Expand All @@ -512,6 +507,12 @@ def quantize(self):
self._inference(self._sampling_model)
self._calibration_data = Helper.gen_valid_sampling_log(tmp_dump_file)

del output_tensor_names
del sampling_graph_def
del self._sampling_model
import gc
gc.collect()

if len(self._calibration_data) > 0:
self._freeze_requantization_ranges(self._kl_op_dict)
self._fuse_requantize_with_fused_quantized_node()
Expand Down Expand Up @@ -807,6 +808,12 @@ def _insert_qdq_pairs(self):
self._inference(self._sampling_model)
self._calibration_data = Helper.gen_valid_sampling_log(tmp_dump_file)

del sampling_graph_def
del output_tensor_names
del self._sampling_model
import gc
gc.collect()

# Insert QDQ pattern
self._tmp_graph_def = GenerateGraphWithQDQPattern(
self._tmp_graph_def, self._calibration_data, self.op_wise_config,
Expand Down Expand Up @@ -847,8 +854,8 @@ def _convert_qdq(self):
self._tmp_graph_def = MergeDuplicatedQDQOptimizer(self._tmp_graph_def).do_transformation()

self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)
self._itex_model.graph_def = self._tmp_graph_def
self._itex_model.graph_def.library.CopyFrom(self.model.graph_def.library)
self._tmp_model.graph_def = self._tmp_graph_def
self._tmp_model.graph_def.library.CopyFrom(self.model.graph_def.library)
else:
self._tmp_graph_def, exclude_node_names = OptimizeQDQGraph(self._tmp_graph_def,
self._tmp_model.input_node_names,
Expand Down
9 changes: 9 additions & 0 deletions neural_compressor/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,12 @@ def set_tensorboard(tensorboard: bool):
"""Set the tensorboard in config."""
from neural_compressor.config import options
options.tensorboard = tensorboard

def show_memory_info(hint):
"""Show process full memory."""
pid = os.getpid()
p = psutil.Process(pid)

info = p.memory_full_info()
memory = info.uss / 1024. / 1024
print('{} memory used: {} MB'.format(hint, memory))

0 comments on commit ad0f1e0

Please sign in to comment.