From 1f3d9ba80b80ed971893601ab03ef8efaeba5bb7 Mon Sep 17 00:00:00 2001 From: Lupin1998 <1070535169@qq.com> Date: Mon, 20 Feb 2023 20:45:59 +0000 Subject: [PATCH] update docs and mmnist benchmarks --- .gitignore | 1 + .readthedocs.yml | 9 +++ .style.yapf | 4 + README.md | 13 +++- configs/mmnist/simvp/SimVP_ConvMixer.py | 11 +++ configs/mmnist/simvp/SimVP_ConvNeXt.py | 2 + configs/mmnist/simvp/SimVP_HorNet.py | 2 + configs/mmnist/simvp/SimVP_MLPMixer.py | 11 +++ configs/mmnist/simvp/SimVP_MogaNet.py | 2 + configs/mmnist/simvp/SimVP_Poolformer.py | 2 + configs/mmnist/simvp/SimVP_Swin.py | 2 + configs/mmnist/simvp/SimVP_Uniformer.py | 2 + configs/mmnist/simvp/SimVP_VAN.py | 11 +++ configs/mmnist/simvp/SimVP_ViT.py | 2 + configs/taxibj/SimVP.py | 8 ++ docs/en/Makefile | 20 +++++ docs/en/changelog.md | 20 +++++ docs/en/get_started.md | 9 +++ docs/en/index.rst | 38 ++++++++++ docs/en/install.md | 22 ++++++ docs/en/model_zoos/video_benchmarks.md | 90 +++++++++++++++++++++++ docs/en/switch_language.md | 1 + simvp/api/train.py | 18 +++-- simvp/datasets/dataloader_kitticaltech.py | 2 +- simvp/datasets/dataloader_weather.py | 4 +- simvp/modules/simvp_modules.py | 1 + simvp/utils/__init__.py | 4 +- simvp/utils/main_utils.py | 29 ++++++++ simvp/utils/parser.py | 12 +-- 29 files changed, 330 insertions(+), 22 deletions(-) create mode 100644 .readthedocs.yml create mode 100644 .style.yapf create mode 100644 configs/mmnist/simvp/SimVP_ConvMixer.py create mode 100644 configs/mmnist/simvp/SimVP_MLPMixer.py create mode 100644 configs/mmnist/simvp/SimVP_VAN.py create mode 100644 configs/taxibj/SimVP.py create mode 100644 docs/en/Makefile create mode 100644 docs/en/changelog.md create mode 100644 docs/en/get_started.md create mode 100644 docs/en/index.rst create mode 100644 docs/en/install.md create mode 100644 docs/en/model_zoos/video_benchmarks.md create mode 100644 docs/en/switch_language.md diff --git a/.gitignore b/.gitignore index 00768af7..479a2c21 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,4 @@ figs # temp configs/kitticaltech/simvp +configs/kth/simvp diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000..332647c9 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,9 @@ +version: 2 + +formats: [] + +python: + version: 3.7 + install: + - requirements: requirements/docs.txt + - requirements: requirements/readthedocs.txt diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 00000000..286a3f1d --- /dev/null +++ b/.style.yapf @@ -0,0 +1,4 @@ +[style] +BASED_ON_STYLE = pep8 +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true diff --git a/README.md b/README.md index 01a06bd4..a43f6d26 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@

