From 1f3d9ba80b80ed971893601ab03ef8efaeba5bb7 Mon Sep 17 00:00:00 2001
From: Lupin1998 <1070535169@qq.com>
Date: Mon, 20 Feb 2023 20:45:59 +0000
Subject: [PATCH] update docs and mmnist benchmarks
---
.gitignore | 1 +
.readthedocs.yml | 9 +++
.style.yapf | 4 +
README.md | 13 +++-
configs/mmnist/simvp/SimVP_ConvMixer.py | 11 +++
configs/mmnist/simvp/SimVP_ConvNeXt.py | 2 +
configs/mmnist/simvp/SimVP_HorNet.py | 2 +
configs/mmnist/simvp/SimVP_MLPMixer.py | 11 +++
configs/mmnist/simvp/SimVP_MogaNet.py | 2 +
configs/mmnist/simvp/SimVP_Poolformer.py | 2 +
configs/mmnist/simvp/SimVP_Swin.py | 2 +
configs/mmnist/simvp/SimVP_Uniformer.py | 2 +
configs/mmnist/simvp/SimVP_VAN.py | 11 +++
configs/mmnist/simvp/SimVP_ViT.py | 2 +
configs/taxibj/SimVP.py | 8 ++
docs/en/Makefile | 20 +++++
docs/en/changelog.md | 20 +++++
docs/en/get_started.md | 9 +++
docs/en/index.rst | 38 ++++++++++
docs/en/install.md | 22 ++++++
docs/en/model_zoos/video_benchmarks.md | 90 +++++++++++++++++++++++
docs/en/switch_language.md | 1 +
simvp/api/train.py | 18 +++--
simvp/datasets/dataloader_kitticaltech.py | 2 +-
simvp/datasets/dataloader_weather.py | 4 +-
simvp/modules/simvp_modules.py | 1 +
simvp/utils/__init__.py | 4 +-
simvp/utils/main_utils.py | 29 ++++++++
simvp/utils/parser.py | 12 +--
29 files changed, 330 insertions(+), 22 deletions(-)
create mode 100644 .readthedocs.yml
create mode 100644 .style.yapf
create mode 100644 configs/mmnist/simvp/SimVP_ConvMixer.py
create mode 100644 configs/mmnist/simvp/SimVP_MLPMixer.py
create mode 100644 configs/mmnist/simvp/SimVP_VAN.py
create mode 100644 configs/taxibj/SimVP.py
create mode 100644 docs/en/Makefile
create mode 100644 docs/en/changelog.md
create mode 100644 docs/en/get_started.md
create mode 100644 docs/en/index.rst
create mode 100644 docs/en/install.md
create mode 100644 docs/en/model_zoos/video_benchmarks.md
create mode 100644 docs/en/switch_language.md
diff --git a/.gitignore b/.gitignore
index 00768af7..479a2c21 100644
--- a/.gitignore
+++ b/.gitignore
@@ -135,3 +135,4 @@ figs
# temp
configs/kitticaltech/simvp
+configs/kth/simvp
diff --git a/.readthedocs.yml b/.readthedocs.yml
new file mode 100644
index 00000000..332647c9
--- /dev/null
+++ b/.readthedocs.yml
@@ -0,0 +1,9 @@
+version: 2
+
+formats: []
+
+python:
+ version: 3.7
+ install:
+ - requirements: requirements/docs.txt
+ - requirements: requirements/readthedocs.txt
diff --git a/.style.yapf b/.style.yapf
new file mode 100644
index 00000000..286a3f1d
--- /dev/null
+++ b/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+BASED_ON_STYLE = pep8
+BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
+SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
diff --git a/README.md b/README.md
index 01a06bd4..a43f6d26 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,10 @@
+[🛠️Installation](docs/en/install.md) |
+[🚀Model Zoo](docs/en/model_zoos/video_benchmarks.md) |
+[🆕News](docs/en/changelog.md)
+
This repository is an open-source project for video prediction benchmarks, which contains the implementation code for paper:
**SimVP: Towards Simple yet Powerful Spatiotemporal Predictive learning**
@@ -125,10 +129,11 @@ We support various video prediction methods and will provide benchmarks on vario
Currently supported datasets
- - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)]
- - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)]
- - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)]
- - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)]
+ - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kth)]
+ - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kitticaltech)]
+ - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/mmnist)]
+ - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)] [[config](https://github.com/chengtan9907/SimVPv2/configs/taxibj)]
+ - [x] [WeatherBench](https://arxiv.org/abs/2002.00469) (ArXiv'2020) [[download](https://github.com/pangeo-data/WeatherBench)] [[config](https://github.com/chengtan9907/SimVPv2/configs/weather)]
diff --git a/configs/mmnist/simvp/SimVP_ConvMixer.py b/configs/mmnist/simvp/SimVP_ConvMixer.py
new file mode 100644
index 00000000..5121eb68
--- /dev/null
+++ b/configs/mmnist/simvp/SimVP_ConvMixer.py
@@ -0,0 +1,11 @@
+method = 'SimVP'
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'convmixer'
+hid_S = 64
+hid_T = 512
+N_T = 8
+N_S = 4
+lr = 1e-2
+batch_size = 16
+drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_ConvNeXt.py b/configs/mmnist/simvp/SimVP_ConvNeXt.py
index cb1822e8..f7c95c57 100644
--- a/configs/mmnist/simvp/SimVP_ConvNeXt.py
+++ b/configs/mmnist/simvp/SimVP_ConvNeXt.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-2
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_HorNet.py b/configs/mmnist/simvp/SimVP_HorNet.py
index 96e76850..d896f059 100644
--- a/configs/mmnist/simvp/SimVP_HorNet.py
+++ b/configs/mmnist/simvp/SimVP_HorNet.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-3
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_MLPMixer.py b/configs/mmnist/simvp/SimVP_MLPMixer.py
new file mode 100644
index 00000000..8b3cf0ab
--- /dev/null
+++ b/configs/mmnist/simvp/SimVP_MLPMixer.py
@@ -0,0 +1,11 @@
+method = 'SimVP'
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'mlp'
+hid_S = 64
+hid_T = 512
+N_T = 8
+N_S = 4
+lr = 1e-3
+batch_size = 16
+drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_MogaNet.py b/configs/mmnist/simvp/SimVP_MogaNet.py
index 638d8476..d8d39061 100644
--- a/configs/mmnist/simvp/SimVP_MogaNet.py
+++ b/configs/mmnist/simvp/SimVP_MogaNet.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-3
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_Poolformer.py b/configs/mmnist/simvp/SimVP_Poolformer.py
index 21463b64..3305cae2 100644
--- a/configs/mmnist/simvp/SimVP_Poolformer.py
+++ b/configs/mmnist/simvp/SimVP_Poolformer.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-3
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_Swin.py b/configs/mmnist/simvp/SimVP_Swin.py
index e433852f..79ffcc51 100644
--- a/configs/mmnist/simvp/SimVP_Swin.py
+++ b/configs/mmnist/simvp/SimVP_Swin.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-3
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_Uniformer.py b/configs/mmnist/simvp/SimVP_Uniformer.py
index 99bfb060..c7f10531 100644
--- a/configs/mmnist/simvp/SimVP_Uniformer.py
+++ b/configs/mmnist/simvp/SimVP_Uniformer.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 5e-4
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_VAN.py b/configs/mmnist/simvp/SimVP_VAN.py
new file mode 100644
index 00000000..3b5632e0
--- /dev/null
+++ b/configs/mmnist/simvp/SimVP_VAN.py
@@ -0,0 +1,11 @@
+method = 'SimVP'
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'convmixer'
+hid_S = 64
+hid_T = 512
+N_T = 8
+N_S = 4
+lr = 1e-3
+batch_size = 16
+drop_path = 0
\ No newline at end of file
diff --git a/configs/mmnist/simvp/SimVP_ViT.py b/configs/mmnist/simvp/SimVP_ViT.py
index fadbf20a..1035acd8 100644
--- a/configs/mmnist/simvp/SimVP_ViT.py
+++ b/configs/mmnist/simvp/SimVP_ViT.py
@@ -6,4 +6,6 @@
hid_T = 512
N_T = 8
N_S = 4
+lr = 1e-3
+batch_size = 16
drop_path = 0
\ No newline at end of file
diff --git a/configs/taxibj/SimVP.py b/configs/taxibj/SimVP.py
new file mode 100644
index 00000000..8a63cb8c
--- /dev/null
+++ b/configs/taxibj/SimVP.py
@@ -0,0 +1,8 @@
+method = 'SimVP'
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+# model_type = None
+hid_S = 64
+hid_T = 512
+N_T = 8
+N_S = 4
\ No newline at end of file
diff --git a/docs/en/Makefile b/docs/en/Makefile
new file mode 100644
index 00000000..d4bb2cbb
--- /dev/null
+++ b/docs/en/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/en/changelog.md b/docs/en/changelog.md
new file mode 100644
index 00000000..94a9fcd0
--- /dev/null
+++ b/docs/en/changelog.md
@@ -0,0 +1,20 @@
+## Changelog
+
+### v0.1.0 (18/02/2023)
+
+Release version to V0.1.0 with code refactoring.
+
+#### Code Refactoring
+
+* Refactor code structures as `simvp/api`, `simvp/core`, `simvp/datasets`, `simvp/methods`, `simvp/models`, `simvp/modules`. We support non-distributed training and evaluation by the executable python file `tools/non_dist_train.py`. Refactor config files for SimVP models.
+* Fix bugs in tools/nondist_train.py, simvp/utils, environment.yml, and .gitignore, etc.
+
+#### New Features
+
+* Support Timm optimizers and schedulers.
+* Update popular Metaformer models as the hidden Translator $h$ in SimVP, supporting [ViT](https://arxiv.org/abs/2010.11929), [Swin-Transformer](https://arxiv.org/abs/2103.14030), [MLP-Mixer](https://arxiv.org/abs/2105.01601), [ConvMixer](https://arxiv.org/abs/2201.09792), [UniFormer](https://arxiv.org/abs/2201.09450), [PoolFormer](https://arxiv.org/abs/2111.11418), [ConvNeXt](https://arxiv.org/abs/2201.03545), [VAN](https://arxiv.org/abs/2202.09741), [HorNet](https://arxiv.org/abs/2207.14284), and [MogaNet](https://arxiv.org/abs/2211.03295).
+* Update implementations of dataset and dataloader, supporting [KTH Action](https://ieeexplore.ieee.org/document/1334462), [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297), [Moving MNIST](http://arxiv.org/abs/1502.04681), [TaxiBJ](https://arxiv.org/abs/1610.00081), and [WeatherBench](https://arxiv.org/abs/2002.00469).
+
+### Update Documents
+
+* Upload `readthedocs` documents. Summarize video prediction benchmark results on MMNIST in [video_benchmarks.md](https://github.com/chengtan9907/SimVPv2/docs/en/model_zoos/video_benchmarks.md).
diff --git a/docs/en/get_started.md b/docs/en/get_started.md
new file mode 100644
index 00000000..d467eaac
--- /dev/null
+++ b/docs/en/get_started.md
@@ -0,0 +1,9 @@
+# Getting Started
+
+This page provides basic tutorials about the usage of SimVP. For installation instructions, please see [Install](docs/en/install.md).
+
+An example of single GPU training SimVP+gSTA on Moving MNIST dataset.
+```shell
+bash tools/prepare_data/download_mmnist.sh
+python tools/non_dist_train.py -d mmnist -m SimVP --model_type gsta --lr 1e-3 --ex_name mmnist_simvp_gsta
+```
diff --git a/docs/en/index.rst b/docs/en/index.rst
new file mode 100644
index 00000000..93c09f01
--- /dev/null
+++ b/docs/en/index.rst
@@ -0,0 +1,38 @@
+.. SimVP documentation master file, created by
+ sphinx-quickstart on Thu June 15 05:11:34 2022.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to SimVP's documentation!
+=====================================
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Getting Started
+
+ install.md
+ get_started.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Model Zoos
+
+ model_zoos/video_benchmarks.md
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Notes
+
+ changelog.md
+
+.. toctree::
+ :caption: Switch Language
+
+ switch_language.md
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`search`
diff --git a/docs/en/install.md b/docs/en/install.md
new file mode 100644
index 00000000..ebe57f8d
--- /dev/null
+++ b/docs/en/install.md
@@ -0,0 +1,22 @@
+# Installation
+
+This project has provided an environment setting file of conda, users can easily reproduce the environment by the following commands:
+```shell
+git clone https://github.com/chengtan9907/SimVPv2
+cd SimVPv2
+conda env create -f environment.yml
+conda activate SimVP
+python setup.py develop
+```
+
+
+Dependencies
+
+* argparse
+* numpy
+* hickle
+* scikit-image=0.16.2
+* torch
+* timm
+* tqdm
+
diff --git a/docs/en/model_zoos/video_benchmarks.md b/docs/en/model_zoos/video_benchmarks.md
new file mode 100644
index 00000000..85449ca2
--- /dev/null
+++ b/docs/en/model_zoos/video_benchmarks.md
@@ -0,0 +1,90 @@
+# Video Prediction Benchmarks
+
+**We provide benchmark results of video prediction methods on video datasets. More video prediction methods will be supported in the future. Issues and PRs are welcome!**
+
+
+Currently supported video prediction methods
+
+- [x] [ConvLSTM](https://arxiv.org/abs/1506.04214) (NIPS'2015)
+- [x] [PredRNN](https://dl.acm.org/doi/abs/10.5555/3294771.3294855) (NIPS'2017)
+- [x] [PredRNN++](https://arxiv.org/abs/1804.06300) (ICML'2018)
+- [x] [E3D-LSTM](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2018)
+- [x] [MAU](https://arxiv.org/abs/1811.07490) (CVPR'2019)
+- [x] [CrevNet](https://openreview.net/forum?id=B1lKS2AqtX) (ICLR'2020)
+- [x] [PhyDNet](https://arxiv.org/abs/2003.01460) (CVPR'2020)
+- [x] [PredRNN.V2](https://arxiv.org/abs/2103.09504v4) (TPAMI'2022)
+- [x] [SimVP](https://arxiv.org/abs/2206.05099) (CVPR'2022)
+- [x] [SimVP.V2](https://arxiv.org/abs/2211.12509) (ArXiv'2022)
+
+
+
+
+Currently supported MetaFormer models for SimVP
+
+- [x] [ViT](https://arxiv.org/abs/2010.11929) (ICLR'2021)
+- [x] [Swin-Transformer](https://arxiv.org/abs/2103.14030) (ICCV'2021)
+- [x] [MLP-Mixer](https://arxiv.org/abs/2105.01601) (NIPS'2021)
+- [x] [ConvMixer](https://arxiv.org/abs/2201.09792) (Openreview'2021)
+- [x] [UniFormer](https://arxiv.org/abs/2201.09450) (ICLR'2022)
+- [x] [PoolFormer](https://arxiv.org/abs/2111.11418) (CVPR'2022)
+- [x] [ConvNeXt](https://arxiv.org/abs/2201.03545) (CVPR'2022)
+- [x] [VAN](https://arxiv.org/abs/2202.09741) (ArXiv'2022)
+- [x] [IncepU (SimVP.V1)](https://arxiv.org/abs/2206.05099) (CVPR'2022)
+- [x] [gSTA (SimVP.V2)](https://arxiv.org/abs/2211.12509) (ArXiv'2022)
+- [x] [HorNet](https://arxiv.org/abs/2207.14284) (NIPS'2022)
+- [x] [MogaNet](https://arxiv.org/abs/2211.03295) (ArXiv'2022)
+
+
+
+## Moving MNIST Benchmarks
+
+We provide benchmark results on popular [Moving MNIST](http://arxiv.org/abs/1502.04681) dataset using $10\rightarrow 10$ frames prediction setting. Metrics (MSE, MAE, SSIM, pSNR) of the final models are reported in three trials. Parameters (M), FLOPs (G), inference FPS (s) are also reported for all methods.
+
+### **Benchmark of Video Prediction Methods**
+
+For fair comparison of different methods, we report final results when models are trained to convergence. We provide config file in [configs/mmnist](https://github.com/chengtan9907/SimVPv2/configs/mmnist).
+
+| Method | Params | FLOPs | FPS | MSE | MAE | SSIM | Download |
+|--------------|:------:|:------:|:---:|:-----:|:------:|:-----:|:------------:|
+| ConvLSTM-S | 15.0M | 56.8G | 113 | 46.26 | 142.18 | 0.878 | model \| log |
+| ConvLSTM-L | 33.8M | 127.0G | 50 | 29.88 | 95.05 | 0.925 | model \| log |
+| PhyDNet | 3.1M | 15.3G | 182 | 35.68 | 96.70 | 0.917 | model \| log |
+| PredRNN | 23.8M | 116.0G | 54 | 25.04 | 76.26 | 0.944 | model \| log |
+| PredRNN++ | 38.6M | 171.7G | 38 | 22.45 | 69.70 | 0.950 | model \| log |
+| MIM | 38.0M | 179.2G | 37 | 23.66 | 74.37 | 0.946 | model \| log |
+| E3D-LSTM | 51.0M | 298.9G | 18 | 36.19 | 78.64 | 0.932 | model \| log |
+| CrevNet | 5.0M | 270.7G | 10 | 30.15 | 86.28 | 0.935 | model \| log |
+| PredRNN.V2 | 23.9M | 116.6G | 52 | 27.73 | 82.17 | 0.937 | model \| log |
+| SimVP+IncepU | 58.0M | 19.4G | 209 | 26.69 | 77.19 | 0.940 | model \| log |
+| SimVP+gSTA-S | 46.8M | 16.5G | 282 | 15.05 | 49.80 | 0.967 | model \| log |
+
+### **Benchmark of MetaFormers on SimVP**
+
+Since the hidden Translator in [SimVP](https://arxiv.org/abs/2211.12509) can be replaced by any [Metaformer](https://arxiv.org/abs/2111.11418) block which achieves `token mixing` and `channel mixing`, we benchmark popular Metaformer architectures on SimVP with training times of 200-epoch and 2000-epoch. We provide config file in [configs/mmnist/simvp](https://github.com/chengtan9907/SimVPv2/configs/mmnist/simvp/).
+
+| MetaFormer | Setting | Params | FLOPs | FPS | MSE | MAE | SSIM | PSNR | Download |
+|------------------|:----------:|:------:|:------:|:----:|:-----:|:-----:|:------:|:-----:|:------------:|
+| IncepU (SimVPv1) | 200 epoch | 58.0M | 19.4G | 209s | 32.15 | 89.05 | 0.9268 | 37.97 | model \| log |
+| gSTA (SimVPv2) | 200 epoch | 46.8M | 16.5G | 282s | 26.69 | 77.19 | 0.9402 | 38.3 | model \| log |
+| ViT | 200 epoch | 46.1M | 16.9.G | 290s | 35.15 | 95.87 | 0.9139 | 37.79 | model \| log |
+| Swin Transformer | 200 epoch | 46.1M | 16.4G | 294s | 29.70 | 84.05 | 0.9331 | 38.14 | model \| log |
+| Uniformer | 200 epoch | 44.8M | 16.5G | 296s | 30.38 | 85.87 | 0.9308 | 38.11 | model \| log |
+| MLP-Mixer | 200 epoch | 38.2M | 14.7G | 334s | 29.52 | 83.36 | 0.9338 | 38.19 | model \| log |
+| ConvMixer | 200 epoch | 3.9M | 5.5G | 658s | 32.09 | 88.93 | 0.9259 | 37.97 | model \| log |
+| Poolformer | 200 epoch | 37.1M | 14.1G | 341s | 31.79 | 88.48 | 0.9271 | 38.06 | model \| log |
+| ConvNeXt | 200 epoch | 37.3M | 14.1G | 344s | 26.94 | 77.23 | 0.9397 | 38.34 | model \| log |
+| VAN | 200 epoch | 44.5M | 16.0G | 288s | 26.10 | 76.11 | 0.9417 | 38.39 | model \| log |
+| HorNet | 200 epoch | 45.7M | 16.3G | 287s | 29.64 | 83.26 | 0.9331 | 38.16 | model \| log |
+| MogaNet | 200 epoch | 46.8M | 16.5G | 255s | 25.57 | 75.19 | 0.9429 | 38.41 | model \| log |
+| IncepU (SimVPv1) | 2000 epoch | 58.0M | 19.4G | 209s | - | - | - | - | - |
+| gSTA (SimVPv2) | 2000 epoch | 46.8M | 16.5G | 282s | 15.05 | 49.80 | 0.9670 | - | model \| log |
+| ViT | 2000 epoch | 46.1M | 16.9.G | 290s | 19.74 | 61.65 | 0.9539 | 38.96 | model \| log |
+| Swin Transformer | 2000 epoch | 46.1M | 16.4G | 294s | 19.11 | 59.84 | 0.9584 | 39.03 | model \| log |
+| Uniformer | 2000 epoch | 44.8M | 16.5G | 296s | 18.01 | 57.52 | 0.9609 | 39.11 | model \| log |
+| MLP-Mixer | 2000 epoch | 38.2M | 14.7G | 334s | 18.85 | 59.86 | 0.9589 | 38.98 | model \| log |
+| ConvMixer | 2000 epoch | 3.9M | 5.5G | 658s | 22.30 | 67.37 | 0.9507 | 38.67 | model \| log |
+| Poolformer | 2000 epoch | 37.1M | 14.1G | 341s | 20.96 | 64.31 | 0.9539 | 38.86 | model \| log |
+| ConvNeXt | 2000 epoch | 37.3M | 14.1G | 344s | 17.58 | 55.76 | 0.9617 | 39.19 | model \| log |
+| VAN | 2000 epoch | 44.5M | 16.0G | 288s | 16.21 | 53.57 | 0.9646 | 39.26 | model \| log |
+| HorNet | 2000 epoch | 45.7M | 16.3G | 287s | 17.40 | 55.70 | 0.9624 | 39.19 | model \| log |
+| MogaNet | 2000 epoch | 46.8M | 16.5G | 255s | 15.67 | 51.84 | 0.9661 | 39.35 | model \| log |
diff --git a/docs/en/switch_language.md b/docs/en/switch_language.md
new file mode 100644
index 00000000..4cf942c1
--- /dev/null
+++ b/docs/en/switch_language.md
@@ -0,0 +1 @@
+## English
diff --git a/simvp/api/train.py b/simvp/api/train.py
index abb89de2..4bcaeb48 100644
--- a/simvp/api/train.py
+++ b/simvp/api/train.py
@@ -12,7 +12,7 @@
from simvp.core import metric, Recorder
from simvp.methods import method_maps
from simvp.utils import (set_seed, print_log, output_namespace, check_dir,
- get_dataset)
+ get_dataset, measure_throughput)
try:
import nni
@@ -35,34 +35,36 @@ def __init__(self, args):
T, C, H, W = self.args.in_shape
if self.args.method == 'simvp':
- _tmp_input = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
- flops = FlopCountAnalysis(self.method.model, _tmp_input)
+ input_dummy = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
elif self.args.method == 'crevnet':
# crevnet must use the batchsize rather than 1
- _tmp_input = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device)
- flops = FlopCountAnalysis(self.method.model, _tmp_input)
+ input_dummy = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device)
elif self.args.method == 'phydnet':
_tmp_input1 = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
_tmp_input2 = torch.ones(1, self.args.aft_seq_length, C, H, W).to(self.device)
_tmp_constraints = torch.zeros((49, 7, 7)).to(self.device)
- flops = FlopCountAnalysis(self.method.model, (_tmp_input1, _tmp_input2, _tmp_constraints))
+ input_dummy = (_tmp_input1, _tmp_input2, _tmp_constraints)
elif self.args.method in ['convlstm', 'predrnnpp', 'predrnn', 'mim', 'e3dlstm', 'mau']:
Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
Cp = self.args.patch_size ** 2 * C
_tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
_tmp_flag = torch.ones(1, self.args.aft_seq_length - 1, Hp, Wp, Cp).to(self.device)
- flops = FlopCountAnalysis(self.method.model, (_tmp_input, _tmp_flag))
+ input_dummy = (_tmp_input, _tmp_flag)
elif self.args.method == 'predrnnv2':
Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
Cp = self.args.patch_size ** 2 * C
_tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
_tmp_flag = torch.ones(1, self.args.total_length - 2, Hp, Wp, Cp).to(self.device)
- flops = FlopCountAnalysis(self.method.model, (_tmp_input, _tmp_flag))
+ input_dummy = (_tmp_input, _tmp_flag)
else:
raise ValueError(f'Invalid method name {self.args.method}')
print_log(self.method.model)
+ flops = FlopCountAnalysis(self.method.model, input_dummy)
print_log(flop_count_table(flops))
+ if args.fps:
+ fps = measure_throughput(self.method.model, input_dummy)
+ print_log('Throughputs of {}: {:.3f}'.format(self.args.method, fps))
def _acquire_device(self):
if self.args.use_gpu:
diff --git a/simvp/datasets/dataloader_kitticaltech.py b/simvp/datasets/dataloader_kitticaltech.py
index 3ca5d2b7..9c6f50e7 100644
--- a/simvp/datasets/dataloader_kitticaltech.py
+++ b/simvp/datasets/dataloader_kitticaltech.py
@@ -121,7 +121,7 @@ def load_data(batch_size, val_batch_size, data_root,
}
input_handle = DataProcess(input_param)
train_data, train_idx = input_handle.load_data('train')
- test_data, test_idx = input_handle.load_data('val')
+ test_data, test_idx = input_handle.load_data('test')
elif os.path.exists(osp.join(data_root, 'kitticaltech_npy')):
train_data = np.load(osp.join(data_root, 'kitticaltech_npy', 'train_data.npy'))
train_idx = np.load(osp.join(data_root, 'kitticaltech_npy', 'train_idx.npy'))
diff --git a/simvp/datasets/dataloader_weather.py b/simvp/datasets/dataloader_weather.py
index 04ab5d6a..e0657cc2 100644
--- a/simvp/datasets/dataloader_weather.py
+++ b/simvp/datasets/dataloader_weather.py
@@ -143,8 +143,8 @@ def load_data(batch_size,
val_batch_size,
data_root,
num_workers=2,
- data_name=['t'],
- train_time=['2010', '2015'],
+ data_name='t2m',
+ train_time=['1979', '2015'],
val_time=['2016', '2016'],
test_time=['2017', '2018'],
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
diff --git a/simvp/modules/simvp_modules.py b/simvp/modules/simvp_modules.py
index 46390fed..53de7ec6 100644
--- a/simvp/modules/simvp_modules.py
+++ b/simvp/modules/simvp_modules.py
@@ -514,6 +514,7 @@ class ViTSubBlock(ViTBlock):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1):
super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True,
drop=drop, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
diff --git a/simvp/utils/__init__.py b/simvp/utils/__init__.py
index 418f11e4..e8df9449 100644
--- a/simvp/utils/__init__.py
+++ b/simvp/utils/__init__.py
@@ -2,7 +2,7 @@
from .config_utils import Config, check_file_exist
from .main_utils import (set_seed, print_log, output_namespace, check_dir, get_dataset,
- count_parameters, load_config, update_config)
+ count_parameters, measure_throughput, load_config, update_config)
from .parser import create_parser
from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch,
reshape_patch_back)
@@ -10,6 +10,6 @@
__all__ = [
'Config', 'check_file_exist', 'create_parser',
'set_seed', 'print_log', 'output_namespace', 'check_dir', 'get_dataset', 'count_parameters',
- 'load_config', 'update_config',
+ 'measure_throughput', 'load_config', 'update_config',
'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back',
]
\ No newline at end of file
diff --git a/simvp/utils/main_utils.py b/simvp/utils/main_utils.py
index d72cd909..42a03b01 100644
--- a/simvp/utils/main_utils.py
+++ b/simvp/utils/main_utils.py
@@ -47,6 +47,35 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
+def measure_throughput(model, input_dummy):
+ bs = 100
+ repetitions = 100
+ if isinstance(input_dummy, tuple):
+ input_dummy = list(input_dummy)
+ _, T, C, H, W = input_dummy[0].shape
+ _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device)
+ input_dummy[0] = _input
+ input_dummy = tuple(input_dummy)
+ else:
+ _, T, C, H, W = input_dummy.shape
+ input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device)
+ total_time = 0
+ with torch.no_grad():
+ for _ in range(repetitions):
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+ if isinstance(input_dummy, tuple):
+ _ = model(*input_dummy)
+ else:
+ _ = model(input_dummy)
+ ender.record()
+ torch.cuda.synchronize()
+ curr_time = starter.elapsed_time(ender) / 1000
+ total_time += curr_time
+ Throughput = (repetitions * bs) / total_time
+ return Throughput
+
+
def load_config(filename:str = None):
'''
load and print config
diff --git a/simvp/utils/parser.py b/simvp/utils/parser.py
index c52e3350..6bcaa0d0 100644
--- a/simvp/utils/parser.py
+++ b/simvp/utils/parser.py
@@ -15,18 +15,20 @@ def create_parser():
parser.add_argument('--use_gpu', default=True, type=bool)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--fps', action='store_true', default=False,
+ help='Whether to measure inference speed (FPS)')
# dataset parameters
- parser.add_argument('--batch_size', '-b', default=16, type=int, help="Training batch size")
- parser.add_argument('--val_batch_size', '-vb', default=4, type=int, help="Validation batch size")
+ parser.add_argument('--batch_size', '-b', default=16, type=int, help='Training batch size')
+ parser.add_argument('--val_batch_size', '-vb', default=4, type=int, help='Validation batch size')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--data_root', default='./data/')
parser.add_argument('--dataname', '-d', default='mmnist', type=str,
choices=['mmnist', 'kitticaltech', 'kth', 'kth40', 'taxibj', 'weather'],
help='Dataset name (default: "mmnist")')
- parser.add_argument('--pre_seq_length', default=None, type=int, help="Sequence length before prediction")
- parser.add_argument('--aft_seq_length', default=None, type=int, help="Sequence length after prediction")
- parser.add_argument('--total_length', default=None, type=int, help="Total Sequence length for prediction")
+ parser.add_argument('--pre_seq_length', default=None, type=int, help='Sequence length before prediction')
+ parser.add_argument('--aft_seq_length', default=None, type=int, help='Sequence length after prediction')
+ parser.add_argument('--total_length', default=None, type=int, help='Total Sequence length for prediction')
# method parameters
parser.add_argument('--method', '-m', default='SimVP', type=str,