From a28705c09f7be415fdd348a56cc1a300f9159a44 Mon Sep 17 00:00:00 2001 From: xinhe Date: Wed, 17 Aug 2022 22:38:07 +0800 Subject: [PATCH] enhance load API (#1162) --- neural_compressor/utils/pytorch.py | 16 +++++++++++++--- .../pytorch_adaptor/test_adaptor_pytorch.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/neural_compressor/utils/pytorch.py b/neural_compressor/utils/pytorch.py index 4219a1a1b4f..24f08fa2c9f 100644 --- a/neural_compressor/utils/pytorch.py +++ b/neural_compressor/utils/pytorch.py @@ -160,9 +160,19 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): weights_file = checkpoint_dir stat_dict = torch.load(weights_file) elif os.path.isdir(checkpoint_dir): - weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), - 'best_model.pt') - stat_dict = torch.load(weights_file) + try: + weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), + 'best_model.pt') + stat_dict = torch.load(weights_file) + except: + tune_cfg_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), + 'best_configure.yaml') + weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)), + 'best_model_weights.pt') + stat_dict = torch.load(weights_file) + with open(tune_cfg_file, 'r') as f: + tune_cfg = yaml.safe_load(f) + stat_dict['best_configure'] = tune_cfg else: logger.error("Unexpected checkpoint type:{}. \ Only file dir/path or state_dict is acceptable") diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch.py index 18986d19b71..06ef513c815 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch.py @@ -873,6 +873,23 @@ def test_fx_dynamic_quant(self): {'preserved_attributes': []} }) self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule)) + + # Test the functionality of older model saving type + state_dict = torch.load("./saved/best_model.pt") + tune_cfg = state_dict.pop('best_configure') + import yaml + with open("./saved/best_configure.yaml", 'w') as f: + yaml.dump(tune_cfg, f, default_flow_style=False) + torch.save(state_dict, "./saved/best_model_weights.pt") + os.remove('./saved/best_model.pt') + model_fx = load("./saved", model, + **{'prepare_custom_config_dict': \ + {'non_traceable_module_name': ['a']}, + 'convert_custom_config_dict': \ + {'preserved_attributes': []} + }) + self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule)) + # recover int8 model with only tune_cfg history_file = './saved/history.snapshot' model_fx_recover = recover(model, history_file, 0,