Skip to content

Commit

Permalink
fix psnr and update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed Jun 8, 2023
1 parent 5882a70 commit 0d74406
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 137 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ figs
configs_test/
configs/kth/simvp
configs/human/simvp
api/test.py
2 changes: 1 addition & 1 deletion configs/kitticaltech/simvp/SimVP_VAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
N_T = 6
N_S = 2
# training
lr = 5e-3
lr = 1e-2
drop_path = 0.1
batch_size = 16
sched = 'onecycle'
2 changes: 1 addition & 1 deletion configs/kth/DMVFN.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
method = 'DMVFN'
# model
routing_out_channels = 32
in_planes = 4 * 3 + 1 + 4 # the first 1: data channel, the second 1: mask channel, the third 4: flow channel
in_planes = 4 * 1 + 1 + 4 # the first 1: data channel, the second 1: mask channel, the third 4: flow channel
num_block = 9
num_features = [160, 160, 160, 80, 80, 80, 44, 44, 44]
scale = [4, 4, 4, 2, 2, 2, 1, 1, 1]
Expand Down
2 changes: 1 addition & 1 deletion configs/kth/PredNet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
method = 'PredNet'
stack_sizes = (3, 32, 64, 128, 256) # 1 refer to num of channel(input)
stack_sizes = (1, 32, 64, 128, 256) # 1 refer to num of channel(input)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3, 3)
Expand Down
2 changes: 1 addition & 1 deletion configs/mfmnist/simvp/SimVP_VAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# model
spatio_kernel_enc = 3
spatio_kernel_dec = 3
model_type = 'convmixer'
model_type = 'van'
hid_S = 64
hid_T = 512
N_T = 8
Expand Down
2 changes: 1 addition & 1 deletion configs/mmnist/simvp/SimVP_VAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# model
spatio_kernel_enc = 3
spatio_kernel_dec = 3
model_type = 'convmixer'
model_type = 'van'
hid_S = 64
hid_T = 512
N_T = 8
Expand Down
2 changes: 1 addition & 1 deletion configs/mmnist_cifar/simvp/SimVP_VAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# model
spatio_kernel_enc = 3
spatio_kernel_dec = 3
model_type = 'convmixer'
model_type = 'van'
hid_S = 64
hid_T = 512
N_T = 8
Expand Down
1 change: 1 addition & 0 deletions docs/en/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Release version to OpenSTL V0.2.0 as [#20](https://github.com/chengtan9907/OpenS
* Fig bugs of building distributed dataloaders and preparation of DDP training.
* Fix bugs of some STL methods (CrevNet, DMVFN, PreDNet, and TAU).
* Fix bugs in datasets: fixing Caltech dataset for evaluation (28/05/2023 updating [Baidu Cloud](https://pan.baidu.com/s/1fudsBHyrf3nbt-7d42YWWg?pwd=kjfk)).
* Fix the bug of `PSNR` (changing the implementation from E3D-LSTM to the corrent version) and update results in the benchmarks.

### v0.1.0 (18/02/2023)

Expand Down
5 changes: 3 additions & 2 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ PORT=29002 CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/mmnist/Conv
PORT=29003 CUDA_VISIBLE_DEVICES=4,5,6,7 bash tools/dist_train.sh configs/mmnist/PredRNNpp.py 4 -d mmnist --lr 1e-3 --batch_size 4 --find_unused_parameters
```

An example of multiple GPUs testing on Moving MNIST dataset. The bash script is `bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} [optional arguments]`.
An example of multiple GPUs testing on Moving MNIST dataset. The bash script is `bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} [optional arguments]`, where the first three augments are necessary.
```shell
PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_test.sh configs/mmnist/simvp/SimVP_gSTA.py 2 work_dirs/mmnist/simvp/SimVP_gSTA -d mmnist
```

**Note**:
* During DDP training, the number of GPUS `ngpus` should be provided, and checkpoints and logs are saved in the same folder structure as the config file under `work_dirs/` (it will be the default setting if `--ex_name` is not specified). The default learning rate `lr` and the batch size `bs` in config files are for a single GPU. If using a different number GPUs, the total batch size will change in proportion, you have to scale the learning rate following `lr = base_lr * ngpus` and `bs = base_bs * ngpus` (known as the `linear scaling rule`). Other arguments should be added as the single GPU training.
* Experiment results using different GPUs settings will produce different results. We have noticed that single GPU training with DP and DDP setups will produce similar results, while different multiple GPUs using linear scaling rules will cause different results because of DDP training. For example, SimVP+gSTA is trained 200 epochs on MMNIST with `1GPU (DP)`, `1GPU (DDP)`, `2GPUs (2xbs8)`, and `4GPUs (4xbs4)` using the same learning rate (lr=1e-3), we produce results of MSE 26.73, 26.78, 30.01, 31.36. Therefore, we will provide the used GPUs setting in the benchmark result with the corresponding learning rate for fair comparison and reproducible purposes.
* Experiment results using different GPUs settings will produce different results. We have noticed that single GPU training with DP and DDP setups will produce similar results, while different multiple GPUs using the **linear scaling rule** will cause different results because of DDP training. For example, SimVP+gSTA is trained 200 epochs on MMNIST with `1GPU (DP)`, `1GPU (DDP)`, `2GPUs (2xbs8)`, and `4GPUs (4xbs4)` using the same learning rate (lr=1e-3), we produce results of MSE 26.73, 26.78, 30.01, 31.36. Therefore, we will provide the used GPUs setting in the benchmark result with the corresponding learning rate for fair comparison and reproducible purposes.
* DDP training and testing error of `WARNING:torch.distributed.elastic.multiprocessing.api:Sending process xxx closing signal SIGTERM` might sometimes occur, caused by PyTorch1.9 to PyTorch1.10. You can use PyTorch1.8 or PyTorch2.0 to get rid of these errors, or conduct `1GPU` experiments.

## Mixed Precision Training

Expand Down
49 changes: 25 additions & 24 deletions docs/en/model_zoos/traffic_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,37 @@ For a fair comparison of different methods, we report the best results when mode

| Method | Setting | Params | FLOPs | FPS | MSE | MAE | SSIM | PSNR | Download |
|--------------|:--------:|:------:|:------:|:----:|:------:|:-----:|:------:|:-----:|:------------:|
| ConvLSTM-S | 50 epoch | 14.98M | 20.74G | 815 | 0.3358 | 15.32 | 0.9836 | 39.73 | model \| log |
| E3D-LSTM\* | 50 epoch | 50.99M | 98.19G | 60 | 0.3421 | 14.96 | 0.9842 | 39.92 | model \| log |
| PhyDNet | 50 epoch | 3.09M | 5.60G | 982 | 0.3622 | 15.53 | 0.9828 | 39.76 | model \| log |
| PredRNN | 50 epoch | 23.66M | 42.40G | 416 | 0.3194 | 15.31 | 0.9838 | 39.79 | model \| log |
| MIM | 50 epoch | 37.86M | 64.10G | 275 | 0.3110 | 14.96 | 0.9847 | 39.88 | model \| log |
| MAU | 50 epoch | 4.41M | 6.02G | 540 | 0.3268 | 15.26 | 0.9834 | 39.80 | model \| log |
| PredRNN++ | 50 epoch | 38.40M | 62.95G | 301 | 0.3348 | 15.37 | 0.9834 | 39.76 | model \| log |
| PredRNN.V2 | 50 epoch | 23.67M | 42.63G | 378 | 0.3834 | 15.55 | 0.9826 | 39.75 | model \| log |
| DMVFN | 50 epoch | 3.54M | 0.057G | 6347 | | | | | model \| log |
| SimVP+IncepU | 50 epoch | 13.79M | 3.61G | 533 | 0.3282 | 15.45 | 0.9835 | 39.72 | model \| log |
| SimVP+gSTA-S | 50 epoch | 9.96M | 2.62G | 1217 | 0.3246 | 15.03 | 0.9844 | 39.95 | model \| log |
| TAU | 50 epoch | 9.55M | 2.49G | 1268 | 0.3108 | 14.93 | 0.9848 | 39.97 | model \| log |
| ConvLSTM-S | 50 epoch | 14.98M | 20.74G | 815 | 0.3358 | 15.32 | 0.9836 | 39.45 | model \| log |
| E3D-LSTM\* | 50 epoch | 50.99M | 98.19G | 60 | 0.3427 | 14.98 | 0.9842 | 39.64 | model \| log |
| PhyDNet | 50 epoch | 3.09M | 5.60G | 982 | 0.3622 | 15.53 | 0.9828 | 39.46 | model \| log |
| PredNet | 50 epoch | 12.5M | 0.85G | 5031 | 0.3516 | 15.91 | 0.9828 | 39.29 | model \| log |
| PredRNN | 50 epoch | 23.66M | 42.40G | 416 | 0.3194 | 15.31 | 0.9838 | 39.51 | model \| log |
| MIM | 50 epoch | 37.86M | 64.10G | 275 | 0.3110 | 14.96 | 0.9847 | 39.65 | model \| log |
| MAU | 50 epoch | 4.41M | 6.02G | 540 | 0.3268 | 15.26 | 0.9834 | 39.52 | model \| log |
| PredRNN++ | 50 epoch | 38.40M | 62.95G | 301 | 0.3348 | 15.37 | 0.9834 | 39.47 | model \| log |
| PredRNN.V2 | 50 epoch | 23.67M | 42.63G | 378 | 0.3834 | 15.55 | 0.9826 | 39.49 | model \| log |
| DMVFN | 50 epoch | 3.54M | 0.057G | 6347 | 0.3517 | 15.72 | 0.9833 | 39.33 | model \| log |
| SimVP+IncepU | 50 epoch | 13.79M | 3.61G | 533 | 0.3282 | 15.45 | 0.9835 | 39.45 | model \| log |
| SimVP+gSTA-S | 50 epoch | 9.96M | 2.62G | 1217 | 0.3246 | 15.03 | 0.9844 | 39.71 | model \| log |
| TAU | 50 epoch | 9.55M | 2.49G | 1268 | 0.3108 | 14.93 | 0.9848 | 39.74 | model \| log |

### **Benchmark of MetaFormers on SimVP**

Similar to [Moving MNIST Benchmarks](#moving-mnist-benchmarks), we benchmark popular Metaformer architectures on [SimVP](https://arxiv.org/abs/2211.12509) with training times of 50-epoch. We provide config files in [configs/taxibj/simvp](https://github.com/chengtan9907/OpenSTL/configs/taxibj/simvp/).

| MetaFormer | Setting | Params | FLOPs | FPS | MSE | MAE | SSIM | PSNR | Download |
|------------------|:--------:|:------:|:-----:|:----:|:------:|:-----:|:------:|:-----:|:------------:|
| IncepU (SimVPv1) | 50 epoch | 13.79M | 3.61G | 533 | 0.3282 | 15.45 | 0.9835 | 39.72 | model \| log |
| gSTA (SimVPv2) | 50 epoch | 9.96M | 2.62G | 1217 | 0.3246 | 15.03 | 0.9844 | 39.95 | model \| log |
| ViT | 50 epoch | 9.66M | 2.80G | 1301 | 0.3171 | 15.15 | 0.9841 | 39.89 | model \| log |
| Swin Transformer | 50 epoch | 9.66M | 2.56G | 1506 | 0.3128 | 15.07 | 0.9847 | 39.89 | model \| log |
| Uniformer | 50 epoch | 9.52M | 2.71G | 1333 | 0.3268 | 15.16 | 0.9844 | 39.89 | model \| log |
| MLP-Mixer | 50 epoch | 8.24M | 2.18G | 1974 | 0.3206 | 15.37 | 0.9841 | 39.78 | model \| log |
| ConvMixer | 50 epoch | 0.84M | 0.23G | 4793 | 0.3634 | 15.63 | 0.9831 | 39.69 | model \| log |
| Poolformer | 50 epoch | 7.75M | 2.06G | 1827 | 0.3273 | 15.39 | 0.9840 | 39.75 | model \| log |
| ConvNeXt | 50 epoch | 7.84M | 2.08G | 1918 | 0.3106 | 14.90 | 0.9845 | 39.99 | model \| log |
| VAN | 50 epoch | 9.48M | 2.49G | 1273 | 0.3125 | 14.96 | 0.9848 | 39.95 | model \| log |
| HorNet | 50 epoch | 9.68M | 2.54G | 1350 | 0.3186 | 15.01 | 0.9843 | 39.91 | model \| log |
| MogaNet | 50 epoch | 9.96M | 2.61G | 1005 | 0.3114 | 15.06 | 0.9847 | 39.92 | model \| log |
| SimVP+IncepU | 50 epoch | 13.79M | 3.61G | 533 | 0.3282 | 15.45 | 0.9835 | 39.45 | model \| log |
| SimVP+gSTA-S | 50 epoch | 9.96M | 2.62G | 1217 | 0.3246 | 15.03 | 0.9844 | 39.71 | model \| log |
| ViT | 50 epoch | 9.66M | 2.80G | 1301 | 0.3171 | 15.15 | 0.9841 | 39.64 | model \| log |
| Swin Transformer | 50 epoch | 9.66M | 2.56G | 1506 | 0.3128 | 15.07 | 0.9847 | 39.65 | model \| log |
| Uniformer | 50 epoch | 9.52M | 2.71G | 1333 | 0.3268 | 15.16 | 0.9844 | 39.64 | model \| log |
| MLP-Mixer | 50 epoch | 8.24M | 2.18G | 1974 | 0.3206 | 15.37 | 0.9841 | 39.49 | model \| log |
| ConvMixer | 50 epoch | 0.84M | 0.23G | 4793 | 0.3634 | 15.63 | 0.9831 | 39.41 | model \| log |
| Poolformer | 50 epoch | 7.75M | 2.06G | 1827 | 0.3273 | 15.39 | 0.9840 | 39.46 | model \| log |
| ConvNeXt | 50 epoch | 7.84M | 2.08G | 1918 | 0.3106 | 14.90 | 0.9845 | 39.76 | model \| log |
| VAN | 50 epoch | 9.48M | 2.49G | 1273 | 0.3125 | 14.96 | 0.9848 | 39.72 | model \| log |
| HorNet | 50 epoch | 9.68M | 2.54G | 1350 | 0.3186 | 15.01 | 0.9843 | 39.66 | model \| log |
| MogaNet | 50 epoch | 9.96M | 2.61G | 1005 | 0.3114 | 15.06 | 0.9847 | 39.70 | model \| log |

<p align="right">(<a href="#top">back to top</a>)</p>
Loading

0 comments on commit 0d74406

Please sign in to comment.