diff --git a/models/fixmatch/fixmatch.py b/models/fixmatch/fixmatch.py index d5c0787f..5536bf92 100644 --- a/models/fixmatch/fixmatch.py +++ b/models/fixmatch/fixmatch.py @@ -266,6 +266,10 @@ def load_model(self, load_path): eval_model.load_state_dict(checkpoint[key]) elif key == 'it': self.it = checkpoint[key] + elif key == 'scheduler': + self.scheduler.load_state_dict(checkpoint[key]) + elif key == 'optimizer': + self.optimizer.load_state_dict(checkpoint[key]) else: getattr(self, key).load_state_dict(checkpoint[key]) self.print_fn(f"Check Point Loading: {key} is LOADED")