Skip to content

Commit

Permalink
Fix DLRM OOM issue (#299)
Browse files Browse the repository at this point in the history
Signed-off-by: changwa1 <[email protected]>
  • Loading branch information
changwangss authored Dec 20, 2022
1 parent 3e9c291 commit ff17257
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
args.print_freq = ld_nbatches
args.test_freq = 0

del ld_model
del(ld_model)

print(
"Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def loss_fn_wrap(Z, T, use_gpu, device):
args.print_freq = ld_nbatches
args.test_freq = 0

del ld_model
del(ld_model)

print(
"Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def run():
)
)
print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))
del ld_model
del(ld_model)

ext_dist.barrier()
print("time/loss/accuracy (if enabled):")
Expand Down
10 changes: 9 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2584,7 +2584,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
self.q_dataloader.batch(batch_size)
logger.info('Recovery `calibration.dataloader.batchsize` {} according \
to config.yaml' .format(batch_size))
del init_model
del(init_model)
with open(self.ipex_config_path, 'r') as f:
self.cfgs = json.load(f)
if self.version.release < Version("1.12.0").release:
Expand Down Expand Up @@ -2776,6 +2776,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx
try:
q_model = copy.deepcopy(model)
q_model.fp32_model = model.fp32_model
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))
Expand Down Expand Up @@ -2983,6 +2984,13 @@ def _pre_hook_for_qat(self, dataloader=None):
# so set it to None.
example_inputs = None

# For export API, deepcopy fp32_model
try:
self.model.fp32_model = copy.deepcopy(self.model.fp32_model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))

if self.sub_module_list is None:
if self.version.release >= Version("1.13.0").release: # pragma: no cover
# pylint: disable=E1123
Expand Down
17 changes: 11 additions & 6 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ def __init__(self, model, **kwargs):
self.q_config = None
self._workspace_path = ''
self.is_quantized = False
try:
self.fp32_model = copy.deepcopy(model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".format(
repr(e)))
self.fp32_model = model
self.fp32_model = model
self.kwargs = kwargs if kwargs else None

def __repr__(self):
Expand Down Expand Up @@ -93,6 +88,16 @@ def model(self, model):
""" Setter to model """
self._model = model

@property
def fp32_model(self):
""" Getter to model """
return self._fp32_model

@fp32_model.setter
def fp32_model(self, fp32_model):
""" Setter to model """
self._fp32_model = fp32_model

def register_forward_pre_hook(self):
self.handles.append(
self._model.register_forward_pre_hook(self.generate_forward_pre_hook()))
Expand Down

0 comments on commit ff17257

Please sign in to comment.