Skip to content

Commit

Permalink
fix checkpoint and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Feb 26, 2023
1 parent 4a08342 commit c195aed
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 23 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
<img src="https://img.shields.io/badge/arXiv-2211.12509-b31b1b.svg?style=flat" /></a>
<a href="https://github.com/Westlake-AI/MogaNet/blob/main/LICENSE" alt="license">
<img src="https://img.shields.io/badge/license-Apache--2.0-%23B7A800" /></a>
<a href="https://simvpv2.readthedocs.io/en/latest/" alt="docs">
<img src="https://readthedocs.org/projects/simvpv2/badge/?version=latest" /></a>
</p>

[📘Documentation](https://simvpv2.readthedocs.io/en/latest/) |
[🛠️Installation](docs/en/install.md) |
[🚀Model Zoo](docs/en/model_zoos/video_benchmarks.md) |
[🆕News](docs/en/changelog.md)
Expand Down
47 changes: 35 additions & 12 deletions simvp/api/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from simvp.core import metric, Recorder
from simvp.methods import method_maps
from simvp.utils import (set_seed, print_log, output_namespace, check_dir,
get_dataset, measure_throughput)
get_dataset, measure_throughput, weights_to_cpu)

try:
import nni
Expand All @@ -29,6 +29,7 @@ def __init__(self, args):
self.config = self.args.__dict__
self.device = self._acquire_device()
self.args.method = self.args.method.lower()
self._epoch = 0

self._preparation()
print_log(output_namespace(self.args))
Expand Down Expand Up @@ -100,6 +101,11 @@ def _preparation(self):
self._get_data()
# build the method
self._build_method()
# resume traing
if self.args.auto_resume:
self.args.resume_from = osp.join(self.checkpoints_path, 'latest.pth')
if self.args.resume_from is not None:
self._load(name=self.args.resume_from)

def _build_method(self):
steps_per_epoch = len(self.train_loader)
Expand All @@ -111,23 +117,34 @@ def _get_data(self):
self.vali_loader = self.test_loader

def _save(self, name=''):
torch.save(self.method.model.state_dict(), osp.join(self.checkpoints_path, name + '.pth'))
fw = open(osp.join(self.checkpoints_path, name + '.pkl'), 'wb')
state = self.method.scheduler.state_dict()
pickle.dump(state, fw)

def _load(self, epoch):
self.method.model.load_state_dict(torch.load(osp.join(self.checkpoints_path, str(epoch) + '.pth')))
fw = open(osp.join(self.checkpoints_path, str(epoch) + '.pkl'), 'rb')
state = pickle.load(fw)
self.method.scheduler.load_state_dict(state)
checkpoint = {
'epoch': self._epoch + 1,
'optimizer': self.method.model_optim.state_dict(),
'state_dict': weights_to_cpu(self.method.model.state_dict()),
'scheduler': self.method.scheduler.state_dict()}
torch.save(checkpoint, osp.join(self.checkpoints_path, name + '.pth'))

def _load(self, name=''):
filename = name if osp.isfile(name) else osp.join(self.checkpoints_path, name + '.pth')
try:
checkpoint = torch.load(filename)
except:
return
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(f'No state_dict found in checkpoint file {filename}')
self.method.model.load_state_dict(checkpoint['state_dict'])
if checkpoint.get('epoch', None) is not None:
self._epoch = checkpoint['epoch']
self.method.model_optim.load_state_dict(checkpoint['optimizer'])
self.method.scheduler.load_state_dict(checkpoint['scheduler'])

def train(self):
recorder = Recorder(verbose=True)
num_updates = 0
# constants for other methods:
eta = 1.0 # PredRNN
for epoch in range(self.config['epoch']):
for epoch in range(self._epoch, self.config['epoch']):
loss_mean = 0.0

if self.args.method in ['simvp', 'crevnet', 'phydnet']:
Expand All @@ -139,6 +156,7 @@ def train(self):
else:
raise ValueError(f'Invalid method name {self.args.method}')

