Skip to content

Commit

Permalink
enhance load API (#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he authored Aug 17, 2022
1 parent c915509 commit a28705c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
16 changes: 13 additions & 3 deletions neural_compressor/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,19 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
weights_file = checkpoint_dir
stat_dict = torch.load(weights_file)
elif os.path.isdir(checkpoint_dir):
weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)),
'best_model.pt')
stat_dict = torch.load(weights_file)
try:
weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)),
'best_model.pt')
stat_dict = torch.load(weights_file)
except:
tune_cfg_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)),
'best_configure.yaml')
weights_file = os.path.join(os.path.abspath(os.path.expanduser(checkpoint_dir)),
'best_model_weights.pt')
stat_dict = torch.load(weights_file)
with open(tune_cfg_file, 'r') as f:
tune_cfg = yaml.safe_load(f)
stat_dict['best_configure'] = tune_cfg
else:
logger.error("Unexpected checkpoint type:{}. \
Only file dir/path or state_dict is acceptable")
Expand Down
17 changes: 17 additions & 0 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,23 @@ def test_fx_dynamic_quant(self):
{'preserved_attributes': []}
})
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))

# Test the functionality of older model saving type
state_dict = torch.load("./saved/best_model.pt")
tune_cfg = state_dict.pop('best_configure')
import yaml
with open("./saved/best_configure.yaml", 'w') as f:
yaml.dump(tune_cfg, f, default_flow_style=False)
torch.save(state_dict, "./saved/best_model_weights.pt")
os.remove('./saved/best_model.pt')
model_fx = load("./saved", model,
**{'prepare_custom_config_dict': \
{'non_traceable_module_name': ['a']},
'convert_custom_config_dict': \
{'preserved_attributes': []}
})
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))

# recover int8 model with only tune_cfg
history_file = './saved/history.snapshot'
model_fx_recover = recover(model, history_file, 0,
Expand Down

0 comments on commit a28705c

Please sign in to comment.