Skip to content

Commit

Permalink
fix dataset and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Feb 24, 2023
1 parent ff68b8d commit 4a08342
Show file tree
Hide file tree
Showing 17 changed files with 66 additions and 24 deletions.
4 changes: 2 additions & 2 deletions simvp/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) CAIRI AI Lab. All rights reserved

from .train import NodDistExperiment
from .train import NonDistExperiment

__all__ = ['NodDistExperiment']
__all__ = ['NonDistExperiment']
9 changes: 5 additions & 4 deletions simvp/api/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
has_nni = False


class NodDistExperiment(object):
class NonDistExperiment(object):
""" Experiment with non-dist PyTorch training and evaluation """

def __init__(self, args):
Expand Down Expand Up @@ -92,8 +92,9 @@ def _preparation(self):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
prefix = 'train' if not self.args.test else 'test'
logging.basicConfig(level=logging.INFO,
filename=osp.join(self.path, 'train_{}.log'.format(timestamp)),
filename=osp.join(self.path, '{}_{}.log'.format(prefix, timestamp)),
filemode='a', format='%(asctime)s - %(message)s')
# prepare data
self._get_data()
Expand Down Expand Up @@ -156,7 +157,7 @@ def train(self):
def vali(self, vali_loader):
preds, trues, val_loss = self.method.vali_one_epoch(self.vali_loader)

if self.args.dataname=='weather':
if 'weather' in self.args.dataname:
metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
else:
metric_list, spatial_norm = ['mse', 'mae'], False
Expand All @@ -170,7 +171,7 @@ def vali(self, vali_loader):

def test(self):
inputs, trues, preds = self.method.test_one_epoch(self.test_loader)
if self.args.dataname=='weather':
if 'weather' in self.args.dataname:
metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
else:
metric_list, spatial_norm = ['mse', 'mae', 'ssim', 'psnr'], False
Expand Down
2 changes: 1 addition & 1 deletion simvp/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def load_data(dataname, batch_size, val_batch_size, num_workers, data_root, **kw
return load_data(batch_size, val_batch_size, data_root, num_workers, pre_seq_length, aft_seq_length)
elif dataname == 'weather':
from .dataloader_weather import load_data
return load_data(batch_size, val_batch_size, data_root, num_workers)
return load_data(batch_size, val_batch_size, data_root, num_workers, **kwargs)
else:
raise ValueError(f'Dataname {dataname} is unsupported')
2 changes: 1 addition & 1 deletion simvp/datasets/dataloader_kitticaltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def load_data(batch_size, val_batch_size, data_root,
pin_memory=True, drop_last=True,
num_workers=num_workers)
dataloader_test = torch.utils.data.DataLoader(test_set,
batch_size=1, shuffle=False,
batch_size=val_batch_size, shuffle=False,
pin_memory=True, drop_last=True,
num_workers=num_workers)

Expand Down
2 changes: 1 addition & 1 deletion simvp/datasets/dataloader_moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def load_data(batch_size, val_batch_size, data_root,
pin_memory=True, drop_last=True,
num_workers=num_workers)
dataloader_test = torch.utils.data.DataLoader(test_set,
batch_size=1, shuffle=False,
batch_size=val_batch_size, shuffle=False,
pin_memory=True, drop_last=True,
num_workers=num_workers)

Expand Down
2 changes: 1 addition & 1 deletion simvp/datasets/dataloader_taxibj.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_data(batch_size, val_batch_size, data_root,
pin_memory=True, drop_last=True,
num_workers=num_workers)
dataloader_test = torch.utils.data.DataLoader(test_set,
batch_size=1, shuffle=False,
batch_size=val_batch_size, shuffle=False,
pin_memory=True, drop_last=True,
num_workers=num_workers)

Expand Down
5 changes: 3 additions & 2 deletions simvp/datasets/dataloader_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def load_data(batch_size,
test_time=['2017', '2018'],
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
step=1):
step=1,
**kwargs):

weather_dataroot = osp.join(data_root, 'weather')

Expand Down Expand Up @@ -191,7 +192,7 @@ def load_data(batch_size,
pin_memory=True, drop_last=True,
num_workers=num_workers)
dataloader_test = torch.utils.data.DataLoader(test_set,
batch_size=1, shuffle=False,
batch_size=val_batch_size, shuffle=False,
pin_memory=True, drop_last=True,
num_workers=num_workers)