self._epoch = epoch
if epoch % self.args.log_step == 0:
cur_lr = self.method.current_lr()
cur_lr = sum(cur_lr) / len(cur_lr)
Expand All @@ -148,6 +166,7 @@ def train(self):
print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format(
epoch + 1, len(self.train_loader), cur_lr, loss_mean, vali_loss))
recorder(vali_loss, self.method.model, self.path)
self._save(name='latest')

if not check_dir(self.path): # exit training when work_dir is removed
assert False and "Exit training because work_dir is removed"
Expand All @@ -170,6 +189,10 @@ def vali(self, vali_loader):
return val_loss

def test(self):
if self.args.test:
best_model_path = osp.join(self.path, 'checkpoint.pth')
self.method.model.load_state_dict(torch.load(best_model_path))

inputs, trues, preds = self.method.test_one_epoch(self.test_loader)
if 'weather' in self.args.dataname:
metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
Expand Down
2 changes: 1 addition & 1 deletion simvp/datasets/dataloader_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def load_data(batch_size,
batch_size=batch_size, shuffle=True,
pin_memory=True, drop_last=True,
num_workers=num_workers)
dataloader_vali = torch.utils.data.DataLoader(validation_set,
dataloader_vali = torch.utils.data.DataLoader(test_set, # validation_set,
batch_size=val_batch_size, shuffle=False,
pin_memory=True, drop_last=True,
num_workers=num_workers)
Expand Down
4 changes: 2 additions & 2 deletions simvp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from .config_utils import Config, check_file_exist
from .main_utils import (set_seed, print_log, output_namespace, check_dir, get_dataset,
count_parameters, measure_throughput, load_config, update_config)
count_parameters, measure_throughput, load_config, update_config, weights_to_cpu)
from .parser import create_parser
from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch,
reshape_patch_back)

__all__ = [
'Config', 'check_file_exist', 'create_parser',
'set_seed', 'print_log', 'output_namespace', 'check_dir', 'get_dataset', 'count_parameters',
'measure_throughput', 'load_config', 'update_config',
'measure_throughput', 'load_config', 'update_config', 'weights_to_cpu',
'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back',
]
27 changes: 21 additions & 6 deletions simvp/utils/main_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import random
import torch.backends.cudnn as cudnn
from collections import OrderedDict
from .config_utils import Config


Expand Down Expand Up @@ -77,9 +78,7 @@ def measure_throughput(model, input_dummy):


def load_config(filename:str = None):
'''
load and print config
'''
"""load and print config"""
print('loading config from ' + filename + ' ...')
try:
configfile = Config(filename=filename)
Expand All @@ -91,9 +90,7 @@ def load_config(filename:str = None):


def update_config(args, config, exclude_keys=list()):
'''
update the args dict with a new config
'''
"""update the args dict with a new config"""
assert isinstance(args, dict) and isinstance(config, dict)
for k in config.keys():
if args.get(k, False):
Expand All @@ -104,3 +101,21 @@ def update_config(args, config, exclude_keys=list()):
else:
args[k] = config[k]
return args


def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
7 changes: 5 additions & 2 deletions simvp/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def create_parser():
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--fps', action='store_true', default=False,
help='Whether to measure inference speed (FPS)')
parser.add_argument('--resume_from', type=str, default=None, help='the checkpoint file to resume from')
parser.add_argument('--auto_resume', action='store_true', default=False,
help='When training was interupted, resume from the latest checkpoint')
parser.add_argument('--test', action='store_true', default=False, help='Only performs testing')

# dataset parameters
Expand Down Expand Up @@ -62,8 +65,8 @@ def create_parser():
parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
parser.add_argument('--lr_k_decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--warmup_lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--final_div_factor', type=float, default=1e4,
Expand Down

0 comments on commit c195aed

Please sign in to comment.