Skip to content

Commit

Permalink
add metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Jun 5, 2023
1 parent a9bb392 commit 5882a70
Show file tree
Hide file tree
Showing 26 changed files with 194 additions and 33 deletions.
2 changes: 1 addition & 1 deletion configs/kitticaltech/PhyDNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# model
patch_size = 4
# training
lr = 1e-3
lr = 5e-3
batch_size = 16
sched = 'onecycle'
3 changes: 3 additions & 0 deletions configs/kitticaltech/PredNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@
A_activation = 'relu'
LSTM_activation = 'tanh'
LSTM_inner_activation = 'hard_sigmoid'
# training
lr = 1e-3
batch_size = 16
sched = 'onecycle'
2 changes: 1 addition & 1 deletion configs/kitticaltech/TAU.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
N_S = 2
alpha = 0.1
# training
lr = 1e-3
lr = 1e-2
drop_path = 0.1
batch_size = 16
sched = 'onecycle'
6 changes: 5 additions & 1 deletion configs/kth/ConvLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@
filter_size = 5
stride = 1
patch_size = 2
layer_norm = 0
layer_norm = 0
# training
# lr = 5e-3
batch_size = 16
sched = 'onecycle'
16 changes: 16 additions & 0 deletions configs/kth/DMVFN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
method = 'DMVFN'
# model
routing_out_channels = 32
in_planes = 4 * 3 + 1 + 4 # the first 1: data channel, the second 1: mask channel, the third 4: flow channel
num_block = 9
num_features = [160, 160, 160, 80, 80, 80, 44, 44, 44]
scale = [4, 4, 4, 2, 2, 2, 1, 1, 1]
training = True
# loss
beta = 0.5
gamma = 0.8
coef = 0.5
# training
# lr = 1e-4
batch_size = 16
sched = 'onecycle'
7 changes: 7 additions & 0 deletions configs/kth/PhyDNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
method = 'PhyDNet'
# model
patch_size = 4
# training
# lr = 5e-3
batch_size = 16
sched = 'onecycle'
16 changes: 16 additions & 0 deletions configs/kth/PredNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
method = 'PredNet'
stack_sizes = (3, 32, 64, 128, 256) # 1 refer to num of channel(input)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3, 3)
pixel_max = 1.0
weight_mode = 'L_0'
error_activation = 'relu'
A_activation = 'relu'
LSTM_activation = 'tanh'
LSTM_inner_activation = 'hard_sigmoid'
# training
# lr = 1e-3
batch_size = 16
sched = 'onecycle'
6 changes: 5 additions & 1 deletion configs/kth/PredRNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@
filter_size = 5
stride = 1
patch_size = 2
layer_norm = 0
layer_norm = 0
# training
lr = 1e-3
batch_size = 16
sched = 'onecycle'
6 changes: 5 additions & 1 deletion configs/kth/PredRNNpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@
filter_size = 5
stride = 1
patch_size = 2
layer_norm = 0
layer_norm = 0
# training
lr = 1e-3
batch_size = 16
sched = 'onecycle'
6 changes: 5 additions & 1 deletion configs/kth/PredRNNv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@
stride = 1
patch_size = 2
layer_norm = 0
decouple_beta = 0.01
decouple_beta = 0.01
# training
lr = 1e-3
batch_size = 16
sched = 'onecycle'
11 changes: 8 additions & 3 deletions configs/kth/SimVP.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
spatio_kernel_dec = 3
# model_type = None
hid_S = 64
hid_T = 512
N_T = 8
N_S = 4
hid_T = 256
N_T = 6
N_S = 2
# training
lr = 1e-3
batch_size = 16
drop_path = 0.1
sched = 'onecycle'
1 change: 0 additions & 1 deletion configs/kth/TAU.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
lr = 1e-3
batch_size = 16
drop_path = 0.1
batch_size = 16
sched = 'onecycle'
8 changes: 8 additions & 0 deletions configs/weather/t2m_1_40625/PhyDNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
method = 'PhyDNet'
# model
patch_size = 4
# training
lr = 1e-4
batch_size = 16
sched = 'cosine'
warmup_epoch = 0
3 changes: 2 additions & 1 deletion docs/en/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ Release version to OpenSTL V0.2.0 as [#20](https://github.com/chengtan9907/OpenS

* Update the Weather Bench dataloader with `5.625deg`, `2.8125deg`, and `1.40625deg` settings. Add Human3.6M dataloader (supporting augmentations) and config files. Add Moving FMNIST and MMNIST_CIFAR as two advanced variants of MMNIST datasets.
* Update tools for dataset preparation of Human3.6M, Weather Bench, and Moving FMNIST.
* Support [TAU](https://arxiv.org/abs/2206.12126) and [DMVFN](https://arxiv.org/abs/2303.09875) with configs and benchmark results. And fix bugs in these new STL methods.
* Support [PredNet](https://openreview.net/forum?id=B1ewdt9xe), [TAU](https://arxiv.org/abs/2206.12126), and [DMVFN](https://arxiv.org/abs/2303.09875) with configs and benchmark results. And fix bugs in these new STL methods.
* Support multi-variant versions of Weather Bench with dataloader and metrics.
* Support [lpips](https://github.com/richzhang/PerceptualSimilarity/tree/master) metric for video prediction benchmarks.

#### Update Documents

Expand Down
6 changes: 3 additions & 3 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} [optional arguments]
- `${CONFIG_FILE}` : The path of a model config file, which will provide detailed settings for a STL method.
- `${GPUS}` : The number of GPUs for DDP training.

Examples of multiple GPUs training on Moving MNIST dataset with a machine with 8 GPUs.
Examples of multiple GPUs training on Moving MNIST dataset with a machine with 8 GPUs. Note that some recurrent-based STL methods (e.g., ConvLSTM, PredRNN++) need `--find_unused_parameters` during DDP training.
```shell
PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_train.sh configs/mmnist/simvp/SimVP_gSTA.py 2 -d mmnist --lr 1e-3 --batch_size 8
PORT=29002 CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/mmnist/PredRNN.py 2 -d mmnist --lr 1e-3 --batch_size 8
PORT=29003 CUDA_VISIBLE_DEVICES=4,5,6,7 bash tools/dist_train.sh configs/mmnist/PredRNNpp.py 4 -d mmnist --lr 1e-3 --batch_size 4
PORT=29002 CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/mmnist/ConvLSTM.py 2 -d mmnist --lr 1e-3 --batch_size 8 --find_unused_parameters
PORT=29003 CUDA_VISIBLE_DEVICES=4,5,6,7 bash tools/dist_train.sh configs/mmnist/PredRNNpp.py 4 -d mmnist --lr 1e-3 --batch_size 4 --find_unused_parameters
```

An example of multiple GPUs testing on Moving MNIST dataset. The bash script is `bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} [optional arguments]`.
Expand Down
1 change: 1 addition & 0 deletions docs/en/model_zoos/traffic_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<summary>Currently supported spatiotemporal prediction methods</summary>

- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015)
- [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017)
- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017)
- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
Expand Down
1 change: 1 addition & 0 deletions docs/en/model_zoos/video_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<summary>Currently supported spatiotemporal prediction methods</summary>

- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015)
- [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017)
- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017)
- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
Expand Down
1 change: 1 addition & 0 deletions docs/en/model_zoos/weather_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<summary>Currently supported spatiotemporal prediction methods</summary>

- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NeurIPS'2015)
- [x] [PredNet](https://openreview.net/forum?id=B1ewdt9xe) (ICLR'2017)
- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NeurIPS'2017)
- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
Expand Down
5 changes: 3 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: SimVP
name: OpenSTL
channels:
- pytorch
- conda-forge
Expand All @@ -13,7 +13,8 @@ dependencies:
- xarray==0.19.0
- pip:
- fvcore
- lpips
- scikit-image
- timm
- tqdm
prefix: /opt/anaconda3/envs/simvp
prefix: /opt/anaconda3/envs/openstl
5 changes: 2 additions & 3 deletions openstl/api/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,10 @@ def test(self):
self.call_hook('after_val_epoch')

if 'weather' in self.args.dataname:
metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
metric_list, spatial_norm = self.args.metrics, True
channel_names = self.test_loader.dataset.data_name if 'mv' in self.args.dataname else None
else:
metric_list, spatial_norm = ['mse', 'mae', 'ssim', 'psnr'], False
channel_names = None
metric_list, spatial_norm, channel_names = self.args.metrics, False, None
eval_res, eval_log = metric(results['preds'], results['trues'],
self.test_loader.dataset.mean, self.test_loader.dataset.std,
metrics=metric_list, channel_names=channel_names, spatial_norm=spatial_norm)
Expand Down
69 changes: 67 additions & 2 deletions openstl/core/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import cv2
import numpy as np
from skimage.metrics import structural_similarity as cal_ssim
import torch

try:
import lpips
from skimage.metrics import structural_similarity as cal_ssim
except:
lpips = None
cal_ssim = None


def rescale(x):
Expand Down Expand Up @@ -37,6 +45,53 @@ def PSNR(pred, true):
return 20 * np.log10(255) - 10 * np.log10(mse)


def SSIM(pred, true, **kwargs):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2

img1 = pred.astype(np.float64)
img2 = true.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()


class LPIPS(torch.nn.Module):
"""Learned Perceptual Image Patch Similarity, LPIPS.
Modified from
https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips_2imgs.py
"""

def __init__(self, net='alex', use_gpu=True):
super().__init__()
assert net in ['alex', 'squeeze', 'vgg']
self.use_gpu = use_gpu and torch.cuda.is_available()
self.loss_fn = lpips.LPIPS(net=net)
if use_gpu:
self.loss_fn.cuda()

def forward(self, img1, img2):
# Load images, which are min-max norm to [0, 1]
img1 = lpips.im2tensor(img1 * 255) # RGB image from [-1,1]
img2 = lpips.im2tensor(img2 * 255)
if self.use_gpu:
img1, img2 = img1.cuda(), img2.cuda()
return self.loss_fn.forward(img1, img2).squeeze().detach().cpu().numpy()


def metric(pred, true, mean=None, std=None, metrics=['mae', 'mse'],
clip_range=[0, 1], channel_names=None,
spatial_norm=False, return_log=True):
Expand All @@ -61,7 +116,7 @@ def metric(pred, true, mean=None, std=None, metrics=['mae', 'mse'],
true = true * std + mean
eval_res = {}
eval_log = ""
allowed_metrics = ['mae', 'mse', 'rmse', 'ssim', 'psnr',]
allowed_metrics = ['mae', 'mse', 'rmse', 'ssim', 'psnr', 'lpips']
invalid_metrics = set(metrics) - set(allowed_metrics)
if len(invalid_metrics) != 0:
raise ValueError(f'metric {invalid_metrics} is not supported.')
Expand Down Expand Up @@ -122,6 +177,16 @@ def metric(pred, true, mean=None, std=None, metrics=['mae', 'mse'],
psnr += PSNR(pred[b, f], true[b, f])
eval_res['psnr'] = psnr / (pred.shape[0] * pred.shape[1])

if 'lpips' in metrics:
lpips = 0
cal_lpips = LPIPS(net='alex', use_gpu=False)
pred = pred.transpose(0, 1, 3, 4, 2)
true = true.transpose(0, 1, 3, 4, 2)
for b in range(pred.shape[0]):
for f in range(pred.shape[1]):
lpips += cal_lpips(pred[b, f], true[b, f])
eval_res['lpips'] = lpips / (pred.shape[0] * pred.shape[1])

if return_log:
for k, v in eval_res.items():
eval_str = f"{k}:{v}" if len(eval_log) == 0 else f", {k}:{v}"
Expand Down
Loading

0 comments on commit 5882a70

Please sign in to comment.