Skip to content

Commit

Permalink
fix base_method
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Apr 19, 2023
1 parent 03c543b commit 5cc979f
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 59 deletions.
7 changes: 5 additions & 2 deletions openstl/api/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, args):
def _acquire_device(self):
"""Setup devices"""
if self.args.use_gpu:
self._use_gpu = True
if self.args.dist:
device = f'cuda:{self._rank}'
torch.cuda.set_device(self._rank)
Expand All @@ -61,6 +62,7 @@ def _acquire_device(self):
device = torch.device('cuda:0')
print('Use non-distributed mode with GPU:', device)
else:
self._use_gpu = False
device = torch.device('cpu')
print('Use CPU')
if self.args.dist:
Expand Down Expand Up @@ -173,7 +175,7 @@ def call_hook(self, fn_name: str) -> None:
def _get_hook_info(self):
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
for hook in self._hooks:
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
Expand Down Expand Up @@ -267,7 +269,6 @@ def train(self):

eta = 1.0 # PredRNN variants
for epoch in range(self._epoch, self._max_epochs):

num_updates, loss_mean, eta = self.method.train_one_epoch(self, self.train_loader,
epoch, num_updates, eta)

Expand All @@ -283,6 +284,8 @@ def train(self):
epoch + 1, len(self.train_loader), cur_lr, loss_mean.avg, vali_loss))
recorder(vali_loss, self.method.model, self.path)
self._save(name='latest')
if self._use_gpu and self.args.empty_cache:
torch.cuda.empty_cache()

if not check_dir(self.path): # exit training when work_dir is removed
assert False and "Exit training because work_dir is removed"
Expand Down
8 changes: 7 additions & 1 deletion openstl/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def load_data(dataname, batch_size, val_batch_size, num_workers, data_root, dist
pre_seq_length, aft_seq_length, distributed=distributed)
elif 'weather' in dataname: # 'weather', 'weather_t2m', etc.
from .dataloader_weather import load_data
return load_data(batch_size, val_batch_size, data_root, num_workers, distributed=distributed **kwargs)
data_split_pool = ['5_625', '2_8125', '1_40625']
data_split = '5_625'
for k in data_split_pool:
if dataname.find(k) != -1:
data_split = k
return load_data(batch_size, val_batch_size, data_root, num_workers,
distributed=distributed, data_split=data_split, **kwargs)
else:
raise ValueError(f'Dataname {dataname} is unsupported')
2 changes: 1 addition & 1 deletion openstl/datasets/dataloader_kitticaltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import Dataset
from skimage.transform import resize

from .utils import create_loader
from openstl.datasets.utils import create_loader

try:
import hickle as hkl
Expand Down
2 changes: 1 addition & 1 deletion openstl/datasets/dataloader_kth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import Dataset
from PIL import Image

from .utils import create_loader
from openstl.datasets.utils import create_loader

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion openstl/datasets/dataloader_moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.utils.data import Dataset

from .utils import create_loader
from openstl.datasets.utils import create_loader


def load_mnist(root):
Expand Down
2 changes: 1 addition & 1 deletion openstl/datasets/dataloader_taxibj.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from torch.utils.data import Dataset

from .utils import create_loader
from openstl.datasets.utils import create_loader


class TaxibjDataset(Dataset):
Expand Down
74 changes: 48 additions & 26 deletions openstl/datasets/dataloader_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os.path as osp
import torch
from torch.utils.data import Dataset
from .utils import create_loader
from openstl.datasets.utils import create_loader

try:
import xarray as xr
Expand Down Expand Up @@ -52,11 +52,12 @@ def xyz2latlon(x, y, z):
class ClimateDataset(Dataset):