+[🛠️Installation](docs/en/install.md) | +[🚀Model Zoo](docs/en/model_zoos/video_benchmarks.md) | +[🆕News](docs/en/changelog.md) + This repository is an open-source project for video prediction benchmarks, which contains the implementation code for paper: **SimVP: Towards Simple yet Powerful Spatiotemporal Predictive learning** @@ -125,10 +129,11 @@ We support various video prediction methods and will provide benchmarks on vario
Currently supported datasets - - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)] - - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)] - - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)] - - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)] + - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kth)] + - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kitticaltech)] + - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/mmnist)] + - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)] [[config](https://github.com/chengtan9907/SimVPv2/configs/taxibj)] + - [x] [WeatherBench](https://arxiv.org/abs/2002.00469) (ArXiv'2020) [[download](https://github.com/pangeo-data/WeatherBench)] [[config](https://github.com/chengtan9907/SimVPv2/configs/weather)]
diff --git a/configs/mmnist/simvp/SimVP_ConvMixer.py b/configs/mmnist/simvp/SimVP_ConvMixer.py new file mode 100644 index 00000000..5121eb68 --- /dev/null +++ b/configs/mmnist/simvp/SimVP_ConvMixer.py @@ -0,0 +1,11 @@ +method = 'SimVP' +spatio_kernel_enc = 3 +spatio_kernel_dec = 3 +model_type = 'convmixer' +hid_S = 64 +hid_T = 512 +N_T = 8 +N_S = 4 +lr = 1e-2 +batch_size = 16 +drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_ConvNeXt.py b/configs/mmnist/simvp/SimVP_ConvNeXt.py index cb1822e8..f7c95c57 100644 --- a/configs/mmnist/simvp/SimVP_ConvNeXt.py +++ b/configs/mmnist/simvp/SimVP_ConvNeXt.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-2 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_HorNet.py b/configs/mmnist/simvp/SimVP_HorNet.py index 96e76850..d896f059 100644 --- a/configs/mmnist/simvp/SimVP_HorNet.py +++ b/configs/mmnist/simvp/SimVP_HorNet.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-3 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_MLPMixer.py b/configs/mmnist/simvp/SimVP_MLPMixer.py new file mode 100644 index 00000000..8b3cf0ab --- /dev/null +++ b/configs/mmnist/simvp/SimVP_MLPMixer.py @@ -0,0 +1,11 @@ +method = 'SimVP' +spatio_kernel_enc = 3 +spatio_kernel_dec = 3 +model_type = 'mlp' +hid_S = 64 +hid_T = 512 +N_T = 8 +N_S = 4 +lr = 1e-3 +batch_size = 16 +drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_MogaNet.py b/configs/mmnist/simvp/SimVP_MogaNet.py index 638d8476..d8d39061 100644 --- a/configs/mmnist/simvp/SimVP_MogaNet.py +++ b/configs/mmnist/simvp/SimVP_MogaNet.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-3 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_Poolformer.py b/configs/mmnist/simvp/SimVP_Poolformer.py index 21463b64..3305cae2 100644 --- a/configs/mmnist/simvp/SimVP_Poolformer.py +++ b/configs/mmnist/simvp/SimVP_Poolformer.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-3 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_Swin.py b/configs/mmnist/simvp/SimVP_Swin.py index e433852f..79ffcc51 100644 --- a/configs/mmnist/simvp/SimVP_Swin.py +++ b/configs/mmnist/simvp/SimVP_Swin.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-3 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_Uniformer.py b/configs/mmnist/simvp/SimVP_Uniformer.py index 99bfb060..c7f10531 100644 --- a/configs/mmnist/simvp/SimVP_Uniformer.py +++ b/configs/mmnist/simvp/SimVP_Uniformer.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 5e-4 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_VAN.py b/configs/mmnist/simvp/SimVP_VAN.py new file mode 100644 index 00000000..3b5632e0 --- /dev/null +++ b/configs/mmnist/simvp/SimVP_VAN.py @@ -0,0 +1,11 @@ +method = 'SimVP' +spatio_kernel_enc = 3 +spatio_kernel_dec = 3 +model_type = 'convmixer' +hid_S = 64 +hid_T = 512 +N_T = 8 +N_S = 4 +lr = 1e-3 +batch_size = 16 +drop_path = 0 \ No newline at end of file diff --git a/configs/mmnist/simvp/SimVP_ViT.py b/configs/mmnist/simvp/SimVP_ViT.py index fadbf20a..1035acd8 100644 --- a/configs/mmnist/simvp/SimVP_ViT.py +++ b/configs/mmnist/simvp/SimVP_ViT.py @@ -6,4 +6,6 @@ hid_T = 512 N_T = 8 N_S = 4 +lr = 1e-3 +batch_size = 16 drop_path = 0 \ No newline at end of file diff --git a/configs/taxibj/SimVP.py b/configs/taxibj/SimVP.py new file mode 100644 index 00000000..8a63cb8c --- /dev/null +++ b/configs/taxibj/SimVP.py @@ -0,0 +1,8 @@ +method = 'SimVP' +spatio_kernel_enc = 3 +spatio_kernel_dec = 3 +# model_type = None +hid_S = 64 +hid_T = 512 +N_T = 8 +N_S = 4 \ No newline at end of file diff --git a/docs/en/Makefile b/docs/en/Makefile new file mode 100644 index 00000000..d4bb2cbb --- /dev/null +++ b/docs/en/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/en/changelog.md b/docs/en/changelog.md new file mode 100644 index 00000000..94a9fcd0 --- /dev/null +++ b/docs/en/changelog.md @@ -0,0 +1,20 @@ +## Changelog + +### v0.1.0 (18/02/2023) + +Release version to V0.1.0 with code refactoring. + +#### Code Refactoring + +* Refactor code structures as `simvp/api`, `simvp/core`, `simvp/datasets`, `simvp/methods`, `simvp/models`, `simvp/modules`. We support non-distributed training and evaluation by the executable python file `tools/non_dist_train.py`. Refactor config files for SimVP models. +* Fix bugs in tools/nondist_train.py, simvp/utils, environment.yml, and .gitignore, etc. + +#### New Features + +* Support Timm optimizers and schedulers. +* Update popular Metaformer models as the hidden Translator $h$ in SimVP, supporting [ViT](https://arxiv.org/abs/2010.11929), [Swin-Transformer](https://arxiv.org/abs/2103.14030), [MLP-Mixer](https://arxiv.org/abs/2105.01601), [ConvMixer](https://arxiv.org/abs/2201.09792), [UniFormer](https://arxiv.org/abs/2201.09450), [PoolFormer](https://arxiv.org/abs/2111.11418), [ConvNeXt](https://arxiv.org/abs/2201.03545), [VAN](https://arxiv.org/abs/2202.09741), [HorNet](https://arxiv.org/abs/2207.14284), and [MogaNet](https://arxiv.org/abs/2211.03295). +* Update implementations of dataset and dataloader, supporting [KTH Action](https://ieeexplore.ieee.org/document/1334462), [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297), [Moving MNIST](http://arxiv.org/abs/1502.04681), [TaxiBJ](https://arxiv.org/abs/1610.00081), and [WeatherBench](https://arxiv.org/abs/2002.00469). + +### Update Documents + +* Upload `readthedocs` documents. Summarize video prediction benchmark results on MMNIST in [video_benchmarks.md](https://github.com/chengtan9907/SimVPv2/docs/en/model_zoos/video_benchmarks.md). diff --git a/docs/en/get_started.md b/docs/en/get_started.md new file mode 100644 index 00000000..d467eaac --- /dev/null +++ b/docs/en/get_started.md @@ -0,0 +1,9 @@ +# Getting Started + +This page provides basic tutorials about the usage of SimVP. For installation instructions, please see [Install](docs/en/install.md). + +An example of single GPU training SimVP+gSTA on Moving MNIST dataset. +```shell +bash tools/prepare_data/download_mmnist.sh +python tools/non_dist_train.py -d mmnist -m SimVP --model_type gsta --lr 1e-3 --ex_name mmnist_simvp_gsta +``` diff --git a/docs/en/index.rst b/docs/en/index.rst new file mode 100644 index 00000000..93c09f01 --- /dev/null +++ b/docs/en/index.rst @@ -0,0 +1,38 @@ +.. SimVP documentation master file, created by + sphinx-quickstart on Thu June 15 05:11:34 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to SimVP's documentation! +===================================== + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + + install.md + get_started.md + +.. toctree:: + :maxdepth: 1 + :caption: Model Zoos + + model_zoos/video_benchmarks.md + +.. toctree:: + :maxdepth: 1 + :caption: Notes + + changelog.md + +.. toctree:: + :caption: Switch Language + + switch_language.md + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`search` diff --git a/docs/en/install.md b/docs/en/install.md new file mode 100644 index 00000000..ebe57f8d --- /dev/null +++ b/docs/en/install.md @@ -0,0 +1,22 @@ +# Installation + +This project has provided an environment setting file of conda, users can easily reproduce the environment by the following commands: +```shell +git clone https://github.com/chengtan9907/SimVPv2 +cd SimVPv2 +conda env create -f environment.yml +conda activate SimVP +python setup.py develop +``` + +
+Dependencies + +* argparse +* numpy +* hickle +* scikit-image=0.16.2 +* torch +* timm +* tqdm +
diff --git a/docs/en/model_zoos/video_benchmarks.md b/docs/en/model_zoos/video_benchmarks.md new file mode 100644 index 00000000..85449ca2 --- /dev/null +++ b/docs/en/model_zoos/video_benchmarks.md @@ -0,0 +1,90 @@ +# Video Prediction Benchmarks + +**We provide benchmark results of video prediction methods on video datasets. More video prediction methods will be supported in the future. Issues and PRs are welcome!** + +
+Currently supported video prediction methods + +- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NIPS'2015) +- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NIPS'2017) +- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018) +- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018) +- [x] [MAU](https://arxiv.org/abs/1811.07490) (CVPR'2019) +- [x] [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2020) +- [x] [PhyDNet](https://arxiv.org/abs/2003.01460) (CVPR'2020) +- [x] [PredRNN.V2](https://arxiv.org/abs/2103.09504v4) (TPAMI'2022) +- [x] [SimVP](https://arxiv.org/abs/2206.05099) (CVPR'2022) +- [x] [SimVP.V2](https://arxiv.org/abs/2211.12509) (ArXiv'2022) + +
+ +
+Currently supported MetaFormer models for SimVP + +- [x] [ViT](https://arxiv.org/abs/2010.11929) (ICLR'2021) +- [x] [Swin-Transformer](https://arxiv.org/abs/2103.14030) (ICCV'2021) +- [x] [MLP-Mixer](https://arxiv.org/abs/2105.01601) (NIPS'2021) +- [x] [ConvMixer](https://arxiv.org/abs/2201.09792) (Openreview'2021) +- [x] [UniFormer](https://arxiv.org/abs/2201.09450) (ICLR'2022) +- [x] [PoolFormer](https://arxiv.org/abs/2111.11418) (CVPR'2022) +- [x] [ConvNeXt](https://arxiv.org/abs/2201.03545) (CVPR'2022) +- [x] [VAN](https://arxiv.org/abs/2202.09741) (ArXiv'2022) +- [x] [IncepU (SimVP.V1)](https://arxiv.org/abs/2206.05099) (CVPR'2022) +- [x] [gSTA (SimVP.V2)](https://arxiv.org/abs/2211.12509) (ArXiv'2022) +- [x] [HorNet](https://arxiv.org/abs/2207.14284) (NIPS'2022) +- [x] [MogaNet](https://arxiv.org/abs/2211.03295) (ArXiv'2022) + +
+ +## Moving MNIST Benchmarks + +We provide benchmark results on popular [Moving MNIST](http://arxiv.org/abs/1502.04681) dataset using $10\rightarrow 10$ frames prediction setting. Metrics (MSE, MAE, SSIM, pSNR) of the final models are reported in three trials. Parameters (M), FLOPs (G), inference FPS (s) are also reported for all methods. + +### **Benchmark of Video Prediction Methods** + +For fair comparison of different methods, we report final results when models are trained to convergence. We provide config file in [configs/mmnist](https://github.com/chengtan9907/SimVPv2/configs/mmnist). + +| Method | Params | FLOPs | FPS | MSE | MAE | SSIM | Download | +|--------------|:------:|:------:|:---:|:-----:|:------:|:-----:|:------------:| +| ConvLSTM-S | 15.0M | 56.8G | 113 | 46.26 | 142.18 | 0.878 | model \| log | +| ConvLSTM-L | 33.8M | 127.0G | 50 | 29.88 | 95.05 | 0.925 | model \| log | +| PhyDNet | 3.1M | 15.3G | 182 | 35.68 | 96.70 | 0.917 | model \| log | +| PredRNN | 23.8M | 116.0G | 54 | 25.04 | 76.26 | 0.944 | model \| log | +| PredRNN++ | 38.6M | 171.7G | 38 | 22.45 | 69.70 | 0.950 | model \| log | +| MIM | 38.0M | 179.2G | 37 | 23.66 | 74.37 | 0.946 | model \| log | +| E3D-LSTM | 51.0M | 298.9G | 18 | 36.19 | 78.64 | 0.932 | model \| log | +| CrevNet | 5.0M | 270.7G | 10 | 30.15 | 86.28 | 0.935 | model \| log | +| PredRNN.V2 | 23.9M | 116.6G | 52 | 27.73 | 82.17 | 0.937 | model \| log | +| SimVP+IncepU | 58.0M | 19.4G | 209 | 26.69 | 77.19 | 0.940 | model \| log | +| SimVP+gSTA-S | 46.8M | 16.5G | 282 | 15.05 | 49.80 | 0.967 | model \| log | + +### **Benchmark of MetaFormers on SimVP** + +Since the hidden Translator in [SimVP](https://arxiv.org/abs/2211.12509) can be replaced by any [Metaformer](https://arxiv.org/abs/2111.11418) block which achieves `token mixing` and `channel mixing`, we benchmark popular Metaformer architectures on SimVP with training times of 200-epoch and 2000-epoch. We provide config file in [configs/mmnist/simvp](https://github.com/chengtan9907/SimVPv2/configs/mmnist/simvp/). + +| MetaFormer | Setting | Params | FLOPs | FPS | MSE | MAE | SSIM | PSNR | Download | +|------------------|:----------:|:------:|:------:|:----:|:-----:|:-----:|:------:|:-----:|:------------:| +| IncepU (SimVPv1) | 200 epoch | 58.0M | 19.4G | 209s | 32.15 | 89.05 | 0.9268 | 37.97 | model \| log | +| gSTA (SimVPv2) | 200 epoch | 46.8M | 16.5G | 282s | 26.69 | 77.19 | 0.9402 | 38.3 | model \| log | +| ViT | 200 epoch | 46.1M | 16.9.G | 290s | 35.15 | 95.87 | 0.9139 | 37.79 | model \| log | +| Swin Transformer | 200 epoch | 46.1M | 16.4G | 294s | 29.70 | 84.05 | 0.9331 | 38.14 | model \| log | +| Uniformer | 200 epoch | 44.8M | 16.5G | 296s | 30.38 | 85.87 | 0.9308 | 38.11 | model \| log | +| MLP-Mixer | 200 epoch | 38.2M | 14.7G | 334s | 29.52 | 83.36 | 0.9338 | 38.19 | model \| log | +| ConvMixer | 200 epoch | 3.9M | 5.5G | 658s | 32.09 | 88.93 | 0.9259 | 37.97 | model \| log | +| Poolformer | 200 epoch | 37.1M | 14.1G | 341s | 31.79 | 88.48 | 0.9271 | 38.06 | model \| log | +| ConvNeXt | 200 epoch | 37.3M | 14.1G | 344s | 26.94 | 77.23 | 0.9397 | 38.34 | model \| log | +| VAN | 200 epoch | 44.5M | 16.0G | 288s | 26.10 | 76.11 | 0.9417 | 38.39 | model \| log | +| HorNet | 200 epoch | 45.7M | 16.3G | 287s | 29.64 | 83.26 | 0.9331 | 38.16 | model \| log | +| MogaNet | 200 epoch | 46.8M | 16.5G | 255s | 25.57 | 75.19 | 0.9429 | 38.41 | model \| log | +| IncepU (SimVPv1) | 2000 epoch | 58.0M | 19.4G | 209s | - | - | - | - | - | +| gSTA (SimVPv2) | 2000 epoch | 46.8M | 16.5G | 282s | 15.05 | 49.80 | 0.9670 | - | model \| log | +| ViT | 2000 epoch | 46.1M | 16.9.G | 290s | 19.74 | 61.65 | 0.9539 | 38.96 | model \| log | +| Swin Transformer | 2000 epoch | 46.1M | 16.4G | 294s | 19.11 | 59.84 | 0.9584 | 39.03 | model \| log | +| Uniformer | 2000 epoch | 44.8M | 16.5G | 296s | 18.01 | 57.52 | 0.9609 | 39.11 | model \| log | +| MLP-Mixer | 2000 epoch | 38.2M | 14.7G | 334s | 18.85 | 59.86 | 0.9589 | 38.98 | model \| log | +| ConvMixer | 2000 epoch | 3.9M | 5.5G | 658s | 22.30 | 67.37 | 0.9507 | 38.67 | model \| log | +| Poolformer | 2000 epoch | 37.1M | 14.1G | 341s | 20.96 | 64.31 | 0.9539 | 38.86 | model \| log | +| ConvNeXt | 2000 epoch | 37.3M | 14.1G | 344s | 17.58 | 55.76 | 0.9617 | 39.19 | model \| log | +| VAN | 2000 epoch | 44.5M | 16.0G | 288s | 16.21 | 53.57 | 0.9646 | 39.26 | model \| log | +| HorNet | 2000 epoch | 45.7M | 16.3G | 287s | 17.40 | 55.70 | 0.9624 | 39.19 | model \| log | +| MogaNet | 2000 epoch | 46.8M | 16.5G | 255s | 15.67 | 51.84 | 0.9661 | 39.35 | model \| log | diff --git a/docs/en/switch_language.md b/docs/en/switch_language.md new file mode 100644 index 00000000..4cf942c1 --- /dev/null +++ b/docs/en/switch_language.md @@ -0,0 +1 @@ +## English diff --git a/simvp/api/train.py b/simvp/api/train.py index abb89de2..4bcaeb48 100644 --- a/simvp/api/train.py +++ b/simvp/api/train.py @@ -12,7 +12,7 @@ from simvp.core import metric, Recorder from simvp.methods import method_maps from simvp.utils import (set_seed, print_log, output_namespace, check_dir, - get_dataset) + get_dataset, measure_throughput) try: import nni @@ -35,34 +35,36 @@ def __init__(self, args): T, C, H, W = self.args.in_shape if self.args.method == 'simvp': - _tmp_input = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device) - flops = FlopCountAnalysis(self.method.model, _tmp_input) + input_dummy = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device) elif self.args.method == 'crevnet': # crevnet must use the batchsize rather than 1 - _tmp_input = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device) - flops = FlopCountAnalysis(self.method.model, _tmp_input) + input_dummy = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device) elif self.args.method == 'phydnet': _tmp_input1 = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device) _tmp_input2 = torch.ones(1, self.args.aft_seq_length, C, H, W).to(self.device) _tmp_constraints = torch.zeros((49, 7, 7)).to(self.device) - flops = FlopCountAnalysis(self.method.model, (_tmp_input1, _tmp_input2, _tmp_constraints)) + input_dummy = (_tmp_input1, _tmp_input2, _tmp_constraints) elif self.args.method in ['convlstm', 'predrnnpp', 'predrnn', 'mim', 'e3dlstm', 'mau']: Hp, Wp = H // self.args.patch_size, W // self.args.patch_size Cp = self.args.patch_size ** 2 * C _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device) _tmp_flag = torch.ones(1, self.args.aft_seq_length - 1, Hp, Wp, Cp).to(self.device) - flops = FlopCountAnalysis(self.method.model, (_tmp_input, _tmp_flag)) + input_dummy = (_tmp_input, _tmp_flag) elif self.args.method == 'predrnnv2': Hp, Wp = H // self.args.patch_size, W // self.args.patch_size Cp = self.args.patch_size ** 2 * C _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device) _tmp_flag = torch.ones(1, self.args.total_length - 2, Hp, Wp, Cp).to(self.device) - flops = FlopCountAnalysis(self.method.model, (_tmp_input, _tmp_flag)) + input_dummy = (_tmp_input, _tmp_flag) else: raise ValueError(f'Invalid method name {self.args.method}') print_log(self.method.model) + flops = FlopCountAnalysis(self.method.model, input_dummy) print_log(flop_count_table(flops)) + if args.fps: + fps = measure_throughput(self.method.model, input_dummy) + print_log('Throughputs of {}: {:.3f}'.format(self.args.method, fps)) def _acquire_device(self): if self.args.use_gpu: diff --git a/simvp/datasets/dataloader_kitticaltech.py b/simvp/datasets/dataloader_kitticaltech.py index 3ca5d2b7..9c6f50e7 100644 --- a/simvp/datasets/dataloader_kitticaltech.py +++ b/simvp/datasets/dataloader_kitticaltech.py @@ -121,7 +121,7 @@ def load_data(batch_size, val_batch_size, data_root, } input_handle = DataProcess(input_param) train_data, train_idx = input_handle.load_data('train') - test_data, test_idx = input_handle.load_data('val') + test_data, test_idx = input_handle.load_data('test') elif os.path.exists(osp.join(data_root, 'kitticaltech_npy')): train_data = np.load(osp.join(data_root, 'kitticaltech_npy', 'train_data.npy')) train_idx = np.load(osp.join(data_root, 'kitticaltech_npy', 'train_idx.npy')) diff --git a/simvp/datasets/dataloader_weather.py b/simvp/datasets/dataloader_weather.py index 04ab5d6a..e0657cc2 100644 --- a/simvp/datasets/dataloader_weather.py +++ b/simvp/datasets/dataloader_weather.py @@ -143,8 +143,8 @@ def load_data(batch_size, val_batch_size, data_root, num_workers=2, - data_name=['t'], - train_time=['2010', '2015'], + data_name='t2m', + train_time=['1979', '2015'], val_time=['2016', '2016'], test_time=['2017', '2018'], idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0], diff --git a/simvp/modules/simvp_modules.py b/simvp/modules/simvp_modules.py index 46390fed..53de7ec6 100644 --- a/simvp/modules/simvp_modules.py +++ b/simvp/modules/simvp_modules.py @@ -514,6 +514,7 @@ class ViTSubBlock(ViTBlock): def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True, drop=drop, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): diff --git a/simvp/utils/__init__.py b/simvp/utils/__init__.py index 418f11e4..e8df9449 100644 --- a/simvp/utils/__init__.py +++ b/simvp/utils/__init__.py @@ -2,7 +2,7 @@ from .config_utils import Config, check_file_exist from .main_utils import (set_seed, print_log, output_namespace, check_dir, get_dataset, - count_parameters, load_config, update_config) + count_parameters, measure_throughput, load_config, update_config) from .parser import create_parser from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch, reshape_patch_back) @@ -10,6 +10,6 @@ __all__ = [ 'Config', 'check_file_exist', 'create_parser', 'set_seed', 'print_log', 'output_namespace', 'check_dir', 'get_dataset', 'count_parameters', - 'load_config', 'update_config', + 'measure_throughput', 'load_config', 'update_config', 'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back', ] \ No newline at end of file diff --git a/simvp/utils/main_utils.py b/simvp/utils/main_utils.py index d72cd909..42a03b01 100644 --- a/simvp/utils/main_utils.py +++ b/simvp/utils/main_utils.py @@ -47,6 +47,35 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) +def measure_throughput(model, input_dummy): + bs = 100 + repetitions = 100 + if isinstance(input_dummy, tuple): + input_dummy = list(input_dummy) + _, T, C, H, W = input_dummy[0].shape + _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device) + input_dummy[0] = _input + input_dummy = tuple(input_dummy) + else: + _, T, C, H, W = input_dummy.shape + input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device) + total_time = 0 + with torch.no_grad(): + for _ in range(repetitions): + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + if isinstance(input_dummy, tuple): + _ = model(*input_dummy) + else: + _ = model(input_dummy) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) / 1000 + total_time += curr_time + Throughput = (repetitions * bs) / total_time + return Throughput + + def load_config(filename:str = None): ''' load and print config diff --git a/simvp/utils/parser.py b/simvp/utils/parser.py index c52e3350..6bcaa0d0 100644 --- a/simvp/utils/parser.py +++ b/simvp/utils/parser.py @@ -15,18 +15,20 @@ def create_parser(): parser.add_argument('--use_gpu', default=True, type=bool) parser.add_argument('--gpu', default=0, type=int) parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--fps', action='store_true', default=False, + help='Whether to measure inference speed (FPS)') # dataset parameters - parser.add_argument('--batch_size', '-b', default=16, type=int, help="Training batch size") - parser.add_argument('--val_batch_size', '-vb', default=4, type=int, help="Validation batch size") + parser.add_argument('--batch_size', '-b', default=16, type=int, help='Training batch size') + parser.add_argument('--val_batch_size', '-vb', default=4, type=int, help='Validation batch size') parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--data_root', default='./data/') parser.add_argument('--dataname', '-d', default='mmnist', type=str, choices=['mmnist', 'kitticaltech', 'kth', 'kth40', 'taxibj', 'weather'], help='Dataset name (default: "mmnist")') - parser.add_argument('--pre_seq_length', default=None, type=int, help="Sequence length before prediction") - parser.add_argument('--aft_seq_length', default=None, type=int, help="Sequence length after prediction") - parser.add_argument('--total_length', default=None, type=int, help="Total Sequence length for prediction") + parser.add_argument('--pre_seq_length', default=None, type=int, help='Sequence length before prediction') + parser.add_argument('--aft_seq_length', default=None, type=int, help='Sequence length after prediction') + parser.add_argument('--total_length', default=None, type=int, help='Total Sequence length for prediction') # method parameters parser.add_argument('--method', '-m', default='SimVP', type=str,