Expand Down
10 changes: 7 additions & 3 deletions simvp/datasets/dataset_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@
'aft_seq_length': 1,
'total_length': 11
},
'weather':{
**dict.fromkeys(['weather', 'weather_t2m'], {
'in_shape': [12, 1, 32, 64],
'pre_seq_length': 12,
'aft_seq_length': 12,
'total_length': 24
}
'total_length': 24,
'data_name': 't2m',
'train_time': ['2010', '2015'],
'val_time': ['2016', '2016'],
'test_time': ['2017', '2018']
})
}
4 changes: 2 additions & 2 deletions simvp/methods/crevnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, **kwargs)
num_updates += 1
loss_mean += loss.item()
losses_m.update(loss.item(), batch_x.size(0))
self.scheduler.step()
self.scheduler2.step()
self.scheduler.step(epoch)
self.scheduler2.step(epoch)
train_pbar.set_description('train loss: {:.4f}'.format(
loss.item() / (self.args.pre_seq_length + self.args.aft_seq_length)))

Expand Down
2 changes: 1 addition & 1 deletion simvp/methods/mau.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, eta, **kw
loss_mean += loss.item()
losses_m.update(loss.item(), batch_x.size(0))

self.scheduler.step()
self.scheduler.step(epoch)

train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))

Expand Down
2 changes: 1 addition & 1 deletion simvp/methods/phydnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, **kwargs)
num_updates += 1
loss_mean += loss.item()
losses_m.update(loss.item(), batch_x.size(0))
self.scheduler.step()
self.scheduler.step(epoch)
train_pbar.set_description('train loss: {:.4f}'.format(
loss.item() / (self.args.pre_seq_length + self.args.aft_seq_length)))

Expand Down
2 changes: 1 addition & 1 deletion simvp/methods/predrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, eta, **kw
img_gen, loss = self.model(ims, real_input_flag)
loss.backward()
self.model_optim.step()
self.scheduler.step()
self.scheduler.step(epoch)

num_updates += 1
loss_mean += loss.item()
Expand Down
2 changes: 1 addition & 1 deletion simvp/methods/predrnnv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, eta, **kw
num_updates += 1
loss_mean += loss.item()
losses_m.update(loss.item(), batch_x.size(0))
self.scheduler.step()
self.scheduler.step(epoch)
train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))

if hasattr(self.model_optim, 'sync_lookahead'):
Expand Down
2 changes: 1 addition & 1 deletion simvp/methods/simvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, **kwargs)
loss = self.criterion(pred_y, batch_y)
loss.backward()
self.model_optim.step()
self.scheduler.step()
self.scheduler.step(epoch)

num_updates += 1
loss_mean += loss.item()
Expand Down
1 change: 1 addition & 0 deletions simvp/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def create_parser():
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--fps', action='store_true', default=False,
help='Whether to measure inference speed (FPS)')
parser.add_argument('--test', action='store_true', default=False, help='Only performs testing')

# dataset parameters
parser.add_argument('--batch_size', '-b', default=16, type=int, help='Training batch size')
Expand Down
35 changes: 35 additions & 0 deletions tools/non_dist_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) CAIRI AI Lab. All rights reserved

import os.path as osp
import warnings
warnings.filterwarnings('ignore')

from simvp.api import NonDistExperiment
from simvp.utils import create_parser, load_config, update_config

try:
import nni
has_nni = True
except ImportError:
has_nni = False


if __name__ == '__main__':
args = create_parser().parse_args()
config = args.__dict__

if has_nni:
tuner_params = nni.get_next_parameter()
config.update(tuner_params)

assert args.config_file is not None, "Config file is required for testing"
config = update_config(config, load_config(args.config_file),
exclude_keys=['batch_size', 'val_batch_size'])
config['test'] = True

exp = NonDistExperiment(args)

print('>'*35 + ' testing ' + '<'*35)
mse = exp.test()
if has_nni:
nni.report_final_result(mse)
4 changes: 2 additions & 2 deletions tools/non_dist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
warnings.filterwarnings('ignore')

from simvp.api import NodDistExperiment
from simvp.api import NonDistExperiment
from simvp.utils import create_parser, load_config, update_config

try:
Expand All @@ -27,7 +27,7 @@
config = update_config(config, load_config(cfg_path),
exclude_keys=['batch_size', 'val_batch_size'])

exp = NodDistExperiment(args)
exp = NonDistExperiment(args)
print('>'*35 + ' training ' + '<'*35)
exp.train()

Expand Down

0 comments on commit 4a08342

Please sign in to comment.