def __init__(self, data_root, data_name, training_time,
idx_in, idx_out, step,
idx_in, idx_out, step, data_split='5_625',
mean=None, std=None,
transform_data=None, transform_labels=None):
super().__init__()
self.dataname = data_name
self.data_split = data_split
self.training_time = training_time
self.idx_in = np.array(idx_in)
self.idx_out = np.array(idx_out)
Expand All @@ -67,17 +68,25 @@ def __init__(self, data_root, data_name, training_time,
self.transform_labels = transform_labels

self.time = None
shape = int(32 * 5.625 / float(data_split.replace('_', '.')))
self.shape = (shape, shape * 2)

if isinstance(data_name, list):
data_name = data_name[0]

if data_name != 'uv10':
try:
# dataset = xr.open_mfdataset(
# data_root+'/{}/*.nc'.format(data_map[data_name]), combine='by_coords')
print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
dataset = xr.open_mfdataset(
data_root+'/{}/*.nc'.format(data_map[data_name]), combine='by_coords')
data_root+'/{}/*.nc'.format(data_map[data_name]), combine='by_coords', parallel=False, chunks={'time':168})
except AttributeError:
assert False and 'Please install the latest xarray, e.g.,' \
'pip install git+https://github.com/pydata/xarray/@v2022.03.0'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
assert False
dataset = dataset.sel(time=slice(*training_time))
dataset = dataset.isel(time=slice(None, -1, step))
if self.time is None:
Expand All @@ -89,10 +98,11 @@ def __init__(self, data_root, data_name, training_time,
lon, lat = np.meshgrid(
(dataset.lon-180) * d2r, dataset.lat*d2r)
x, y, z = latlon2xyz(lat, lon)
self.V = np.stack([x, y, z]).reshape(3, 32*64).T
self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T
# input_datasets.append(dataset.get(key).values[:, np.newaxis, :, :])
# self.data = np.concatenate(input_datasets, axis=1)
self.data = dataset.get(data_name).values[:, np.newaxis, :, :]

elif data_name == 'uv10':
input_datasets = []
for key in ['u10', 'v10']:
Expand All @@ -103,6 +113,9 @@ def __init__(self, data_root, data_name, training_time,
assert False and 'Please install the latest xarray, e.g.,' \
'pip install git+https://github.com/pydata/xarray/@v2022.03.0,' \
'pip install netcdf4 h5netcdf dask'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
assert False
dataset = dataset.sel(time=slice(*training_time))
dataset = dataset.isel(time=slice(None, -1, step))
if self.time is None:
Expand All @@ -114,7 +127,7 @@ def __init__(self, data_root, data_name, training_time,
lon, lat = np.meshgrid(
(dataset.lon-180) * d2r, dataset.lat*d2r)
x, y, z = latlon2xyz(lat, lon)
self.V = np.stack([x, y, z]).reshape(3, 32*64).T
self.V = np.stack([x, y, z]).reshape(3, self.shape[0]*self.shape[1]).T
input_datasets.append(dataset.get(key).values[:, np.newaxis, :, :])
self.data = np.concatenate(input_datasets, axis=1)

Expand Down Expand Up @@ -151,6 +164,7 @@ def load_data(batch_size,
val_batch_size,
data_root,
num_workers=4,
data_split='5_625',
data_name='t2m',
train_time=['1979', '2015'],
val_time=['2016', '2016'],
Expand All @@ -161,28 +175,30 @@ def load_data(batch_size,
distributed=False,
**kwargs):

weather_dataroot = osp.join(data_root, 'weather')
assert data_split in ['5_625', '2_8125', '1_40625']
_dataroot = osp.join(data_root, f'weather_{data_split}deg')
weather_dataroot = _dataroot if osp.exists(_dataroot) else osp.join(data_root, 'weather')

train_set = ClimateDataset(data_root=weather_dataroot,
data_name=data_name,
data_name=data_name, data_split=data_split,
training_time=train_time,
idx_in=idx_in,
idx_out=idx_out,
step=step)
vali_set = ClimateDataset(weather_dataroot,
data_name,
val_time,
idx_in,
idx_out,
step,
data_name=data_name, data_split=data_split,
training_time=val_time,
idx_in=idx_in,
idx_out=idx_out,
step=step,
mean=train_set.mean,
std=train_set.std)
test_set = ClimateDataset(weather_dataroot,
data_name,
test_time,
idx_in,
idx_out,
step,
data_name, data_split=data_split,
training_time=test_time,
idx_in=idx_in,
idx_out=idx_out,
step=step,
mean=train_set.mean,
std=train_set.std)

Expand All @@ -206,15 +222,21 @@ def load_data(batch_size,


if __name__ == '__main__':
dataloader_train, _, _ = load_data(batch_size=128,
val_batch_size=128,
data_root='../../data',
num_workers=4, data_name='t2m',
train_time=['1979', '2015'],
val_time=['2016', '2016'],
test_time=['2017', '2018'],
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], step=24)
dataloader_train, _, dataloader_test = \
load_data(batch_size=128,
val_batch_size=32,
data_root='../../data',
num_workers=2, data_name='t2m',
data_split='1_40625',
train_time=['1979', '2015'],
val_time=['2016', '2016'],
test_time=['2017', '2018'],
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], step=24)

for item in dataloader_train:
print(item[0].shape)
break
for item in dataloader_test:
print(item[0].shape)
break
32 changes: 32 additions & 0 deletions openstl/datasets/dataset_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,36 @@
'data_name': 'tcc',
'train_time': ['2010', '2015'], 'val_time': ['2016', '2016'], 'test_time': ['2017', '2018'],
},
'weather_t2m_1_40625': { # relative_humidity
'in_shape': [12, 1, 128, 256],
'pre_seq_length': 12,
'aft_seq_length': 12,
'total_length': 24,
'data_name': 't2m',
'train_time': ['2010', '2015'], 'val_time': ['2016', '2016'], 'test_time': ['2017', '2018'],
},
'weather_r_1_40625': { # relative_humidity
'in_shape': [12, 1, 128, 256],
'pre_seq_length': 12,
'aft_seq_length': 12,
'total_length': 24,
'data_name': 'r',
'train_time': ['2010', '2015'], 'val_time': ['2016', '2016'], 'test_time': ['2017', '2018'],
},
'weather_uv10_1_40625': { # u10+v10, component_of_wind
'in_shape': [12, 2, 128, 256],
'pre_seq_length': 12,
'aft_seq_length': 12,
'total_length': 24,
'data_name': 'uv10',
'train_time': ['2010', '2015'], 'val_time': ['2016', '2016'], 'test_time': ['2017', '2018'],
},
'weather_tcc_1_40625': { # total_cloud_cover
'in_shape': [12, 1, 128, 256],
'pre_seq_length': 12,
'aft_seq_length': 12,
'total_length': 24,
'data_name': 'tcc',
'train_time': ['2010', '2015'], 'val_time': ['2016', '2016'], 'test_time': ['2017', '2018'],
},
}
Loading

0 comments on commit 5cc979f

Please sign in to comment.