From 03c543b4b2a51d0eb3d1a0c832eb90858150f6bc Mon Sep 17 00:00:00 2001
From: Lupin1998 <1070535169@qq.com>
Date: Tue, 18 Apr 2023 23:54:30 +0000
Subject: [PATCH] code refactor for OpenSTL V0.2.0
---
.coveragerc | 5 +
README.md | 77 ++---
.../SimVP_ConvMixer.py | 0
.../SimVP_ConvNeXt.py | 0
.../SimVP_HorNet.py | 0
.../SimVP_IncepU.py | 0
.../SimVP_MLPMixer.py | 0
.../SimVP_MogaNet.py | 0
.../SimVP_Poolformer.py | 0
.../{simvp_r_5_625 => r_5_625}/SimVP_Swin.py | 0
.../SimVP_Uniformer.py | 0
.../{simvp_r_5_625 => r_5_625}/SimVP_VAN.py | 0
.../weather/{simvp => r_5_625}/SimVP_ViT.py | 0
.../weather/{simvp => r_5_625}/SimVP_gSTA.py | 0
.../weather/simvp_t2m_5_625/SimVP_ConvNeXt.py | 15 -
.../weather/simvp_t2m_5_625/SimVP_HorNet.py | 15 -
.../weather/simvp_t2m_5_625/SimVP_MLPMixer.py | 15 -
.../weather/simvp_t2m_5_625/SimVP_MogaNet.py | 15 -
.../simvp_t2m_5_625/SimVP_Poolformer.py | 15 -
configs/weather/simvp_t2m_5_625/SimVP_VAN.py | 15 -
configs/weather/simvp_t2m_5_625/SimVP_ViT.py | 15 -
configs/weather/simvp_t2m_5_625/SimVP_gSTA.py | 14 -
.../weather/simvp_tcc_5_625/SimVP_IncepU.py | 15 -
configs/weather/simvp_tcc_5_625/SimVP_Swin.py | 15 -
.../simvp_tcc_5_625/SimVP_Uniformer.py | 15 -
.../simvp_uv10_5_625/SimVP_ConvMixer.py | 15 -
.../{simvp => t2m_5_625}/SimVP_ConvMixer.py | 0
.../{simvp => t2m_5_625}/SimVP_ConvNeXt.py | 0
.../{simvp => t2m_5_625}/SimVP_HorNet.py | 0
.../{simvp => t2m_5_625}/SimVP_IncepU.py | 0
.../{simvp => t2m_5_625}/SimVP_MLPMixer.py | 0
.../{simvp => t2m_5_625}/SimVP_MogaNet.py | 0
.../{simvp => t2m_5_625}/SimVP_Poolformer.py | 0
.../{simvp => t2m_5_625}/SimVP_Swin.py | 0
.../{simvp => t2m_5_625}/SimVP_Uniformer.py | 0
.../weather/{simvp => t2m_5_625}/SimVP_VAN.py | 0
.../{simvp_r_5_625 => t2m_5_625}/SimVP_ViT.py | 0
.../SimVP_gSTA.py | 0
configs/weather/tcc_5_625/ConvLSTM.py | 19 +
configs/weather/tcc_5_625/PredRNN.py | 19 +
configs/weather/tcc_5_625/PredRNNpp.py | 19 +
configs/weather/tcc_5_625/PredRNNv2.py | 20 ++
.../SimVP_ConvMixer.py | 0
.../SimVP_ConvNeXt.py | 0
.../SimVP_HorNet.py | 0
.../SimVP_IncepU.py | 0
.../SimVP_MLPMixer.py | 0
.../SimVP_MogaNet.py | 0
.../SimVP_Poolformer.py | 0
.../SimVP_Swin.py | 0
.../SimVP_Uniformer.py | 0
.../SimVP_VAN.py | 0
.../SimVP_ViT.py | 0
.../SimVP_gSTA.py | 0
configs/weather/uv10_5_625/ConvLSTM.py | 19 +
configs/weather/uv10_5_625/PredRNN.py | 19 +
configs/weather/uv10_5_625/PredRNNpp.py | 19 +
configs/weather/uv10_5_625/PredRNNv2.py | 20 ++
.../SimVP_ConvMixer.py | 0
.../SimVP_ConvNeXt.py | 0
.../SimVP_HorNet.py | 0
.../SimVP_IncepU.py | 0
.../SimVP_MLPMixer.py | 0
.../SimVP_MogaNet.py | 0
.../SimVP_Poolformer.py | 0
.../SimVP_Swin.py | 0
.../SimVP_Uniformer.py | 0
.../SimVP_VAN.py | 0
.../SimVP_ViT.py | 0
.../SimVP_gSTA.py | 0
docs/en/changelog.md | 2 +-
docs/en/conf.py | 2 +-
docs/en/get_started.md | 8 +-
docs/en/install.md | 13 +-
environment.yml | 4 +-
{simvp => openstl}/__init__.py | 0
openstl/api/__init__.py | 7 +
{simvp => openstl}/api/train.py | 252 +++++++++-----
{simvp => openstl}/core/__init__.py | 0
{simvp => openstl}/core/ema_hook.py | 4 +
{simvp => openstl}/core/hooks.py | 0
{simvp => openstl}/core/metrics.py | 0
{simvp => openstl}/core/optim_constant.py | 0
{simvp => openstl}/core/optim_scheduler.py | 0
{simvp => openstl}/core/recorder.py | 0
{simvp => openstl}/datasets/__init__.py | 3 +-
{simvp => openstl}/datasets/dataloader.py | 18 +-
.../datasets/dataloader_kitticaltech.py | 29 +-
{simvp => openstl}/datasets/dataloader_kth.py | 25 +-
.../datasets/dataloader_moving_mnist.py | 33 +-
openstl/datasets/dataloader_taxibj.py | 52 +++
.../datasets/dataloader_weather.py | 45 +--
.../datasets/dataset_constant.py | 0
openstl/datasets/utils.py | 194 +++++++++++
{simvp => openstl}/methods/__init__.py | 0
openstl/methods/base_method.py | 184 ++++++++++
{simvp => openstl}/methods/convlstm.py | 2 +-
{simvp => openstl}/methods/crevnet.py | 8 +-
{simvp => openstl}/methods/e3dlstm.py | 2 +-
{simvp => openstl}/methods/mau.py | 4 +-
{simvp => openstl}/methods/mim.py | 2 +-
{simvp => openstl}/methods/phydnet.py | 6 +-
openstl/methods/predrnn.py | 108 ++++++
{simvp => openstl}/methods/predrnnpp.py | 2 +-
{simvp => openstl}/methods/predrnnv2.py | 4 +-
{simvp => openstl}/methods/simvp.py | 94 +++--
{simvp => openstl}/models/__init__.py | 0
{simvp => openstl}/models/convlstm_model.py | 2 +-
{simvp => openstl}/models/crevnet_model.py | 2 +-
{simvp => openstl}/models/e3dlstm_model.py | 2 +-
{simvp => openstl}/models/mau_model.py | 2 +-
{simvp => openstl}/models/mim_model.py | 2 +-
{simvp => openstl}/models/phydnet_model.py | 2 +-
{simvp => openstl}/models/predrnn_model.py | 2 +-
{simvp => openstl}/models/predrnnpp_model.py | 2 +-
{simvp => openstl}/models/predrnnv2_model.py | 2 +-
{simvp => openstl}/models/simvp_model.py | 6 +-
{simvp => openstl}/modules/__init__.py | 0
.../modules/convlstm_modules.py | 0
{simvp => openstl}/modules/crevnet_modules.py | 0
{simvp => openstl}/modules/e3dlstm_modules.py | 0
{simvp => openstl}/modules/layers/__init__.py | 0
{simvp => openstl}/modules/layers/hornet.py | 0
{simvp => openstl}/modules/layers/moganet.py | 0
.../modules/layers/poolformer.py | 0
.../modules/layers/uniformer.py | 0
{simvp => openstl}/modules/layers/van.py | 0
{simvp => openstl}/modules/mau_modules.py | 0
{simvp => openstl}/modules/mim_modules.py | 0
{simvp => openstl}/modules/phydnet_modules.py | 0
{simvp => openstl}/modules/predrnn_modules.py | 0
.../modules/predrnnpp_modules.py | 0
.../modules/predrnnv2_modules.py | 0
{simvp => openstl}/modules/simvp_modules.py | 0
openstl/utils/__init__.py | 24 ++
openstl/utils/collect.py | 210 ++++++++++++
{simvp => openstl}/utils/config_utils.py | 0
openstl/utils/main_utils.py | 307 +++++++++++++++++
{simvp => openstl}/utils/parser.py | 40 ++-
{simvp => openstl}/utils/predrnn_utils.py | 0
openstl/utils/progressbar.py | 324 ++++++++++++++++++
{simvp => openstl}/version.py | 2 +-
setup.py | 15 +-
simvp/api/__init__.py | 5 -
simvp/datasets/dataloader_taxibj.py | 47 ---
simvp/methods/base_method.py | 70 ----
simvp/methods/predrnn.py | 161 ---------
simvp/utils/__init__.py | 16 -
simvp/utils/main_utils.py | 168 ---------
tools/dist_test.sh | 23 ++
tools/dist_train.sh | 23 ++
tools/{non_dist_test.py => test.py} | 14 +-
tools/{non_dist_train.py => train.py} | 17 +-
153 files changed, 2049 insertions(+), 963 deletions(-)
create mode 100644 .coveragerc
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_ConvMixer.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_ConvNeXt.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_HorNet.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_IncepU.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_MLPMixer.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_MogaNet.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_Poolformer.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_Swin.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_Uniformer.py (100%)
rename configs/weather/{simvp_r_5_625 => r_5_625}/SimVP_VAN.py (100%)
rename configs/weather/{simvp => r_5_625}/SimVP_ViT.py (100%)
rename configs/weather/{simvp => r_5_625}/SimVP_gSTA.py (100%)
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_ConvNeXt.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_HorNet.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_MLPMixer.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_MogaNet.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_Poolformer.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_VAN.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_ViT.py
delete mode 100644 configs/weather/simvp_t2m_5_625/SimVP_gSTA.py
delete mode 100644 configs/weather/simvp_tcc_5_625/SimVP_IncepU.py
delete mode 100644 configs/weather/simvp_tcc_5_625/SimVP_Swin.py
delete mode 100644 configs/weather/simvp_tcc_5_625/SimVP_Uniformer.py
delete mode 100644 configs/weather/simvp_uv10_5_625/SimVP_ConvMixer.py
rename configs/weather/{simvp => t2m_5_625}/SimVP_ConvMixer.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_ConvNeXt.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_HorNet.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_IncepU.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_MLPMixer.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_MogaNet.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_Poolformer.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_Swin.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_Uniformer.py (100%)
rename configs/weather/{simvp => t2m_5_625}/SimVP_VAN.py (100%)
rename configs/weather/{simvp_r_5_625 => t2m_5_625}/SimVP_ViT.py (100%)
rename configs/weather/{simvp_r_5_625 => t2m_5_625}/SimVP_gSTA.py (100%)
create mode 100644 configs/weather/tcc_5_625/ConvLSTM.py
create mode 100644 configs/weather/tcc_5_625/PredRNN.py
create mode 100644 configs/weather/tcc_5_625/PredRNNpp.py
create mode 100644 configs/weather/tcc_5_625/PredRNNv2.py
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_ConvMixer.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_ConvNeXt.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_HorNet.py (100%)
rename configs/weather/{simvp_t2m_5_625 => tcc_5_625}/SimVP_IncepU.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_MLPMixer.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_MogaNet.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_Poolformer.py (100%)
rename configs/weather/{simvp_t2m_5_625 => tcc_5_625}/SimVP_Swin.py (100%)
rename configs/weather/{simvp_t2m_5_625 => tcc_5_625}/SimVP_Uniformer.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_VAN.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_ViT.py (100%)
rename configs/weather/{simvp_tcc_5_625 => tcc_5_625}/SimVP_gSTA.py (100%)
create mode 100644 configs/weather/uv10_5_625/ConvLSTM.py
create mode 100644 configs/weather/uv10_5_625/PredRNN.py
create mode 100644 configs/weather/uv10_5_625/PredRNNpp.py
create mode 100644 configs/weather/uv10_5_625/PredRNNv2.py
rename configs/weather/{simvp_t2m_5_625 => uv10_5_625}/SimVP_ConvMixer.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_ConvNeXt.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_HorNet.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_IncepU.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_MLPMixer.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_MogaNet.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_Poolformer.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_Swin.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_Uniformer.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_VAN.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_ViT.py (100%)
rename configs/weather/{simvp_uv10_5_625 => uv10_5_625}/SimVP_gSTA.py (100%)
rename {simvp => openstl}/__init__.py (100%)
create mode 100644 openstl/api/__init__.py
rename {simvp => openstl}/api/train.py (62%)
rename {simvp => openstl}/core/__init__.py (100%)
rename {simvp => openstl}/core/ema_hook.py (99%)
rename {simvp => openstl}/core/hooks.py (100%)
rename {simvp => openstl}/core/metrics.py (100%)
rename {simvp => openstl}/core/optim_constant.py (100%)
rename {simvp => openstl}/core/optim_scheduler.py (100%)
rename {simvp => openstl}/core/recorder.py (100%)
rename {simvp => openstl}/datasets/__init__.py (84%)
rename {simvp => openstl}/datasets/dataloader.py (66%)
rename {simvp => openstl}/datasets/dataloader_kitticaltech.py (82%)
rename {simvp => openstl}/datasets/dataloader_kth.py (91%)
rename {simvp => openstl}/datasets/dataloader_moving_mnist.py (82%)
create mode 100644 openstl/datasets/dataloader_taxibj.py
rename {simvp => openstl}/datasets/dataloader_weather.py (83%)
rename {simvp => openstl}/datasets/dataset_constant.py (100%)
create mode 100644 openstl/datasets/utils.py
rename {simvp => openstl}/methods/__init__.py (100%)
create mode 100644 openstl/methods/base_method.py
rename {simvp => openstl}/methods/convlstm.py (94%)
rename {simvp => openstl}/methods/crevnet.py (95%)
rename {simvp => openstl}/methods/e3dlstm.py (94%)
rename {simvp => openstl}/methods/mau.py (98%)
rename {simvp => openstl}/methods/mim.py (96%)
rename {simvp => openstl}/methods/phydnet.py (96%)
create mode 100644 openstl/methods/predrnn.py
rename {simvp => openstl}/methods/predrnnpp.py (94%)
rename {simvp => openstl}/methods/predrnnv2.py (94%)
rename {simvp => openstl}/methods/simvp.py (50%)
rename {simvp => openstl}/models/__init__.py (100%)
rename {simvp => openstl}/models/convlstm_model.py (98%)
rename {simvp => openstl}/models/crevnet_model.py (98%)
rename {simvp => openstl}/models/e3dlstm_model.py (98%)
rename {simvp => openstl}/models/mau_model.py (99%)
rename {simvp => openstl}/models/mim_model.py (98%)
rename {simvp => openstl}/models/phydnet_model.py (97%)
rename {simvp => openstl}/models/predrnn_model.py (98%)
rename {simvp => openstl}/models/predrnnpp_model.py (98%)
rename {simvp => openstl}/models/predrnnv2_model.py (98%)
rename {simvp => openstl}/models/simvp_model.py (96%)
rename {simvp => openstl}/modules/__init__.py (100%)
rename {simvp => openstl}/modules/convlstm_modules.py (100%)
rename {simvp => openstl}/modules/crevnet_modules.py (100%)
rename {simvp => openstl}/modules/e3dlstm_modules.py (100%)
rename {simvp => openstl}/modules/layers/__init__.py (100%)
rename {simvp => openstl}/modules/layers/hornet.py (100%)
rename {simvp => openstl}/modules/layers/moganet.py (100%)
rename {simvp => openstl}/modules/layers/poolformer.py (100%)
rename {simvp => openstl}/modules/layers/uniformer.py (100%)
rename {simvp => openstl}/modules/layers/van.py (100%)
rename {simvp => openstl}/modules/mau_modules.py (100%)
rename {simvp => openstl}/modules/mim_modules.py (100%)
rename {simvp => openstl}/modules/phydnet_modules.py (100%)
rename {simvp => openstl}/modules/predrnn_modules.py (100%)
rename {simvp => openstl}/modules/predrnnpp_modules.py (100%)
rename {simvp => openstl}/modules/predrnnv2_modules.py (100%)
rename {simvp => openstl}/modules/simvp_modules.py (100%)
create mode 100644 openstl/utils/__init__.py
create mode 100644 openstl/utils/collect.py
rename {simvp => openstl}/utils/config_utils.py (100%)
create mode 100644 openstl/utils/main_utils.py
rename {simvp => openstl}/utils/parser.py (69%)
rename {simvp => openstl}/utils/predrnn_utils.py (100%)
create mode 100644 openstl/utils/progressbar.py
rename {simvp => openstl}/version.py (97%)
delete mode 100644 simvp/api/__init__.py
delete mode 100644 simvp/datasets/dataloader_taxibj.py
delete mode 100644 simvp/methods/base_method.py
delete mode 100644 simvp/methods/predrnn.py
delete mode 100644 simvp/utils/__init__.py
delete mode 100644 simvp/utils/main_utils.py
create mode 100644 tools/dist_test.sh
create mode 100644 tools/dist_train.sh
rename tools/{non_dist_test.py => test.py} (68%)
rename tools/{non_dist_train.py => train.py} (66%)
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 00000000..8a85b545
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,5 @@
+[report]
+exclude_lines =
+ @abstractmethod
+ @abc.abstractmethod
+ raise NotImplementedError
diff --git a/README.md b/README.md
index ebfa9ee5..2b210a29 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,15 @@
-# SimVP: Towards Simple yet Powerful Spatiotemporal Predictive learning
+# OpenSTL: Open-source Toolbox for SpatioTemporal Predictive Learning
-
+
-
+
-
+
@@ -25,47 +25,35 @@ This repository is an open-source project for video prediction benchmarks, which
## Introduction
-This is the journal version of our previous conference work ([SimVP: Simpler yet Better Video Prediction](https://arxiv.org/abs/2206.05099), In CVPR 2022).
-
-It is worth noticing that the hidden Translator $h$ in SimVP can be replaced by any [MetaFormer](https://arxiv.org/abs/2111.11418) block (satisfying the macro design of `token mixing` and `channel mixing`).
+This is the journal version of our previous conference work ([SimVP: Simpler yet Better Video Prediction](https://arxiv.org/abs/2206.05099), In CVPR 2022). It is worth noticing that the hidden Translator $h$ in SimVP can be replaced by any [MetaFormer](https://arxiv.org/abs/2111.11418) block (satisfying the macro design of `token mixing` and `channel mixing`).
-The performance of SimVPs on the Moving MNIST dataset. For the training time, the less the better. For the inference efficiency (frames per second), the more the better.
-
-
-
-
-Quantitative results of different methods on the Moving MNIST dataset ($10 \rightarrow 10$ frames).
-
-
-
-
(back to top)
## Overview
-* `simvp/api` contains an experiment runner.
-* `simvp/core` contains core training plugins and metrics.
-* `simvp/datasets` contains datasets and dataloaders.
-* `simvp/methods/` contains training methods for various video prediction methods.
-* `simvp/models/` contains the main network architectures of various video prediction methods.
-* `simvp/modules/` contains network modules and layers.
-* `tools/non_dist_train.py` is the executable python file with possible arguments for training, validating, and testing pipelines.
+* `openstl/api` contains an experiment runner.
+* `openstl/core` contains core training plugins and metrics.
+* `openstl/datasets` contains datasets and dataloaders.
+* `openstl/methods/` contains training methods for various video prediction methods.
+* `openstl/models/` contains the main network architectures of various video prediction methods.
+* `openstl/modules/` contains network modules and layers.
+* `tools/train.py` and `tools/train.py` are the executable python file with possible arguments for training, validating, and testing pipelines.
## News and Updates
-[2023-02-18] `SimVP` v0.1.0 is released. Benchmark results and config files are updated for MMNIST, TaxiBJ, and WeatherBench datasets.
+[2023-04-19] `OpenSTL` v0.2.0 is released.
## Installation
This project has provided an environment setting file of conda, users can easily reproduce the environment by the following commands:
```shell
-git clone https://github.com/chengtan9907/SimVPv2
-cd SimVPv2
+git clone https://github.com/chengtan9907/OpenSTL
+cd OpenSTL
conda env create -f environment.yml
-conda activate SimVP
+conda activate OpenSTL
python setup.py develop
```
@@ -76,21 +64,22 @@ python setup.py develop
* fvcore
* numpy
* hickle
-* scikit-image=0.16.2
+* scikit-image
* scikit-learn
* torch
* timm
* tqdm
+* xarray
Please refer to [install.md](docs/en/install.md) for more detailed instructions.
## Getting Started
-Please see [get_started.md](docs/en/get_started.md) for the basic usage. Here is an example of single GPU training SimVP+gSTA on Moving MNIST dataset.
+Please see [get_started.md](docs/en/get_started.md) for the basic usage. Here is an example of single GPU non-dist training SimVP+gSTA on Moving MNIST dataset.
```shell
bash tools/prepare_data/download_mmnist.sh
-python tools/non_dist_train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
+python tools/train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
(back to top)
@@ -140,11 +129,11 @@ We support various video prediction methods and will provide benchmarks on vario
Currently supported datasets
- - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kth)]
- - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)] [[config](https://github.com/chengtan9907/SimVPv2/configs/kitticaltech)]
- - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)] [[config](https://github.com/chengtan9907/SimVPv2/configs/mmnist)]
- - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)] [[config](https://github.com/chengtan9907/SimVPv2/configs/taxibj)]
- - [x] [WeatherBench](https://arxiv.org/abs/2002.00469) (ArXiv'2020) [[download](https://github.com/pangeo-data/WeatherBench)] [[config](https://github.com/chengtan9907/SimVPv2/configs/weather)]
+ - [x] [KTH Action](https://ieeexplore.ieee.org/document/1334462) (ICPR'2004) [[download](https://www.csc.kth.se/cvap/actions/)] [[config](configs/kth)]
+ - [x] [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) (IJRR'2013) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)] [[config](configs/kitticaltech)]
+ - [x] [Moving MNIST](http://arxiv.org/abs/1502.04681) (ICML'2015) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)] [[config](configs/mmnist)]
+ - [x] [TaxiBJ](https://arxiv.org/abs/1610.00081) (AAAI'2017) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)] [[config](configs/taxibj)]
+ - [x] [WeatherBench](https://arxiv.org/abs/2002.00469) (ArXiv'2020) [[download](https://github.com/pangeo-data/WeatherBench)] [[config](configs/weather)]
@@ -160,23 +149,31 @@ SimVPv2 is an open-source project for video prediction methods created by resear
## Citation
-If you are interested in our repository and our paper, please cite the following paper:
+If you are interested in our repository or our paper, please cite the following paper:
```
@article{tan2022simvp,
title={SimVP: Towards Simple yet Powerful Spatiotemporal Predictive Learning},
- author={Tan, Cheng and Gao, Zhangyang and Li, Stan Z},
+ author={Tan, Cheng and Gao, Zhangyang and Li, Siyuan and Li, Stan Z},
journal={arXiv preprint arXiv:2211.12509},
year={2022}
}
```
+```
+@misc{tan2023openstl,
+ title={OpenSTL: Open-source Toolbox for SpatioTemporal Predictive Learning},
+ author={Tan, Cheng and Li, Siyuan and Gao, Zhangyang and Li, Stan Z},
+ howpublished = {\url{https://github.com/chengtan9907/OpenSTL}},
+ year={2023}
+}
+```
## Contribution and Contact
-For adding new features, looking for helps, or reporting bugs associated with SimVPv2, please open a [GitHub issue](https://github.com/chengtan9907/SimVPv2/issues) and [pull request](https://github.com/chengtan9907/SimVPv2/pulls) with the tag "help wanted" or "enhancement". Feel free to contact us through email if you have any questions. Enjoy!
+For adding new features, looking for helps, or reporting bugs associated with `OpenSTL`, please open a [GitHub issue](https://github.com/chengtan9907/OpenSTL/issues) and [pull request](https://github.com/chengtan9907/OpenSTL/pulls) with the tag "help wanted" or "enhancement". Feel free to contact us through email if you have any questions. Enjoy!
-- Cheng Tan (tancheng@westlake.edu.cn), Westlake University & Zhejiang University
- Siyuan Li (lisiyuan@westlake.edu.cn), Westlake University & Zhejiang University
+- Cheng Tan (tancheng@westlake.edu.cn), Westlake University & Zhejiang University
- Zhangyang Gao (gaozhangyang@westlake.edu.cn), Westlake University & Zhejiang University
(back to top)
diff --git a/configs/weather/simvp_r_5_625/SimVP_ConvMixer.py b/configs/weather/r_5_625/SimVP_ConvMixer.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_ConvMixer.py
rename to configs/weather/r_5_625/SimVP_ConvMixer.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_ConvNeXt.py b/configs/weather/r_5_625/SimVP_ConvNeXt.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_ConvNeXt.py
rename to configs/weather/r_5_625/SimVP_ConvNeXt.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_HorNet.py b/configs/weather/r_5_625/SimVP_HorNet.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_HorNet.py
rename to configs/weather/r_5_625/SimVP_HorNet.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_IncepU.py b/configs/weather/r_5_625/SimVP_IncepU.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_IncepU.py
rename to configs/weather/r_5_625/SimVP_IncepU.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_MLPMixer.py b/configs/weather/r_5_625/SimVP_MLPMixer.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_MLPMixer.py
rename to configs/weather/r_5_625/SimVP_MLPMixer.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_MogaNet.py b/configs/weather/r_5_625/SimVP_MogaNet.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_MogaNet.py
rename to configs/weather/r_5_625/SimVP_MogaNet.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_Poolformer.py b/configs/weather/r_5_625/SimVP_Poolformer.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_Poolformer.py
rename to configs/weather/r_5_625/SimVP_Poolformer.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_Swin.py b/configs/weather/r_5_625/SimVP_Swin.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_Swin.py
rename to configs/weather/r_5_625/SimVP_Swin.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_Uniformer.py b/configs/weather/r_5_625/SimVP_Uniformer.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_Uniformer.py
rename to configs/weather/r_5_625/SimVP_Uniformer.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_VAN.py b/configs/weather/r_5_625/SimVP_VAN.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_VAN.py
rename to configs/weather/r_5_625/SimVP_VAN.py
diff --git a/configs/weather/simvp/SimVP_ViT.py b/configs/weather/r_5_625/SimVP_ViT.py
similarity index 100%
rename from configs/weather/simvp/SimVP_ViT.py
rename to configs/weather/r_5_625/SimVP_ViT.py
diff --git a/configs/weather/simvp/SimVP_gSTA.py b/configs/weather/r_5_625/SimVP_gSTA.py
similarity index 100%
rename from configs/weather/simvp/SimVP_gSTA.py
rename to configs/weather/r_5_625/SimVP_gSTA.py
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_ConvNeXt.py b/configs/weather/simvp_t2m_5_625/SimVP_ConvNeXt.py
deleted file mode 100644
index 93a809c2..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_ConvNeXt.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'convnext'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-2
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_HorNet.py b/configs/weather/simvp_t2m_5_625/SimVP_HorNet.py
deleted file mode 100644
index b96fa6d7..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_HorNet.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'hornet'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_MLPMixer.py b/configs/weather/simvp_t2m_5_625/SimVP_MLPMixer.py
deleted file mode 100644
index 954c69ec..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_MLPMixer.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'mlp'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_MogaNet.py b/configs/weather/simvp_t2m_5_625/SimVP_MogaNet.py
deleted file mode 100644
index f99111ec..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_MogaNet.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'moga'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 5e-3
-batch_size = 16
-drop_path = 0.2
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_Poolformer.py b/configs/weather/simvp_t2m_5_625/SimVP_Poolformer.py
deleted file mode 100644
index 1cea4e85..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_Poolformer.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'poolformer'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 5e-4
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_VAN.py b/configs/weather/simvp_t2m_5_625/SimVP_VAN.py
deleted file mode 100644
index 885d5e21..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_VAN.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'van'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 5e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_ViT.py b/configs/weather/simvp_t2m_5_625/SimVP_ViT.py
deleted file mode 100644
index e711cee3..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_ViT.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'vit'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_gSTA.py b/configs/weather/simvp_t2m_5_625/SimVP_gSTA.py
deleted file mode 100644
index 8eefd720..00000000
--- a/configs/weather/simvp_t2m_5_625/SimVP_gSTA.py
+++ /dev/null
@@ -1,14 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'gSTA'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 5e-3
-batch_size = 16
-drop_path = 0.1
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_IncepU.py b/configs/weather/simvp_tcc_5_625/SimVP_IncepU.py
deleted file mode 100644
index e33989f1..00000000
--- a/configs/weather/simvp_tcc_5_625/SimVP_IncepU.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'IncepU' # SimVP.V1
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-2
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_Swin.py b/configs/weather/simvp_tcc_5_625/SimVP_Swin.py
deleted file mode 100644
index bdc7e545..00000000
--- a/configs/weather/simvp_tcc_5_625/SimVP_Swin.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'swin'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_Uniformer.py b/configs/weather/simvp_tcc_5_625/SimVP_Uniformer.py
deleted file mode 100644
index 698a6860..00000000
--- a/configs/weather/simvp_tcc_5_625/SimVP_Uniformer.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'uniformer'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 5e-3
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_ConvMixer.py b/configs/weather/simvp_uv10_5_625/SimVP_ConvMixer.py
deleted file mode 100644
index 96a4edca..00000000
--- a/configs/weather/simvp_uv10_5_625/SimVP_ConvMixer.py
+++ /dev/null
@@ -1,15 +0,0 @@
-method = 'SimVP'
-# model
-spatio_kernel_enc = 3
-spatio_kernel_dec = 3
-model_type = 'convmixer'
-hid_S = 32
-hid_T = 256
-N_T = 8
-N_S = 2
-# training
-lr = 1e-2
-batch_size = 16
-drop_path = 0.1
-sched = 'cosine'
-warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/simvp/SimVP_ConvMixer.py b/configs/weather/t2m_5_625/SimVP_ConvMixer.py
similarity index 100%
rename from configs/weather/simvp/SimVP_ConvMixer.py
rename to configs/weather/t2m_5_625/SimVP_ConvMixer.py
diff --git a/configs/weather/simvp/SimVP_ConvNeXt.py b/configs/weather/t2m_5_625/SimVP_ConvNeXt.py
similarity index 100%
rename from configs/weather/simvp/SimVP_ConvNeXt.py
rename to configs/weather/t2m_5_625/SimVP_ConvNeXt.py
diff --git a/configs/weather/simvp/SimVP_HorNet.py b/configs/weather/t2m_5_625/SimVP_HorNet.py
similarity index 100%
rename from configs/weather/simvp/SimVP_HorNet.py
rename to configs/weather/t2m_5_625/SimVP_HorNet.py
diff --git a/configs/weather/simvp/SimVP_IncepU.py b/configs/weather/t2m_5_625/SimVP_IncepU.py
similarity index 100%
rename from configs/weather/simvp/SimVP_IncepU.py
rename to configs/weather/t2m_5_625/SimVP_IncepU.py
diff --git a/configs/weather/simvp/SimVP_MLPMixer.py b/configs/weather/t2m_5_625/SimVP_MLPMixer.py
similarity index 100%
rename from configs/weather/simvp/SimVP_MLPMixer.py
rename to configs/weather/t2m_5_625/SimVP_MLPMixer.py
diff --git a/configs/weather/simvp/SimVP_MogaNet.py b/configs/weather/t2m_5_625/SimVP_MogaNet.py
similarity index 100%
rename from configs/weather/simvp/SimVP_MogaNet.py
rename to configs/weather/t2m_5_625/SimVP_MogaNet.py
diff --git a/configs/weather/simvp/SimVP_Poolformer.py b/configs/weather/t2m_5_625/SimVP_Poolformer.py
similarity index 100%
rename from configs/weather/simvp/SimVP_Poolformer.py
rename to configs/weather/t2m_5_625/SimVP_Poolformer.py
diff --git a/configs/weather/simvp/SimVP_Swin.py b/configs/weather/t2m_5_625/SimVP_Swin.py
similarity index 100%
rename from configs/weather/simvp/SimVP_Swin.py
rename to configs/weather/t2m_5_625/SimVP_Swin.py
diff --git a/configs/weather/simvp/SimVP_Uniformer.py b/configs/weather/t2m_5_625/SimVP_Uniformer.py
similarity index 100%
rename from configs/weather/simvp/SimVP_Uniformer.py
rename to configs/weather/t2m_5_625/SimVP_Uniformer.py
diff --git a/configs/weather/simvp/SimVP_VAN.py b/configs/weather/t2m_5_625/SimVP_VAN.py
similarity index 100%
rename from configs/weather/simvp/SimVP_VAN.py
rename to configs/weather/t2m_5_625/SimVP_VAN.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_ViT.py b/configs/weather/t2m_5_625/SimVP_ViT.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_ViT.py
rename to configs/weather/t2m_5_625/SimVP_ViT.py
diff --git a/configs/weather/simvp_r_5_625/SimVP_gSTA.py b/configs/weather/t2m_5_625/SimVP_gSTA.py
similarity index 100%
rename from configs/weather/simvp_r_5_625/SimVP_gSTA.py
rename to configs/weather/t2m_5_625/SimVP_gSTA.py
diff --git a/configs/weather/tcc_5_625/ConvLSTM.py b/configs/weather/tcc_5_625/ConvLSTM.py
new file mode 100644
index 00000000..da05e091
--- /dev/null
+++ b/configs/weather/tcc_5_625/ConvLSTM.py
@@ -0,0 +1,19 @@
+method = 'ConvLSTM'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 1e-4
\ No newline at end of file
diff --git a/configs/weather/tcc_5_625/PredRNN.py b/configs/weather/tcc_5_625/PredRNN.py
new file mode 100644
index 00000000..5912706d
--- /dev/null
+++ b/configs/weather/tcc_5_625/PredRNN.py
@@ -0,0 +1,19 @@
+method = 'PredRNN'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 5e-4
\ No newline at end of file
diff --git a/configs/weather/tcc_5_625/PredRNNpp.py b/configs/weather/tcc_5_625/PredRNNpp.py
new file mode 100644
index 00000000..8c95dbd0
--- /dev/null
+++ b/configs/weather/tcc_5_625/PredRNNpp.py
@@ -0,0 +1,19 @@
+method = 'PredRNNpp'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 1e-3
\ No newline at end of file
diff --git a/configs/weather/tcc_5_625/PredRNNv2.py b/configs/weather/tcc_5_625/PredRNNv2.py
new file mode 100644
index 00000000..c99c16d3
--- /dev/null
+++ b/configs/weather/tcc_5_625/PredRNNv2.py
@@ -0,0 +1,20 @@
+method = 'PredRNNv2'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 1
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+decouple_beta = 0.1
+# training
+lr = 1e-3
\ No newline at end of file
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_ConvMixer.py b/configs/weather/tcc_5_625/SimVP_ConvMixer.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_ConvMixer.py
rename to configs/weather/tcc_5_625/SimVP_ConvMixer.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_ConvNeXt.py b/configs/weather/tcc_5_625/SimVP_ConvNeXt.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_ConvNeXt.py
rename to configs/weather/tcc_5_625/SimVP_ConvNeXt.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_HorNet.py b/configs/weather/tcc_5_625/SimVP_HorNet.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_HorNet.py
rename to configs/weather/tcc_5_625/SimVP_HorNet.py
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_IncepU.py b/configs/weather/tcc_5_625/SimVP_IncepU.py
similarity index 100%
rename from configs/weather/simvp_t2m_5_625/SimVP_IncepU.py
rename to configs/weather/tcc_5_625/SimVP_IncepU.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_MLPMixer.py b/configs/weather/tcc_5_625/SimVP_MLPMixer.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_MLPMixer.py
rename to configs/weather/tcc_5_625/SimVP_MLPMixer.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_MogaNet.py b/configs/weather/tcc_5_625/SimVP_MogaNet.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_MogaNet.py
rename to configs/weather/tcc_5_625/SimVP_MogaNet.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_Poolformer.py b/configs/weather/tcc_5_625/SimVP_Poolformer.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_Poolformer.py
rename to configs/weather/tcc_5_625/SimVP_Poolformer.py
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_Swin.py b/configs/weather/tcc_5_625/SimVP_Swin.py
similarity index 100%
rename from configs/weather/simvp_t2m_5_625/SimVP_Swin.py
rename to configs/weather/tcc_5_625/SimVP_Swin.py
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_Uniformer.py b/configs/weather/tcc_5_625/SimVP_Uniformer.py
similarity index 100%
rename from configs/weather/simvp_t2m_5_625/SimVP_Uniformer.py
rename to configs/weather/tcc_5_625/SimVP_Uniformer.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_VAN.py b/configs/weather/tcc_5_625/SimVP_VAN.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_VAN.py
rename to configs/weather/tcc_5_625/SimVP_VAN.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_ViT.py b/configs/weather/tcc_5_625/SimVP_ViT.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_ViT.py
rename to configs/weather/tcc_5_625/SimVP_ViT.py
diff --git a/configs/weather/simvp_tcc_5_625/SimVP_gSTA.py b/configs/weather/tcc_5_625/SimVP_gSTA.py
similarity index 100%
rename from configs/weather/simvp_tcc_5_625/SimVP_gSTA.py
rename to configs/weather/tcc_5_625/SimVP_gSTA.py
diff --git a/configs/weather/uv10_5_625/ConvLSTM.py b/configs/weather/uv10_5_625/ConvLSTM.py
new file mode 100644
index 00000000..fbe59ad8
--- /dev/null
+++ b/configs/weather/uv10_5_625/ConvLSTM.py
@@ -0,0 +1,19 @@
+method = 'ConvLSTM'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 5e-4
\ No newline at end of file
diff --git a/configs/weather/uv10_5_625/PredRNN.py b/configs/weather/uv10_5_625/PredRNN.py
new file mode 100644
index 00000000..5912706d
--- /dev/null
+++ b/configs/weather/uv10_5_625/PredRNN.py
@@ -0,0 +1,19 @@
+method = 'PredRNN'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 5e-4
\ No newline at end of file
diff --git a/configs/weather/uv10_5_625/PredRNNpp.py b/configs/weather/uv10_5_625/PredRNNpp.py
new file mode 100644
index 00000000..af4b0179
--- /dev/null
+++ b/configs/weather/uv10_5_625/PredRNNpp.py
@@ -0,0 +1,19 @@
+method = 'PredRNNpp'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 0
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+# training
+lr = 5e-4
\ No newline at end of file
diff --git a/configs/weather/uv10_5_625/PredRNNv2.py b/configs/weather/uv10_5_625/PredRNNv2.py
new file mode 100644
index 00000000..c665d0b6
--- /dev/null
+++ b/configs/weather/uv10_5_625/PredRNNv2.py
@@ -0,0 +1,20 @@
+method = 'PredRNNv2'
+# reverse scheduled sampling
+reverse_scheduled_sampling = 1
+r_sampling_step_1 = 25000
+r_sampling_step_2 = 50000
+r_exp_alpha = 5000
+# scheduled sampling
+scheduled_sampling = 1
+sampling_stop_iter = 50000
+sampling_start_value = 1.0
+sampling_changing_rate = 0.00002
+# model
+num_hidden = '128,128,128,128'
+filter_size = 5
+stride = 1
+patch_size = 2
+layer_norm = 0
+decouple_beta = 0.1
+# training
+lr = 5e-4
\ No newline at end of file
diff --git a/configs/weather/simvp_t2m_5_625/SimVP_ConvMixer.py b/configs/weather/uv10_5_625/SimVP_ConvMixer.py
similarity index 100%
rename from configs/weather/simvp_t2m_5_625/SimVP_ConvMixer.py
rename to configs/weather/uv10_5_625/SimVP_ConvMixer.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_ConvNeXt.py b/configs/weather/uv10_5_625/SimVP_ConvNeXt.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_ConvNeXt.py
rename to configs/weather/uv10_5_625/SimVP_ConvNeXt.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_HorNet.py b/configs/weather/uv10_5_625/SimVP_HorNet.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_HorNet.py
rename to configs/weather/uv10_5_625/SimVP_HorNet.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_IncepU.py b/configs/weather/uv10_5_625/SimVP_IncepU.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_IncepU.py
rename to configs/weather/uv10_5_625/SimVP_IncepU.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_MLPMixer.py b/configs/weather/uv10_5_625/SimVP_MLPMixer.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_MLPMixer.py
rename to configs/weather/uv10_5_625/SimVP_MLPMixer.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_MogaNet.py b/configs/weather/uv10_5_625/SimVP_MogaNet.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_MogaNet.py
rename to configs/weather/uv10_5_625/SimVP_MogaNet.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_Poolformer.py b/configs/weather/uv10_5_625/SimVP_Poolformer.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_Poolformer.py
rename to configs/weather/uv10_5_625/SimVP_Poolformer.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_Swin.py b/configs/weather/uv10_5_625/SimVP_Swin.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_Swin.py
rename to configs/weather/uv10_5_625/SimVP_Swin.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_Uniformer.py b/configs/weather/uv10_5_625/SimVP_Uniformer.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_Uniformer.py
rename to configs/weather/uv10_5_625/SimVP_Uniformer.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_VAN.py b/configs/weather/uv10_5_625/SimVP_VAN.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_VAN.py
rename to configs/weather/uv10_5_625/SimVP_VAN.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_ViT.py b/configs/weather/uv10_5_625/SimVP_ViT.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_ViT.py
rename to configs/weather/uv10_5_625/SimVP_ViT.py
diff --git a/configs/weather/simvp_uv10_5_625/SimVP_gSTA.py b/configs/weather/uv10_5_625/SimVP_gSTA.py
similarity index 100%
rename from configs/weather/simvp_uv10_5_625/SimVP_gSTA.py
rename to configs/weather/uv10_5_625/SimVP_gSTA.py
diff --git a/docs/en/changelog.md b/docs/en/changelog.md
index cbbe915e..0de2139e 100644
--- a/docs/en/changelog.md
+++ b/docs/en/changelog.md
@@ -19,4 +19,4 @@ Release version to V0.1.0 with code refactoring.
* Upload `readthedocs` documents. Summarize video prediction benchmark results on MMNIST in [video_benchmarks.md](https://github.com/chengtan9907/SimVPv2/docs/en/model_zoos/video_benchmarks.md).
* Update benchmark results of video prediction baselines and MetaFormer architectures based on SimVP on MMNIST, TaxiBJ, and WeatherBench datasets.
-* Update README and add license.
+* Update README and add a license.
diff --git a/docs/en/conf.py b/docs/en/conf.py
index 0ac9233c..6f58b39e 100644
--- a/docs/en/conf.py
+++ b/docs/en/conf.py
@@ -79,7 +79,7 @@
'menu': [
{
'name': 'GitHub',
- 'url': 'https://github.com/chengtan9907/SimVPv2'
+ 'url': 'https://github.com/chengtan9907/OpenSTL'
},
{
'name':
diff --git a/docs/en/get_started.md b/docs/en/get_started.md
index 25c959a1..f5ed2598 100644
--- a/docs/en/get_started.md
+++ b/docs/en/get_started.md
@@ -4,10 +4,10 @@ This page provides basic tutorials about the usage of SimVP. For installation in
## Training and Testing with a Single GPU
-You can perform single GPU training and testing with `tools/non_dist_train.py` and `tools/non_dist_test.py`. We provide descriptions of some essential arguments.
+You can perform single/multiple GPU training and testing with `tools/train.py` and `tools/test.py`. We provide descriptions of some essential arguments.
```bash
-python tools/non_dist_train.py \
+python tools/train.py \
--dataname ${DATASET_NAME} \
--method ${METHOD_NAME} \
--config_file ${CONFIG_FILE} \
@@ -29,10 +29,10 @@ python tools/non_dist_train.py \
An example of single GPU training with SimVP+gSTA on Moving MNIST dataset.
```shell
bash tools/prepare_data/download_mmnist.sh
-python tools/non_dist_train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
+python tools/train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
An example of single GPU testing with SimVP+gSTA on Moving MNIST dataset.
```shell
-python tools/non_dist_test.py -d mmnist -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
+python tools/test.py -d mmnist -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
diff --git a/docs/en/install.md b/docs/en/install.md
index 47a8335c..5674d29c 100644
--- a/docs/en/install.md
+++ b/docs/en/install.md
@@ -4,10 +4,10 @@
This project has provided an environment setting file of conda, users can easily reproduce the environment by the following commands:
```shell
-git clone https://github.com/chengtan9907/SimVPv2
-cd SimVPv2
+git clone https://github.com/chengtan9907/OpenSTL
+cd OpenSTL
conda env create -f environment.yml
-conda activate SimVP
+conda activate OpenSTL
python setup.py develop # or `pip install -e .`
```
@@ -18,18 +18,19 @@ python setup.py develop # or `pip install -e .`
* fvcore
* numpy
* hickle
-* scikit-image=0.16.2
+* scikit-image
* scikit-learn
* torch
* timm
* tqdm
+* xarray
**Note:**
1. Some errors might occur with `hickle` and `xarray` when using KittiCaltech and WeatherBench datasets. As for KittiCaltech, you can solve the issues by installing additional pacakges according to the output messeage. As for WeatherBench, you can install the latest version of `xarray` to solve the errors, i.e., `pip install git+https://github.com/pydata/xarray/@v2022.03.0` and then installing required pacakges according to error messages.
-2. Following the above instructions, SimVPv2 is installed on `dev` mode, any local modifications made to the code will take effect. You can install it by `pip install .` to use it as a PyPi package, and you should reinstall it to make the local modifications effect.
+2. Following the above instructions, OpenSTL is installed on `dev` mode, any local modifications made to the code will take effect. You can install it by `pip install .` to use it as a PyPi package, and you should reinstall it to make the local modifications effect.
## Prepare datasets
@@ -38,7 +39,7 @@ It is recommended to symlink your dataset root (assuming `$YOUR_DATA_ROOT`) to `
We support following datasets: [KTH Action](https://ieeexplore.ieee.org/document/1334462) [[download](https://www.csc.kth.se/cvap/actions/)], [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297) [[download](https://figshare.com/articles/dataset/KITTI_hkl_files/7985684)], [Moving MNIST](http://arxiv.org/abs/1502.04681) [[download](http://www.cs.toronto.edu/~nitish/unsupervised_video/)], [TaxiBJ](https://arxiv.org/abs/1610.00081) [[download](https://github.com/TolicWang/DeepST/tree/master/data/TaxiBJ)], [WeatherBench](https://arxiv.org/abs/2002.00469) [[download](https://github.com/pangeo-data/WeatherBench)]. You can also download the version we used in experiments from [**Baidu Cloud**](https://pan.baidu.com/s/1fudsBHyrf3nbt-7d42YWWg?pwd=kjfk) (kjfk). Please do not distribute the datasets and only use them for research.
```
-SimVPv2
+OpenSTL
├── configs
└── data
├── caltech
diff --git a/environment.yml b/environment.yml
index fe025385..888c6317 100644
--- a/environment.yml
+++ b/environment.yml
@@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
+ - nni
- numpy
- hickle
- pip
@@ -13,6 +14,5 @@ dependencies:
- pip:
- scikit-image
- timm
- - nni
- tqdm
-prefix: /opt/anaconda3/envs/simvp
+prefix: /opt/anaconda3/envs/simvp
\ No newline at end of file
diff --git a/simvp/__init__.py b/openstl/__init__.py
similarity index 100%
rename from simvp/__init__.py
rename to openstl/__init__.py
diff --git a/openstl/api/__init__.py b/openstl/api/__init__.py
new file mode 100644
index 00000000..219e4e04
--- /dev/null
+++ b/openstl/api/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) CAIRI AI Lab. All rights reserved
+
+from .train import BaseExperiment
+
+__all__ = [
+ 'BaseExperiment',
+]
\ No newline at end of file
diff --git a/simvp/api/train.py b/openstl/api/train.py
similarity index 62%
rename from simvp/api/train.py
rename to openstl/api/train.py
index fe1d5052..cbd5d7f9 100644
--- a/simvp/api/train.py
+++ b/openstl/api/train.py
@@ -1,18 +1,22 @@
# Copyright (c) CAIRI AI Lab. All rights reserved
+import os
+import os.path as osp
import time
import logging
import json
-import torch
import numpy as np
-import os.path as osp
from typing import Dict, List
from fvcore.nn import FlopCountAnalysis, flop_count_table
-from simvp.core import Hook, metric, Recorder, get_priority, hook_maps
-from simvp.methods import method_maps
-from simvp.utils import (set_seed, print_log, output_namespace, check_dir,
- get_dataset, get_dist_info, measure_throughput, weights_to_cpu)
+import torch
+import torch.distributed as dist
+
+from openstl.core import Hook, metric, Recorder, get_priority, hook_maps
+from openstl.methods import method_maps
+from openstl.utils import (set_seed, print_log, output_namespace, check_dir, collect_env,
+ init_dist, init_random_seed,
+ get_dataset, get_dist_info, measure_throughput, weights_to_cpu)
try:
import nni
@@ -21,13 +25,15 @@
has_nni = False
-class NonDistExperiment(object):
- """ Experiment with non-dist PyTorch training and evaluation """
+class BaseExperiment(object):
+ """The basic class of PyTorch training and evaluation."""
def __init__(self, args):
+ """Initialize experiments (non-dist as an example)"""
self.args = args
self.config = self.args.__dict__
- self.device = self._acquire_device()
+ self.device = self.args.device
+ self.method = None
self.args.method = self.args.method.lower()
self._epoch = 0
self._iter = 0
@@ -35,62 +41,41 @@ def __init__(self, args):
self._max_epochs = self.config['epoch']
self._max_iters = None
self._hooks: List[Hook] = []
+ self._rank = 0
+ self._world_size = 1
+ self._dist = self.args.dist
self._preparation()
- print_log(output_namespace(self.args))
-
- T, C, H, W = self.args.in_shape
- if self.args.method == 'simvp':
- input_dummy = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
- elif self.args.method == 'crevnet':
- # crevnet must use the batchsize rather than 1
- input_dummy = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device)
- elif self.args.method == 'phydnet':
- _tmp_input1 = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
- _tmp_input2 = torch.ones(1, self.args.aft_seq_length, C, H, W).to(self.device)
- _tmp_constraints = torch.zeros((49, 7, 7)).to(self.device)
- input_dummy = (_tmp_input1, _tmp_input2, _tmp_constraints)
- elif self.args.method in ['convlstm', 'predrnnpp', 'predrnn', 'mim', 'e3dlstm', 'mau']:
- Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
- Cp = self.args.patch_size ** 2 * C
- _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
- _tmp_flag = torch.ones(1, self.args.aft_seq_length - 1, Hp, Wp, Cp).to(self.device)
- input_dummy = (_tmp_input, _tmp_flag)
- elif self.args.method == 'predrnnv2':
- Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
- Cp = self.args.patch_size ** 2 * C
- _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
- _tmp_flag = torch.ones(1, self.args.total_length - 2, Hp, Wp, Cp).to(self.device)
- input_dummy = (_tmp_input, _tmp_flag)
- else:
- raise ValueError(f'Invalid method name {self.args.method}')
-
- print_log(self.method.model)
- flops = FlopCountAnalysis(self.method.model, input_dummy)
- print_log(flop_count_table(flops))
- if args.fps:
- fps = measure_throughput(self.method.model, input_dummy)
- print_log('Throughputs of {}: {:.3f}'.format(self.args.method, fps))
+ if self._rank == 0:
+ print_log(output_namespace(self.args))
+ self.display_method_info()
def _acquire_device(self):
+ """Setup devices"""
if self.args.use_gpu:
if self.args.dist:
- self._rank, self._world_size = get_dist_info()
- self.device = f'cuda:{self._rank}'
- print(f'Use GPU: local rank={self._rank}')
+ device = f'cuda:{self._rank}'
+ torch.cuda.set_device(self._rank)
+ print(f'Use distributed mode with GPUs: local rank={self._rank}')
else:
device = torch.device('cuda:0')
- print('Use GPU:', device)
+ print('Use non-distributed mode with GPU:', device)
else:
device = torch.device('cpu')
print('Use CPU')
+ if self.args.dist:
+ assert False, "Distributed training requires GPUs"
return device
def _preparation(self):
- # seed
- set_seed(self.args.seed)
+ """Preparation of basic experiment setups"""
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(self.args.local_rank)
+
# log and checkpoint
- self.path = osp.join(self.args.res_dir, self.args.ex_name)
+ base_dir = self.args.res_dir if self.args.res_dir is not None else 'work_dirs'
+ self.path = osp.join(base_dir, self.args.ex_name if not self.args.ex_name.startswith(self.args.res_dir) \
+ else self.args.ex_name.split(self.args.res_dir+'/')[-1])
check_dir(self.path)
self.checkpoints_path = osp.join(self.path, 'checkpoints')
@@ -107,6 +92,36 @@ def _preparation(self):
logging.basicConfig(level=logging.INFO,
filename=osp.join(self.path, '{}_{}.log'.format(prefix, timestamp)),
filemode='a', format='%(asctime)s - %(message)s')
+
+ # init distributed env first, since logger depends on the dist info.
+ if self.args.launcher != 'none' or self.args.dist:
+ self._dist = True
+ if self._dist:
+ assert self.args.launcher != 'none'
+ dist_params = dict(backend='nccl', init_method='env://')
+ if self.args.launcher == 'slurm':
+ dist_params['port'] = self.args.port
+ init_dist(self.args.launcher, **dist_params)
+ self._rank, self._world_size = get_dist_info()
+ # re-set gpu_ids with distributed training mode
+ self._gpu_ids = range(self._world_size)
+ self.device = self._acquire_device()
+
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ if self._rank == 0:
+ print_log('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
+
+ # set random seeds
+ if self._dist:
+ seed = init_random_seed(self.args.seed)
+ seed = seed + dist.get_rank() if self.args.diff_seed else seed
+ else:
+ seed = self.args.seed
+ set_seed(seed)
+
# prepare data
self._get_data()
# build the method
@@ -121,8 +136,15 @@ def _preparation(self):
self.call_hook('before_run')
def _build_method(self):
- steps_per_epoch = len(self.train_loader)
- self.method = method_maps[self.args.method](self.args, self.device, steps_per_epoch)
+ self.steps_per_epoch = len(self.train_loader)
+ self.method = method_maps[self.args.method](self.args, self.device, self.steps_per_epoch)
+ self.method.model.eval()
+ # setup ddp training
+ if self._dist:
+ self.method.model.cuda()
+ if self.args.torchscript:
+ self.method.model = torch.jit.script(self.method.model)
+ self.method._init_distributed()
def _build_hook(self):
for k in self.args.__dict__:
@@ -144,15 +166,10 @@ def _build_hook(self):
self._hooks.insert(0, hook)
def call_hook(self, fn_name: str) -> None:
+ """Run hooks by the registered names"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
- def _get_data(self):
- self.train_loader, self.vali_loader, self.test_loader = get_dataset(self.args.dataname, self.config)
- if self.vali_loader is None:
- self.vali_loader = self.test_loader
- self._max_iters = self._max_epochs * len(self.train_loader)
-
def _get_hook_info(self):
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
@@ -173,7 +190,15 @@ def _get_hook_info(self):
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
+ def _get_data(self):
+ self.train_loader, self.vali_loader, self.test_loader = \
+ get_dataset(self.args.dataname, self.config)
+ if self.vali_loader is None:
+ self.vali_loader = self.test_loader
+ self._max_iters = self._max_epochs * len(self.train_loader)
+
def _save(self, name=''):
+ """Saving models and meta data to checkpoints"""
checkpoint = {
'epoch': self._epoch + 1,
'optimizer': self.method.model_optim.state_dict(),
@@ -182,6 +207,7 @@ def _save(self, name=''):
torch.save(checkpoint, osp.join(self.checkpoints_path, name + '.pth'))
def _load(self, name=''):
+ """Loading models from the checkpoint"""
filename = name if osp.isfile(name) else osp.join(self.checkpoints_path, name + '.pth')
try:
checkpoint = torch.load(filename)
@@ -190,29 +216,60 @@ def _load(self, name=''):
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(f'No state_dict found in checkpoint file {filename}')
- self.method.model.load_state_dict(checkpoint['state_dict'])
+ if self._dist:
+ self.method.model.module.load_state_dict(checkpoint['state_dict'])
+ else:
+ self.method.model.load_state_dict(checkpoint['state_dict'])
if checkpoint.get('epoch', None) is not None:
self._epoch = checkpoint['epoch']
self.method.model_optim.load_state_dict(checkpoint['optimizer'])
self.method.scheduler.load_state_dict(checkpoint['scheduler'])
+ def display_method_info(self):
+ """Plot the basic infomation of supported methods"""
+ T, C, H, W = self.args.in_shape
+ if self.args.method == 'simvp':
+ input_dummy = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
+ elif self.args.method == 'crevnet':
+ # crevnet must use the batchsize rather than 1
+ input_dummy = torch.ones(self.args.batch_size, 20, C, H, W).to(self.device)
+ elif self.args.method == 'phydnet':
+ _tmp_input1 = torch.ones(1, self.args.pre_seq_length, C, H, W).to(self.device)
+ _tmp_input2 = torch.ones(1, self.args.aft_seq_length, C, H, W).to(self.device)
+ _tmp_constraints = torch.zeros((49, 7, 7)).to(self.device)
+ input_dummy = (_tmp_input1, _tmp_input2, _tmp_constraints)
+ elif self.args.method in ['convlstm', 'predrnnpp', 'predrnn', 'mim', 'e3dlstm', 'mau']:
+ Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
+ Cp = self.args.patch_size ** 2 * C
+ _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
+ _tmp_flag = torch.ones(1, self.args.aft_seq_length - 1, Hp, Wp, Cp).to(self.device)
+ input_dummy = (_tmp_input, _tmp_flag)
+ elif self.args.method == 'predrnnv2':
+ Hp, Wp = H // self.args.patch_size, W // self.args.patch_size
+ Cp = self.args.patch_size ** 2 * C
+ _tmp_input = torch.ones(1, self.args.total_length, Hp, Wp, Cp).to(self.device)
+ _tmp_flag = torch.ones(1, self.args.total_length - 2, Hp, Wp, Cp).to(self.device)
+ input_dummy = (_tmp_input, _tmp_flag)
+ else:
+ raise ValueError(f'Invalid method name {self.args.method}')
+
+ print_log(self.method.model)
+ flops = FlopCountAnalysis(self.method.model, input_dummy)
+ print_log(flop_count_table(flops))
+ if self.args.fps:
+ fps = measure_throughput(self.method.model, input_dummy)
+ print_log('Throughputs of {}: {:.3f}'.format(self.args.method, fps))
+
def train(self):
recorder = Recorder(verbose=True)
- num_updates = 0
+ num_updates = self._epoch * self.steps_per_epoch
self.call_hook('before_train_epoch')
- # constants for other methods:
- eta = 1.0 # PredRNN
+
+ eta = 1.0 # PredRNN variants
for epoch in range(self._epoch, self._max_epochs):
- loss_mean = 0.0
-
- if self.args.method in ['simvp', 'crevnet', 'phydnet']:
- num_updates, loss_mean = self.method.train_one_epoch(
- self, self.train_loader, epoch, num_updates, loss_mean)
- elif self.args.method in ['convlstm', 'predrnnpp', 'predrnn', 'predrnnv2', 'mim', 'e3dlstm', 'mau']:
- num_updates, loss_mean, eta = self.method.train_one_epoch(
- self, self.train_loader, epoch, num_updates, loss_mean, eta)
- else:
- raise ValueError(f'Invalid method name {self.args.method}')
+
+ num_updates, loss_mean, eta = self.method.train_one_epoch(self, self.train_loader,
+ epoch, num_updates, eta)
self._epoch = epoch
if epoch % self.args.log_step == 0:
@@ -221,15 +278,19 @@ def train(self):
with torch.no_grad():
vali_loss = self.vali(self.vali_loader)
- print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format(
- epoch + 1, len(self.train_loader), cur_lr, loss_mean, vali_loss))
- recorder(vali_loss, self.method.model, self.path)
- self._save(name='latest')
+ if self._rank == 0:
+ print_log('Epoch: {0}, Steps: {1} | Lr: {2:.7f} | Train Loss: {3:.7f} | Vali Loss: {4:.7f}\n'.format(
+ epoch + 1, len(self.train_loader), cur_lr, loss_mean.avg, vali_loss))
+ recorder(vali_loss, self.method.model, self.path)
+ self._save(name='latest')
if not check_dir(self.path): # exit training when work_dir is removed
assert False and "Exit training because work_dir is removed"
best_model_path = osp.join(self.path, 'checkpoint.pth')
- self.method.model.load_state_dict(torch.load(best_model_path))
+ if self._dist:
+ self.method.model.module.load_state_dict(torch.load(best_model_path))
+ else:
+ self.method.model.load_state_dict(torch.load(best_model_path))
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
@@ -238,22 +299,27 @@ def vali(self, vali_loader):
preds, trues, val_loss = self.method.vali_one_epoch(self, self.vali_loader)
self.call_hook('after_val_epoch')
- if 'weather' in self.args.dataname:
- metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
- else:
- metric_list, spatial_norm = ['mse', 'mae'], False
- eval_res, eval_log = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std,
- metrics=metric_list, spatial_norm=spatial_norm)
- print_log('val\t '+eval_log)
- if has_nni:
- nni.report_intermediate_result(eval_res['mse'])
+ if self._rank == 0:
+ if 'weather' in self.args.dataname:
+ metric_list, spatial_norm = ['mse', 'rmse', 'mae'], True
+ else:
+ metric_list, spatial_norm = ['mse', 'mae'], False
+ eval_res, eval_log = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std,
+ metrics=metric_list, spatial_norm=spatial_norm)
+
+ print_log('val\t '+eval_log)
+ if has_nni:
+ nni.report_intermediate_result(eval_res['mse'])
return val_loss
def test(self):
if self.args.test:
best_model_path = osp.join(self.path, 'checkpoint.pth')
- self.method.model.load_state_dict(torch.load(best_model_path))
+ if self._dist:
+ self.method.model.module.load_state_dict(torch.load(best_model_path))
+ else:
+ self.method.model.load_state_dict(torch.load(best_model_path))
self.call_hook('before_val_epoch')
inputs, trues, preds = self.method.test_one_epoch(self, self.test_loader)
@@ -266,11 +332,13 @@ def test(self):
eval_res, eval_log = metric(preds, trues, self.test_loader.dataset.mean, self.test_loader.dataset.std,
metrics=metric_list, spatial_norm=spatial_norm)
metrics = np.array([eval_res['mae'], eval_res['mse']])
- print_log(eval_log)
- folder_path = osp.join(self.path, 'saved')
- check_dir(folder_path)
+ if self._rank == 0:
+ print_log(eval_log)
+ folder_path = osp.join(self.path, 'saved')
+ check_dir(folder_path)
+
+ for np_data in ['metrics', 'inputs', 'trues', 'preds']:
+ np.save(osp.join(folder_path, np_data + '.npy'), vars()[np_data])
- for np_data in ['metrics', 'inputs', 'trues', 'preds']:
- np.save(osp.join(folder_path, np_data + '.npy'), vars()[np_data])
return eval_res['mse']
diff --git a/simvp/core/__init__.py b/openstl/core/__init__.py
similarity index 100%
rename from simvp/core/__init__.py
rename to openstl/core/__init__.py
diff --git a/simvp/core/ema_hook.py b/openstl/core/ema_hook.py
similarity index 99%
rename from simvp/core/ema_hook.py
rename to openstl/core/ema_hook.py
index 3f69b44d..2e6e0224 100644
--- a/simvp/core/ema_hook.py
+++ b/openstl/core/ema_hook.py
@@ -83,6 +83,8 @@ def before_run(self, runner):
Register ema parameter as ``named_buffer`` to model
"""
model = runner.method.model
+ if runner._dist:
+ model = model.module
self.param_ema_buffer = {}
if self.full_params_ema:
self.model_parameters = dict(model.state_dict())
@@ -227,6 +229,8 @@ def before_run(self, runner):
Register ema parameter as ``named_buffer`` to model
"""
model = runner.method.model
+ if runner._dist:
+ model = model.module
self.param_ema_buffer = {}
if self.full_params_ema:
self.model_parameters = dict(model.state_dict())
diff --git a/simvp/core/hooks.py b/openstl/core/hooks.py
similarity index 100%
rename from simvp/core/hooks.py
rename to openstl/core/hooks.py
diff --git a/simvp/core/metrics.py b/openstl/core/metrics.py
similarity index 100%
rename from simvp/core/metrics.py
rename to openstl/core/metrics.py
diff --git a/simvp/core/optim_constant.py b/openstl/core/optim_constant.py
similarity index 100%
rename from simvp/core/optim_constant.py
rename to openstl/core/optim_constant.py
diff --git a/simvp/core/optim_scheduler.py b/openstl/core/optim_scheduler.py
similarity index 100%
rename from simvp/core/optim_scheduler.py
rename to openstl/core/optim_scheduler.py
diff --git a/simvp/core/recorder.py b/openstl/core/recorder.py
similarity index 100%
rename from simvp/core/recorder.py
rename to openstl/core/recorder.py
diff --git a/simvp/datasets/__init__.py b/openstl/datasets/__init__.py
similarity index 84%
rename from simvp/datasets/__init__.py
rename to openstl/datasets/__init__.py
index 662e0339..3eea09f4 100644
--- a/simvp/datasets/__init__.py
+++ b/openstl/datasets/__init__.py
@@ -7,8 +7,9 @@
from .dataloader_weather import ClimateDataset
from .dataloader import load_data
from .dataset_constant import dataset_parameters
+from .utils import create_loader
__all__ = [
'KittiCaltechDataset', 'KTHDataset', 'MovingMNIST', 'TaxibjDataset', 'ClimateDataset',
- 'load_data', 'dataset_parameters'
+ 'load_data', 'dataset_parameters', 'create_loader',
]
\ No newline at end of file
diff --git a/simvp/datasets/dataloader.py b/openstl/datasets/dataloader.py
similarity index 66%
rename from simvp/datasets/dataloader.py
rename to openstl/datasets/dataloader.py
index 4f238ace..ad18dcc2 100644
--- a/simvp/datasets/dataloader.py
+++ b/openstl/datasets/dataloader.py
@@ -1,22 +1,26 @@
+# Copyright (c) CAIRI AI Lab. All rights reserved
-
-def load_data(dataname, batch_size, val_batch_size, num_workers, data_root, **kwargs):
+def load_data(dataname, batch_size, val_batch_size, num_workers, data_root, distributed=False, **kwargs):
pre_seq_length = kwargs.get('pre_seq_length', 10)
aft_seq_length = kwargs.get('aft_seq_length', 10)
if dataname == 'kitticaltech':
from .dataloader_kitticaltech import load_data
- return load_data(batch_size, val_batch_size, data_root, num_workers, pre_seq_length, aft_seq_length)
+ return load_data(batch_size, val_batch_size, data_root, num_workers,
+ pre_seq_length, aft_seq_length, distributed=distributed)
elif 'kth' in dataname: # 'kth', 'kth20', 'kth40'
from .dataloader_kth import load_data
- return load_data(batch_size, val_batch_size, data_root, num_workers, pre_seq_length, aft_seq_length)
+ return load_data(batch_size, val_batch_size, data_root, num_workers,
+ pre_seq_length, aft_seq_length, distributed=distributed)
elif dataname == 'mmnist':
from .dataloader_moving_mnist import load_data
- return load_data(batch_size, val_batch_size, data_root, num_workers, pre_seq_length, aft_seq_length)
+ return load_data(batch_size, val_batch_size, data_root, num_workers,
+ pre_seq_length, aft_seq_length, distributed=distributed)
elif dataname == 'taxibj':
from .dataloader_taxibj import load_data
- return load_data(batch_size, val_batch_size, data_root, num_workers, pre_seq_length, aft_seq_length)
+ return load_data(batch_size, val_batch_size, data_root, num_workers,
+ pre_seq_length, aft_seq_length, distributed=distributed)
elif 'weather' in dataname: # 'weather', 'weather_t2m', etc.
from .dataloader_weather import load_data
- return load_data(batch_size, val_batch_size, data_root, num_workers, **kwargs)
+ return load_data(batch_size, val_batch_size, data_root, num_workers, distributed=distributed **kwargs)
else:
raise ValueError(f'Dataname {dataname} is unsupported')
diff --git a/simvp/datasets/dataloader_kitticaltech.py b/openstl/datasets/dataloader_kitticaltech.py
similarity index 82%
rename from simvp/datasets/dataloader_kitticaltech.py
rename to openstl/datasets/dataloader_kitticaltech.py
index 3c5ceb8a..a40c73b3 100644
--- a/simvp/datasets/dataloader_kitticaltech.py
+++ b/openstl/datasets/dataloader_kitticaltech.py
@@ -6,6 +6,8 @@
from torch.utils.data import Dataset
from skimage.transform import resize
+from .utils import create_loader
+
try:
import hickle as hkl
except ImportError:
@@ -109,8 +111,8 @@ def load_data(self, mode='train'):
return data, indices
-def load_data(batch_size, val_batch_size, data_root,
- num_workers=4, pre_seq_length=10, aft_seq_length=1):
+def load_data(batch_size, val_batch_size, data_root, num_workers=4,
+ pre_seq_length=10, aft_seq_length=1, distributed=False):
if os.path.exists(osp.join(data_root, 'kitti_hkl')):
input_param = {
@@ -135,17 +137,16 @@ def load_data(batch_size, val_batch_size, data_root,
test_set = KittiCaltechDataset(
test_data, test_idx, pre_seq_length, aft_seq_length)
- dataloader_train = torch.utils.data.DataLoader(train_set,
- batch_size=batch_size, shuffle=True,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_vali = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_test = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
+ dataloader_train = create_loader(train_set,
+ batch_size=batch_size,
+ shuffle=True, is_training=True,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_vali = None
+ dataloader_test = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
return dataloader_train, dataloader_vali, dataloader_test
diff --git a/simvp/datasets/dataloader_kth.py b/openstl/datasets/dataloader_kth.py
similarity index 91%
rename from simvp/datasets/dataloader_kth.py
rename to openstl/datasets/dataloader_kth.py
index bf575a68..2a78e1af 100644
--- a/simvp/datasets/dataloader_kth.py
+++ b/openstl/datasets/dataloader_kth.py
@@ -7,6 +7,8 @@
from torch.utils.data import Dataset
from PIL import Image
+from .utils import create_loader
+
logger = logging.getLogger(__name__)
@@ -224,7 +226,8 @@ def get_test_input_handle(self):
return InputHandle(test_data, test_indices, self.input_param)
-def load_data(batch_size, val_batch_size, data_root, num_workers=4, pre_seq_length=10, aft_seq_length=20):
+def load_data(batch_size, val_batch_size, data_root, num_workers=4,
+ pre_seq_length=10, aft_seq_length=20, distributed=False):
img_width = 128
# pre_seq_length, aft_seq_length = 10, 10
@@ -249,10 +252,16 @@ def load_data(batch_size, val_batch_size, data_root, num_workers=4, pre_seq_leng
pre_seq_length,
aft_seq_length)
- dataloader_train = torch.utils.data.DataLoader(
- train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
- dataloader_validation = None
- dataloader_test = torch.utils.data.DataLoader(
- test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
-
- return dataloader_train, dataloader_validation, dataloader_test
+ dataloader_train = create_loader(train_set,
+ batch_size=batch_size,
+ shuffle=True, is_training=True,
+ pin_memory=True, num_workers=num_workers,
+ distributed=distributed)
+ dataloader_vali = None
+ dataloader_test = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, num_workers=num_workers,
+ distributed=distributed)
+
+ return dataloader_train, dataloader_vali, dataloader_test
diff --git a/simvp/datasets/dataloader_moving_mnist.py b/openstl/datasets/dataloader_moving_mnist.py
similarity index 82%
rename from simvp/datasets/dataloader_moving_mnist.py
rename to openstl/datasets/dataloader_moving_mnist.py
index 2fbdd683..078fee1e 100644
--- a/simvp/datasets/dataloader_moving_mnist.py
+++ b/openstl/datasets/dataloader_moving_mnist.py
@@ -5,6 +5,8 @@
import torch
from torch.utils.data import Dataset
+from .utils import create_loader
+
def load_mnist(root):
# Load MNIST dataset for generating training data.
@@ -150,8 +152,8 @@ def __len__(self):
return self.length
-def load_data(batch_size, val_batch_size, data_root,
- num_workers=4, pre_seq_length=10, aft_seq_length=10):
+def load_data(batch_size, val_batch_size, data_root, num_workers=4,
+ pre_seq_length=10, aft_seq_length=10, distributed=False):
train_set = MovingMNIST(root=data_root, is_train=True,
n_frames_input=pre_seq_length,
@@ -160,17 +162,20 @@ def load_data(batch_size, val_batch_size, data_root,
n_frames_input=pre_seq_length,
n_frames_output=aft_seq_length, num_objects=[2])
- dataloader_train = torch.utils.data.DataLoader(train_set,
- batch_size=batch_size, shuffle=True,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_vali = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_test = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
+ dataloader_train = create_loader(train_set,
+ batch_size=batch_size,
+ shuffle=True, is_training=True,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_vali = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_test = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
return dataloader_train, dataloader_vali, dataloader_test
diff --git a/openstl/datasets/dataloader_taxibj.py b/openstl/datasets/dataloader_taxibj.py
new file mode 100644
index 00000000..b8033fc9
--- /dev/null
+++ b/openstl/datasets/dataloader_taxibj.py
@@ -0,0 +1,52 @@
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+
+from .utils import create_loader
+
+
+class TaxibjDataset(Dataset):
+ """Taxibj `_ Dataset"""
+
+ def __init__(self, X, Y):
+ super(TaxibjDataset, self).__init__()
+ self.X = (X+1)/2
+ self.Y = (Y+1)/2
+ self.mean = 0
+ self.std = 1
+
+ def __len__(self):
+ return self.X.shape[0]
+
+ def __getitem__(self, index):
+ data = torch.tensor(self.X[index, ::]).float()
+ labels = torch.tensor(self.Y[index, ::]).float()
+ return data, labels
+
+
+def load_data(batch_size, val_batch_size, data_root, num_workers=4,
+ pre_seq_length=None, aft_seq_length=None, distributed=False):
+
+ dataset = np.load(data_root+'taxibj/dataset.npz')
+ X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset[
+ 'Y_train'], dataset['X_test'], dataset['Y_test']
+ train_set = TaxibjDataset(X=X_train, Y=Y_train)
+ test_set = TaxibjDataset(X=X_test, Y=Y_test)
+
+ dataloader_train = create_loader(train_set,
+ batch_size=batch_size,
+ shuffle=True, is_training=True,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_vali = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_test = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+
+ return dataloader_train, dataloader_vali, dataloader_test
diff --git a/simvp/datasets/dataloader_weather.py b/openstl/datasets/dataloader_weather.py
similarity index 83%
rename from simvp/datasets/dataloader_weather.py
rename to openstl/datasets/dataloader_weather.py
index 7a338253..bc82e6be 100644
--- a/simvp/datasets/dataloader_weather.py
+++ b/openstl/datasets/dataloader_weather.py
@@ -5,6 +5,7 @@
import os.path as osp
import torch
from torch.utils.data import Dataset
+from .utils import create_loader
try:
import xarray as xr
@@ -157,6 +158,7 @@ def load_data(batch_size,
idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
step=1,
+ distributed=False,
**kwargs):
weather_dataroot = osp.join(data_root, 'weather')
@@ -167,14 +169,14 @@ def load_data(batch_size,
idx_in=idx_in,
idx_out=idx_out,
step=step)
- validation_set = ClimateDataset(weather_dataroot,
- data_name,
- val_time,
- idx_in,
- idx_out,
- step,
- mean=train_set.mean,
- std=train_set.std)
+ vali_set = ClimateDataset(weather_dataroot,
+ data_name,
+ val_time,
+ idx_in,
+ idx_out,
+ step,
+ mean=train_set.mean,
+ std=train_set.std)
test_set = ClimateDataset(weather_dataroot,
data_name,
test_time,
@@ -184,18 +186,21 @@ def load_data(batch_size,
mean=train_set.mean,
std=train_set.std)
- dataloader_train = torch.utils.data.DataLoader(train_set,
- batch_size=batch_size, shuffle=True,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_vali = torch.utils.data.DataLoader(test_set, # validation_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_test = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
+ dataloader_train = create_loader(train_set,
+ batch_size=batch_size,
+ shuffle=True, is_training=True,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_vali = create_loader(test_set, # validation_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
+ dataloader_test = create_loader(test_set,
+ batch_size=val_batch_size,
+ shuffle=False, is_training=False,
+ pin_memory=True, drop_last=True,
+ num_workers=num_workers, distributed=distributed)
return dataloader_train, dataloader_vali, dataloader_test
diff --git a/simvp/datasets/dataset_constant.py b/openstl/datasets/dataset_constant.py
similarity index 100%
rename from simvp/datasets/dataset_constant.py
rename to openstl/datasets/dataset_constant.py
diff --git a/openstl/datasets/utils.py b/openstl/datasets/utils.py
new file mode 100644
index 00000000..e2379269
--- /dev/null
+++ b/openstl/datasets/utils.py
@@ -0,0 +1,194 @@
+import random
+from functools import partial
+from itertools import repeat
+from typing import Callable
+from timm.data.distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
+
+import torch.utils.data
+import numpy as np
+
+
+def worker_init(worker_id, worker_seeding='all'):
+ worker_info = torch.utils.data.get_worker_info()
+ assert worker_info.id == worker_id
+ if isinstance(worker_seeding, Callable):
+ seed = worker_seeding(worker_info)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed % (2 ** 32 - 1))
+ else:
+ assert worker_seeding in ('all', 'part')
+ # random / torch seed already called in dataloader iter class w/ worker_info.seed
+ # to reproduce some old results (same seed + hparam combo), partial seeding
+ # is required (skip numpy re-seed)
+ if worker_seeding == 'all':
+ np.random.seed(worker_info.seed % (2 ** 32 - 1))
+
+
+def fast_collate(batch):
+ """ A fast collation function optimized for uint8 images (np array or torch)
+ and int64 targets (labels)"""
+ assert isinstance(batch[0], tuple)
+ batch_size = len(batch)
+ if isinstance(batch[0][0], tuple):
+ # This branch 'deinterleaves' and flattens tuples of input tensors into
+ # one tensor ordered by position such that all tuple of position n will end up
+ # in a torch.split(tensor, batch_size) in nth position
+ inner_tuple_size = len(batch[0][0])
+ flattened_batch_size = batch_size * inner_tuple_size
+ targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
+ tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ # all input tensor tuples must be same length
+ assert len(batch[i][0]) == inner_tuple_size
+ for j in range(inner_tuple_size):
+ targets[i + j * batch_size] = batch[i][1]
+ tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
+ return tensor, targets
+ elif isinstance(batch[0][0], np.ndarray):
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ assert len(targets) == batch_size
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ tensor[i] += torch.from_numpy(batch[i][0])
+ return tensor, targets
+ elif isinstance(batch[0][0], torch.Tensor):
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ assert len(targets) == batch_size
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ tensor[i].copy_(batch[i][0])
+ return tensor, targets
+ else:
+ assert False
+
+
+def expand_to_chs(x, n):
+ if not isinstance(x, (tuple, list)):
+ x = tuple(repeat(x, n))
+ elif len(x) == 1:
+ x = x * n
+ else:
+ assert len(x) == n, 'normalization stats must match image channels'
+ return x
+
+
+class PrefetchLoader:
+
+ def __init__(self,
+ loader,
+ mean,
+ std,
+ channels=3,
+ fp16=False):
+
+ mean = expand_to_chs(mean, channels)
+ std = expand_to_chs(std, channels)
+ normalization_shape = (1, channels, 1, 1)
+
+ self.loader = loader
+ self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
+ self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
+ self.fp16 = fp16
+ if fp16:
+ self.mean = self.mean.half()
+ self.std = self.std.half()
+
+ def __iter__(self):
+ stream = torch.cuda.Stream()
+ first = True
+
+ for next_input, next_target in self.loader:
+ with torch.cuda.stream(stream):
+ next_input = next_input.cuda(non_blocking=True)
+ next_target = next_target.cuda(non_blocking=True)
+ if self.fp16:
+ next_input = next_input.half().sub_(self.mean).div_(self.std)
+ else:
+ next_input = next_input.float().sub_(self.mean).div_(self.std)
+
+ if not first:
+ yield input, target
+ else:
+ first = False
+
+ torch.cuda.current_stream().wait_stream(stream)
+ input = next_input
+ target = next_target
+
+ yield input, target
+
+ def __len__(self):
+ return len(self.loader)
+
+ @property
+ def sampler(self):
+ return self.loader.sampler
+
+ @property
+ def dataset(self):
+ return self.loader.dataset
+
+
+def create_loader(dataset,
+ batch_size,
+ shuffle=True,
+ is_training=False,
+ mean=None,
+ std=None,
+ num_workers=1,
+ num_aug_repeats=0,
+ input_channels=1,
+ use_prefetcher=False,
+ distributed=False,
+ pin_memory=False,
+ drop_last=False,
+ fp16=False,
+ collate_fn=None,
+ persistent_workers=True,
+ worker_seeding='all'):
+ sampler = None
+ if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
+ if is_training:
+ if num_aug_repeats:
+ sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
+ else:
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+ else:
+ # This will add extra duplicate entries to result in equal num
+ # of samples per-process, will slightly alter validation results
+ sampler = OrderedDistributedSampler(dataset)
+ else:
+ assert num_aug_repeats==0, "RepeatAugment is not supported in non-distributed or IterableDataset"
+
+ if collate_fn is None:
+ collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
+ loader_class = torch.utils.data.DataLoader
+
+ loader_args = dict(
+ batch_size=batch_size,
+ shuffle=shuffle and (not isinstance(dataset, torch.utils.data.IterableDataset)) and is_training,
+ num_workers=num_workers,
+ sampler=sampler,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=drop_last,
+ worker_init_fn=partial(worker_init, worker_seeding=worker_seeding),
+ persistent_workers=persistent_workers
+ )
+ try:
+ loader = loader_class(dataset, **loader_args)
+ except TypeError as e:
+ loader_args.pop('persistent_workers') # only in Pytorch 1.7+
+ loader = loader_class(dataset, **loader_args)
+
+ if use_prefetcher:
+ loader = PrefetchLoader(
+ loader,
+ mean=mean,
+ std=std,
+ channels=input_channels,
+ fp16=fp16,
+ )
+
+ return loader
diff --git a/simvp/methods/__init__.py b/openstl/methods/__init__.py
similarity index 100%
rename from simvp/methods/__init__.py
rename to openstl/methods/__init__.py
diff --git a/openstl/methods/base_method.py b/openstl/methods/base_method.py
new file mode 100644
index 00000000..cbe0a5c9
--- /dev/null
+++ b/openstl/methods/base_method.py
@@ -0,0 +1,184 @@
+from typing import Dict, List, Union
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as NativeDDP
+from contextlib import suppress
+from timm.utils import NativeScaler
+from timm.utils.agc import adaptive_clip_grad
+
+from openstl.core.optim_scheduler import get_optim_scheduler
+from openstl.utils import dist_forward_collect, nondist_forward_collect
+
+has_native_amp = False
+try:
+ if getattr(torch.cuda.amp, 'autocast') is not None:
+ has_native_amp = True
+except AttributeError:
+ pass
+
+
+class Base_method(object):
+ """Base Method.
+
+ This class defines the basic functions of a video prediction (VP)
+ method training and testing. Any VP method that inherits this class
+ should at least define its own `train_one_epoch`, `vali_one_epoch`,
+ and `test_one_epoch` function.
+
+ """
+
+ def __init__(self, args, device, steps_per_epoch):
+ super(Base_method, self).__init__()
+ self.args = args
+ self.dist = args.dist
+ self.device = device
+ self.config = args.__dict__
+ self.criterion = None
+ self.model_optim = None
+ self.scheduler = None
+ if self.dist:
+ self.rank = int(device.split(':')[-1])
+ else:
+ self.rank = 0
+ self.clip_value = self.args.clip_grad
+ self.clip_mode = self.args.clip_mode if self.clip_value is not None else None
+ # setup automatic mixed-precision (AMP) loss scaling and op casting
+ self.amp_autocast = suppress # do nothing
+ self.loss_scaler = None
+
+ def _build_model(self, **kwargs):
+ raise NotImplementedError
+
+ def _init_optimizer(self, steps_per_epoch):
+ return get_optim_scheduler(
+ self.args, self.args.epoch, self.model, steps_per_epoch)
+
+ def _init_distributed(self):
+ """Initialize DDP training"""
+ if self.args.fp16 and has_native_amp:
+ self.amp_autocast = torch.cuda.amp.autocast
+ self.loss_scaler = NativeScaler()
+ if self.args.rank == 0:
+ print('Using native PyTorch AMP. Training in mixed precision (fp16).')
+ else:
+ print('AMP not enabled. Training in float32.')
+ self.model = NativeDDP(self.model, device_ids=[self.rank], broadcast_buffers=True)
+
+ def train_one_epoch(self, runner, train_loader, **kwargs):
+ """Train the model with train_loader.
+
+ Args:
+ runner: the trainer of methods.
+ train_loader: dataloader of train.
+ """
+ raise NotImplementedError
+
+ def _predict(self, **kwargs):
+ """Forward the model.
+
+ Args:
+ batch_x, batch_y: testing samples and groung truth.
+ """
+ raise NotImplementedError
+
+ def forward_test(self, batch_x, batch_y):
+ """Evaluate the model.
+
+ Args:
+ batch_x, batch_y: testing samples and groung truth.
+
+ Returns:
+ dict(tensor): The concatenated outputs with keys.
+ """
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+
+ with self.amp_autocast():
+ pred_y = self._predict(batch_x)
+
+ return dict(zip(['inputs', 'preds', 'trues'],
+ [batch_x, pred_y, batch_y]))
+
+ def vali_one_epoch(self, runner, vali_loader, **kwargs):
+ """Evaluate the model with val_loader.
+
+ Args:
+ runner: the trainer of methods.
+ val_loader: dataloader of validation.
+
+ Returns:
+ list(tensor, ...): The list of predictions and losses.
+ """
+ self.model.eval()
+ func = lambda *x: self.forward_test(*x)
+ if self.dist:
+ results = dist_forward_collect(func, vali_loader, self.rank,
+ len(vali_loader.dataset), to_numpy=True)
+ else:
+ results = nondist_forward_collect(func, vali_loader,
+ len(vali_loader.dataset), to_numpy=True)
+
+ preds = torch.tensor(results['preds']).to(self.device)
+ trues = torch.tensor(results['trues']).to(self.device)
+ losses_m = self.criterion(preds, trues).cpu().numpy()
+ return results['preds'], results['trues'], losses_m
+
+ def test_one_epoch(self, runner, test_loader, **kwargs):
+ """Evaluate the model with test_loader.
+
+ Args:
+ runner: the trainer of methods.
+ test_loader: dataloader of testing.
+
+ Returns:
+ list(tensor, ...): The list of inputs and predictions.
+ """
+ self.model.eval()
+ func = lambda *x: self.forward_test(*x)
+ if self.dist:
+ results = dist_forward_collect(func, test_loader, self.rank,
+ len(test_loader.dataset), to_numpy=True)
+ else:
+ results = nondist_forward_collect(func, test_loader,
+ len(test_loader.dataset), to_numpy=True)
+
+ return results['inputs'], results['preds'], results['trues']
+
+ def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
+ """Get current learning rates.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current learning rates of all
+ param groups. If the runner has a dict of optimizers, this method
+ will return a dict.
+ """
+ lr: Union[List[float], Dict[str, List[float]]]
+ if isinstance(self.model_optim, torch.optim.Optimizer):
+ lr = [group['lr'] for group in self.model_optim.param_groups]
+ elif isinstance(self.model_optim, dict):
+ lr = dict()
+ for name, optim in self.model_optim.items():
+ lr[name] = [group['lr'] for group in optim.param_groups]
+ else:
+ raise RuntimeError(
+ 'lr is not applicable because optimizer does not exist.')
+ return lr
+
+ def clip_grads(self, params, norm_type: float = 2.0):
+ """ Dispatch to gradient clipping method
+
+ Args:
+ parameters (Iterable): model parameters to clip
+ value (float): clipping value/factor/norm, mode dependant
+ mode (str): clipping mode, one of 'norm', 'value', 'agc'
+ norm_type (float): p-norm, default 2.0
+ """
+ if self.clip_mode is None:
+ return
+ if self.clip_mode == 'norm':
+ torch.nn.utils.clip_grad_norm_(params, self.clip_value, norm_type=norm_type)
+ elif self.clip_mode == 'value':
+ torch.nn.utils.clip_grad_value_(params, self.clip_value)
+ elif self.clip_mode == 'agc':
+ adaptive_clip_grad(params, self.clip_value, norm_type=norm_type)
+ else:
+ assert False, f"Unknown clip mode ({self.clip_mode})."
diff --git a/simvp/methods/convlstm.py b/openstl/methods/convlstm.py
similarity index 94%
rename from simvp/methods/convlstm.py
rename to openstl/methods/convlstm.py
index f2613d06..ef1d13f5 100644
--- a/simvp/methods/convlstm.py
+++ b/openstl/methods/convlstm.py
@@ -1,6 +1,6 @@
import torch.nn as nn
-from simvp.models import ConvLSTM_Model
+from openstl.models import ConvLSTM_Model
from .predrnn import PredRNN
diff --git a/simvp/methods/crevnet.py b/openstl/methods/crevnet.py
similarity index 95%
rename from simvp/methods/crevnet.py
rename to openstl/methods/crevnet.py
index aa61010e..f7d1535c 100644
--- a/simvp/methods/crevnet.py
+++ b/openstl/methods/crevnet.py
@@ -4,8 +4,8 @@
from tqdm import tqdm
from timm.utils import AverageMeter
-from simvp.core.optim_scheduler import get_optim_scheduler
-from simvp.models import CrevNet_Model
+from openstl.core.optim_scheduler import get_optim_scheduler
+from openstl.models import CrevNet_Model
from .base_method import Base_method
@@ -33,7 +33,7 @@ def _init_optimizer(self, steps_per_epoch):
self.model_optim2, self.scheduler2, self.by_epoch_2 = get_optim_scheduler(
self.args, self.args.epoch, self.model.encoder, steps_per_epoch)
- def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, **kwargs):
+ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, eta=None, **kwargs):
losses_m = AverageMeter()
self.model.train()
if self.by_epoch_1:
@@ -67,7 +67,7 @@ def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, **kwargs)
if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()
- return num_updates, loss_mean
+ return num_updates, loss_mean, eta
def vali_one_epoch(self, runner, vali_loader, **kwargs):
self.model.eval()
diff --git a/simvp/methods/e3dlstm.py b/openstl/methods/e3dlstm.py
similarity index 94%
rename from simvp/methods/e3dlstm.py
rename to openstl/methods/e3dlstm.py
index b93d535c..b0b5499c 100644
--- a/simvp/methods/e3dlstm.py
+++ b/openstl/methods/e3dlstm.py
@@ -1,6 +1,6 @@
import torch.nn as nn
-from simvp.models import E3DLSTM_Model
+from openstl.models import E3DLSTM_Model
from .predrnn import PredRNN
diff --git a/simvp/methods/mau.py b/openstl/methods/mau.py
similarity index 98%
rename from simvp/methods/mau.py
rename to openstl/methods/mau.py
index 7f263990..d683f106 100644
--- a/simvp/methods/mau.py
+++ b/openstl/methods/mau.py
@@ -4,8 +4,8 @@
from timm.utils import AverageMeter
from tqdm import tqdm
-from simvp.models import MAU_Model
-from simvp.utils import schedule_sampling
+from openstl.models import MAU_Model
+from openstl.utils import schedule_sampling
from .base_method import Base_method
diff --git a/simvp/methods/mim.py b/openstl/methods/mim.py
similarity index 96%
rename from simvp/methods/mim.py
rename to openstl/methods/mim.py
index 9229398e..0a239ffc 100644
--- a/simvp/methods/mim.py
+++ b/openstl/methods/mim.py
@@ -1,6 +1,6 @@
import torch.nn as nn
-from simvp.models import MIM_Model
+from openstl.models import MIM_Model
from .predrnn import PredRNN
diff --git a/simvp/methods/phydnet.py b/openstl/methods/phydnet.py
similarity index 96%
rename from simvp/methods/phydnet.py
rename to openstl/methods/phydnet.py
index 7dce8c6c..466e05f5 100644
--- a/simvp/methods/phydnet.py
+++ b/openstl/methods/phydnet.py
@@ -4,7 +4,7 @@
from timm.utils import AverageMeter
from tqdm import tqdm
-from simvp.models import PhyDNet_Model
+from openstl.models import PhyDNet_Model
from .base_method import Base_method
@@ -36,7 +36,7 @@ def _get_constraints(self):
ind +=1
return constraints
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, **kwargs):
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta=None, **kwargs):
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
@@ -64,7 +64,7 @@ def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, *
if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()
- return num_updates, loss_mean
+ return num_updates, loss_mean, eta
def vali_one_epoch(self, runner, vali_loader, **kwargs):
self.model.eval()
diff --git a/openstl/methods/predrnn.py b/openstl/methods/predrnn.py
new file mode 100644
index 00000000..b5ba3f38
--- /dev/null
+++ b/openstl/methods/predrnn.py
@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from timm.utils import AverageMeter
+from tqdm import tqdm
+
+from openstl.models import PredRNN_Model
+from openstl.utils import (reshape_patch, reshape_patch_back,
+ reserve_schedule_sampling_exp, schedule_sampling)
+from .base_method import Base_method
+
+
+class PredRNN(Base_method):
+ r"""PredRNN
+
+ Implementation of `PredRNN: A Recurrent Neural Network for Spatiotemporal
+ Predictive Learning `_.
+
+ """
+
+ def __init__(self, args, device, steps_per_epoch):
+ Base_method.__init__(self, args, device, steps_per_epoch)
+ self.model = self._build_model(self.args)
+ self.model_optim, self.scheduler, self.by_epoch = self._init_optimizer(steps_per_epoch)
+ self.criterion = nn.MSELoss()
+
+ def _build_model(self, args):
+ num_hidden = [int(x) for x in self.args.num_hidden.split(',')]
+ num_layers = len(num_hidden)
+ return PredRNN_Model(num_layers, num_hidden, args).to(self.device)
+
+ def _predict(self, batch_x, batch_y):
+ """Forward the model."""
+ # reverse schedule sampling
+ if self.args.reverse_scheduled_sampling == 1:
+ mask_input = 1
+ else:
+ mask_input = self.args.pre_seq_length
+ _, img_channel, img_height, img_width = self.args.in_shape
+
+ # preprocess
+ test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
+ test_dat = reshape_patch(test_ims, self.args.patch_size)
+ test_ims = test_ims[:, :, :, :, :img_channel]
+
+ real_input_flag = torch.zeros(
+ (batch_x.shape[0],
+ self.args.total_length - mask_input - 1,
+ img_height // self.args.patch_size,
+ img_width // self.args.patch_size,
+ self.args.patch_size ** 2 * img_channel)).to(self.device)
+
+ if self.args.reverse_scheduled_sampling == 1:
+ real_input_flag[:, :self.args.pre_seq_length - 1, :, :] = 1.0
+
+ img_gen, _ = self.model(test_dat, real_input_flag)
+ img_gen = reshape_patch_back(img_gen, self.args.patch_size)
+ pred_y = img_gen[:, -self.args.aft_seq_length:].permute(0, 1, 4, 2, 3).contiguous()
+
+ return pred_y
+
+ def forward_test(self, batch_x, batch_y):
+ """Evaluate the model"""
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+
+ with self.amp_autocast():
+ pred_y = self._predict(batch_x, batch_y)
+
+ return dict(zip(['inputs', 'preds', 'trues'],
+ [batch_x, pred_y, batch_y]))
+
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta, **kwargs):
+ losses_m = AverageMeter()
+ self.model.train()
+ if self.by_epoch:
+ self.scheduler.step(epoch)
+
+ train_pbar = tqdm(train_loader)
+ for batch_x, batch_y in train_pbar:
+ self.model_optim.zero_grad()
+ batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+
+ # preprocess
+ ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
+ ims = reshape_patch(ims, self.args.patch_size)
+ if self.args.reverse_scheduled_sampling == 1:
+ real_input_flag = reserve_schedule_sampling_exp(
+ num_updates, ims.shape[0], self.args)
+ else:
+ eta, real_input_flag = schedule_sampling(
+ eta, num_updates, ims.shape[0], self.args)
+
+ img_gen, loss = self.model(ims, real_input_flag)
+ loss.backward()
+ self.model_optim.step()
+ if not self.by_epoch:
+ self.scheduler.step(epoch)
+
+ num_updates += 1
+ loss_mean += loss.item()
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
+
+ if hasattr(self.model_optim, 'sync_lookahead'):
+ self.model_optim.sync_lookahead()
+
+ return num_updates, loss_mean, eta
diff --git a/simvp/methods/predrnnpp.py b/openstl/methods/predrnnpp.py
similarity index 94%
rename from simvp/methods/predrnnpp.py
rename to openstl/methods/predrnnpp.py
index 0b026ef1..3a79f27a 100644
--- a/simvp/methods/predrnnpp.py
+++ b/openstl/methods/predrnnpp.py
@@ -1,6 +1,6 @@
import torch.nn as nn
-from simvp.models import PredRNNpp_Model
+from openstl.models import PredRNNpp_Model
from .predrnn import PredRNN
diff --git a/simvp/methods/predrnnv2.py b/openstl/methods/predrnnv2.py
similarity index 94%
rename from simvp/methods/predrnnv2.py
rename to openstl/methods/predrnnv2.py
index c5469594..87811c34 100644
--- a/simvp/methods/predrnnv2.py
+++ b/openstl/methods/predrnnv2.py
@@ -3,8 +3,8 @@
from timm.utils import AverageMeter
from tqdm import tqdm
-from simvp.models import PredRNNv2_Model
-from simvp.utils import reshape_patch, reserve_schedule_sampling_exp, schedule_sampling
+from openstl.models import PredRNNv2_Model
+from openstl.utils import reshape_patch, reserve_schedule_sampling_exp, schedule_sampling
from .predrnn import PredRNN
diff --git a/simvp/methods/simvp.py b/openstl/methods/simvp.py
similarity index 50%
rename from simvp/methods/simvp.py
rename to openstl/methods/simvp.py
index 0ea52635..f1b03192 100644
--- a/simvp/methods/simvp.py
+++ b/openstl/methods/simvp.py
@@ -1,10 +1,11 @@
+import time
import torch
import torch.nn as nn
-import numpy as np
from tqdm import tqdm
from timm.utils import AverageMeter
-from simvp.models import SimVP_Model
+from openstl.models import SimVP_Model
+from openstl.utils import reduce_tensor
from .base_method import Base_method
@@ -26,6 +27,7 @@ def _build_model(self, config):
return SimVP_Model(**config).to(self.device)
def _predict(self, batch_x):
+ """Forward the model."""
if self.args.aft_seq_length == self.args.pre_seq_length:
pred_y = self.model(batch_x)
elif self.args.aft_seq_length < self.args.pre_seq_length:
@@ -48,75 +50,59 @@ def _predict(self, batch_x):
pred_y = torch.cat(pred_y, dim=1)
return pred_y
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, **kwargs):
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
runner.call_hook('before_train_iter')
- pred_y = self._predict(batch_x)
- loss = self.criterion(pred_y, batch_y)
- loss.backward()
+ with self.amp_autocast():
+ pred_y = self._predict(batch_x)
+ loss = self.criterion(pred_y, batch_y)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
+
self.model_optim.step()
+ torch.cuda.synchronize()
+ num_updates += 1
+
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
+
if not self.by_epoch:
self.scheduler.step()
-
- num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
runner.call_hook('after_train_iter')
runner._iter += 1
- train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
+
+ end = time.time() # end for
if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()
- return num_updates, loss_mean
-
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- self.model.eval()
- preds_lst, trues_lst, total_loss = [], [], []
- vali_pbar = tqdm(vali_loader)
- for i, (batch_x, batch_y) in enumerate(vali_pbar):
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- runner.call_hook('before_val_iter')
- pred_y = self._predict(batch_x)
- loss = self.criterion(pred_y, batch_y)
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()
- ), [pred_y, batch_y], [preds_lst, trues_lst]))
- runner.call_hook('after_val_iter')
- if i * batch_x.shape[0] > 1000:
- break
-
- vali_pbar.set_description('vali loss: {:.4f}'.format(loss.mean().item()))
- total_loss.append(loss.mean().item())
-
- total_loss = np.average(total_loss)
-
- preds = np.concatenate(preds_lst, axis=0)
- trues = np.concatenate(trues_lst, axis=0)
- return preds, trues, total_loss
-
- def test_one_epoch(self, runner, test_loader, **kwargs):
- self.model.eval()
- inputs_lst, trues_lst, preds_lst = [], [], []
- test_pbar = tqdm(test_loader)
- for batch_x, batch_y in test_pbar:
- runner.call_hook('before_val_iter')
- pred_y = self._predict(batch_x.to(self.device))
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
- batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
- runner.call_hook('after_val_iter')
-
- inputs, trues, preds = map(
- lambda data: np.concatenate(data, axis=0), [inputs_lst, trues_lst, preds_lst])
- return inputs, trues, preds
+ return num_updates, losses_m, eta
diff --git a/simvp/models/__init__.py b/openstl/models/__init__.py
similarity index 100%
rename from simvp/models/__init__.py
rename to openstl/models/__init__.py
diff --git a/simvp/models/convlstm_model.py b/openstl/models/convlstm_model.py
similarity index 98%
rename from simvp/models/convlstm_model.py
rename to openstl/models/convlstm_model.py
index 4742c5fe..3414447e 100644
--- a/simvp/models/convlstm_model.py
+++ b/openstl/models/convlstm_model.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from simvp.modules import ConvLSTMCell
+from openstl.modules import ConvLSTMCell
class ConvLSTM_Model(nn.Module):
diff --git a/simvp/models/crevnet_model.py b/openstl/models/crevnet_model.py
similarity index 98%
rename from simvp/models/crevnet_model.py
rename to openstl/models/crevnet_model.py
index 3369ed21..fbdc2e20 100644
--- a/simvp/models/crevnet_model.py
+++ b/openstl/models/crevnet_model.py
@@ -2,7 +2,7 @@
from torch import nn
from torch.autograd import Variable
-from simvp.modules import zig_rev_predictor, autoencoder
+from openstl.modules import zig_rev_predictor, autoencoder
class CrevNet_Model(nn.Module):
diff --git a/simvp/models/e3dlstm_model.py b/openstl/models/e3dlstm_model.py
similarity index 98%
rename from simvp/models/e3dlstm_model.py
rename to openstl/models/e3dlstm_model.py
index f09a6a6f..9dda637f 100644
--- a/simvp/models/e3dlstm_model.py
+++ b/openstl/models/e3dlstm_model.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from simvp.modules import Eidetic3DLSTMCell
+from openstl.modules import Eidetic3DLSTMCell
class E3DLSTM_Model(nn.Module):
diff --git a/simvp/models/mau_model.py b/openstl/models/mau_model.py
similarity index 99%
rename from simvp/models/mau_model.py
rename to openstl/models/mau_model.py
index 1bb79c3a..976f19d3 100644
--- a/simvp/models/mau_model.py
+++ b/openstl/models/mau_model.py
@@ -2,7 +2,7 @@
import torch
import torch.nn as nn
-from simvp.modules import MAUCell
+from openstl.modules import MAUCell
class MAU_Model(nn.Module):
diff --git a/simvp/models/mim_model.py b/openstl/models/mim_model.py
similarity index 98%
rename from simvp/models/mim_model.py
rename to openstl/models/mim_model.py
index 4dcd4cf0..367201b5 100644
--- a/simvp/models/mim_model.py
+++ b/openstl/models/mim_model.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from simvp.modules import SpatioTemporalLSTMCell, MIMBlock, MIMN
+from openstl.modules import SpatioTemporalLSTMCell, MIMBlock, MIMN
class MIM_Model(nn.Module):
diff --git a/simvp/models/phydnet_model.py b/openstl/models/phydnet_model.py
similarity index 97%
rename from simvp/models/phydnet_model.py
rename to openstl/models/phydnet_model.py
index 24ad01ac..e87c142a 100644
--- a/simvp/models/phydnet_model.py
+++ b/openstl/models/phydnet_model.py
@@ -2,7 +2,7 @@
import torch
from torch import nn
-from simvp.modules import PhyCell, PhyD_ConvLSTM, PhyD_EncoderRNN, K2M
+from openstl.modules import PhyCell, PhyD_ConvLSTM, PhyD_EncoderRNN, K2M
class PhyDNet_Model(nn.Module):
diff --git a/simvp/models/predrnn_model.py b/openstl/models/predrnn_model.py
similarity index 98%
rename from simvp/models/predrnn_model.py
rename to openstl/models/predrnn_model.py
index 055184b7..37f20501 100644
--- a/simvp/models/predrnn_model.py
+++ b/openstl/models/predrnn_model.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from simvp.modules import SpatioTemporalLSTMCell
+from openstl.modules import SpatioTemporalLSTMCell
class PredRNN_Model(nn.Module):
diff --git a/simvp/models/predrnnpp_model.py b/openstl/models/predrnnpp_model.py
similarity index 98%
rename from simvp/models/predrnnpp_model.py
rename to openstl/models/predrnnpp_model.py
index ec2f49e1..142d5b30 100644
--- a/simvp/models/predrnnpp_model.py
+++ b/openstl/models/predrnnpp_model.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from simvp.modules import CausalLSTMCell, GHU
+from openstl.modules import CausalLSTMCell, GHU
class PredRNNpp_Model(nn.Module):
diff --git a/simvp/models/predrnnv2_model.py b/openstl/models/predrnnv2_model.py
similarity index 98%
rename from simvp/models/predrnnv2_model.py
rename to openstl/models/predrnnv2_model.py
index 32c8ffb8..037f4eb3 100644
--- a/simvp/models/predrnnv2_model.py
+++ b/openstl/models/predrnnv2_model.py
@@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from simvp.modules import SpatioTemporalLSTMCellv2
+from openstl.modules import SpatioTemporalLSTMCellv2
class PredRNNv2_Model(nn.Module):
diff --git a/simvp/models/simvp_model.py b/openstl/models/simvp_model.py
similarity index 96%
rename from simvp/models/simvp_model.py
rename to openstl/models/simvp_model.py
index b6b05fdf..7eca2231 100644
--- a/simvp/models/simvp_model.py
+++ b/openstl/models/simvp_model.py
@@ -1,9 +1,9 @@
import torch
from torch import nn
-from simvp.modules import (ConvSC, ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
- HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
- SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock)
+from openstl.modules import (ConvSC, ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
+ HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
+ SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock)
class SimVP_Model(nn.Module):
diff --git a/simvp/modules/__init__.py b/openstl/modules/__init__.py
similarity index 100%
rename from simvp/modules/__init__.py
rename to openstl/modules/__init__.py
diff --git a/simvp/modules/convlstm_modules.py b/openstl/modules/convlstm_modules.py
similarity index 100%
rename from simvp/modules/convlstm_modules.py
rename to openstl/modules/convlstm_modules.py
diff --git a/simvp/modules/crevnet_modules.py b/openstl/modules/crevnet_modules.py
similarity index 100%
rename from simvp/modules/crevnet_modules.py
rename to openstl/modules/crevnet_modules.py
diff --git a/simvp/modules/e3dlstm_modules.py b/openstl/modules/e3dlstm_modules.py
similarity index 100%
rename from simvp/modules/e3dlstm_modules.py
rename to openstl/modules/e3dlstm_modules.py
diff --git a/simvp/modules/layers/__init__.py b/openstl/modules/layers/__init__.py
similarity index 100%
rename from simvp/modules/layers/__init__.py
rename to openstl/modules/layers/__init__.py
diff --git a/simvp/modules/layers/hornet.py b/openstl/modules/layers/hornet.py
similarity index 100%
rename from simvp/modules/layers/hornet.py
rename to openstl/modules/layers/hornet.py
diff --git a/simvp/modules/layers/moganet.py b/openstl/modules/layers/moganet.py
similarity index 100%
rename from simvp/modules/layers/moganet.py
rename to openstl/modules/layers/moganet.py
diff --git a/simvp/modules/layers/poolformer.py b/openstl/modules/layers/poolformer.py
similarity index 100%
rename from simvp/modules/layers/poolformer.py
rename to openstl/modules/layers/poolformer.py
diff --git a/simvp/modules/layers/uniformer.py b/openstl/modules/layers/uniformer.py
similarity index 100%
rename from simvp/modules/layers/uniformer.py
rename to openstl/modules/layers/uniformer.py
diff --git a/simvp/modules/layers/van.py b/openstl/modules/layers/van.py
similarity index 100%
rename from simvp/modules/layers/van.py
rename to openstl/modules/layers/van.py
diff --git a/simvp/modules/mau_modules.py b/openstl/modules/mau_modules.py
similarity index 100%
rename from simvp/modules/mau_modules.py
rename to openstl/modules/mau_modules.py
diff --git a/simvp/modules/mim_modules.py b/openstl/modules/mim_modules.py
similarity index 100%
rename from simvp/modules/mim_modules.py
rename to openstl/modules/mim_modules.py
diff --git a/simvp/modules/phydnet_modules.py b/openstl/modules/phydnet_modules.py
similarity index 100%
rename from simvp/modules/phydnet_modules.py
rename to openstl/modules/phydnet_modules.py
diff --git a/simvp/modules/predrnn_modules.py b/openstl/modules/predrnn_modules.py
similarity index 100%
rename from simvp/modules/predrnn_modules.py
rename to openstl/modules/predrnn_modules.py
diff --git a/simvp/modules/predrnnpp_modules.py b/openstl/modules/predrnnpp_modules.py
similarity index 100%
rename from simvp/modules/predrnnpp_modules.py
rename to openstl/modules/predrnnpp_modules.py
diff --git a/simvp/modules/predrnnv2_modules.py b/openstl/modules/predrnnv2_modules.py
similarity index 100%
rename from simvp/modules/predrnnv2_modules.py
rename to openstl/modules/predrnnv2_modules.py
diff --git a/simvp/modules/simvp_modules.py b/openstl/modules/simvp_modules.py
similarity index 100%
rename from simvp/modules/simvp_modules.py
rename to openstl/modules/simvp_modules.py
diff --git a/openstl/utils/__init__.py b/openstl/utils/__init__.py
new file mode 100644
index 00000000..9be4b5cc
--- /dev/null
+++ b/openstl/utils/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) CAIRI AI Lab. All rights reserved
+
+from .collect import (gather_tensors, gather_tensors_batch, nondist_forward_collect,
+ dist_forward_collect, collect_results_gpu)
+from .config_utils import Config, check_file_exist
+from .main_utils import (set_seed, setup_multi_processes, print_log, output_namespace,
+ collect_env, check_dir, get_dataset, count_parameters, measure_throughput,
+ load_config, update_config, weights_to_cpu,
+ init_dist, init_random_seed, get_dist_info, reduce_tensor)
+from .parser import create_parser
+from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch,
+ reshape_patch_back)
+from .progressbar import ProgressBar, Timer
+
+__all__ = [
+ 'collect_results_gpu', 'gather_tensors', 'gather_tensors_batch',
+ 'nondist_forward_collect', 'dist_forward_collect',
+ 'Config', 'check_file_exist', 'create_parser',
+ 'set_seed', 'setup_multi_processes', 'print_log', 'output_namespace', 'collect_env', 'check_dir',
+ 'get_dataset', 'count_parameters', 'measure_throughput', 'load_config', 'update_config', 'weights_to_cpu',
+ 'init_dist', 'init_random_seed', 'get_dist_info', 'reduce_tensor',
+ 'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back',
+ 'ProgressBar', 'Timer',
+]
\ No newline at end of file
diff --git a/openstl/utils/collect.py b/openstl/utils/collect.py
new file mode 100644
index 00000000..a6ea3b90
--- /dev/null
+++ b/openstl/utils/collect.py
@@ -0,0 +1,210 @@
+import numpy as np
+import pickle
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+from .main_utils import get_dist_info
+from .progressbar import ProgressBar
+
+
+def gather_tensors(input_array):
+ """Gather tensor from all GPUs."""
+ world_size = dist.get_world_size()
+ # gather shapes first
+ myshape = input_array.shape
+ mycount = input_array.size
+ shape_tensor = torch.Tensor(np.array(myshape)).cuda()
+ all_shape = [
+ torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)
+ ]
+ dist.all_gather(all_shape, shape_tensor)
+ # compute largest shapes
+ all_shape = [x.cpu().numpy() for x in all_shape]
+ all_count = [int(x.prod()) for x in all_shape]
+ all_shape = [list(map(int, x)) for x in all_shape]
+ max_count = max(all_count)
+ # padding tensors and gather them
+ output_tensors = [
+ torch.Tensor(max_count).cuda() for i in range(world_size)
+ ]
+ padded_input_array = np.zeros(max_count)
+ padded_input_array[:mycount] = input_array.reshape(-1)
+ input_tensor = torch.Tensor(padded_input_array).cuda()
+ dist.all_gather(output_tensors, input_tensor)
+ # unpadding gathered tensors
+ padded_output = [x.cpu().numpy() for x in output_tensors]
+ output = [
+ x[:all_count[i]].reshape(all_shape[i])
+ for i, x in enumerate(padded_output)
+ ]
+ return output
+
+
+def gather_tensors_batch(input_array, part_size=100, ret_rank=-1):
+ """batch-wise gathering to avoid CUDA out of memory."""
+ rank = dist.get_rank()
+ all_features = []
+ part_num = input_array.shape[0] // part_size + 1 if input_array.shape[
+ 0] % part_size != 0 else input_array.shape[0] // part_size
+ for i in range(part_num):
+ part_feat = input_array[i *
+ part_size:min((i + 1) *
+ part_size, input_array.shape[0]),
+ ...]
+ assert part_feat.shape[
+ 0] > 0, f'rank: {rank}, length of part features should > 0'
+ gather_part_feat = gather_tensors(part_feat)
+ all_features.append(gather_part_feat)
+ if ret_rank == -1:
+ all_features = [
+ np.concatenate([all_features[i][j] for i in range(part_num)],
+ axis=0) for j in range(len(all_features[0]))
+ ]
+ return all_features
+ else:
+ if rank == ret_rank:
+ all_features = [
+ np.concatenate([all_features[i][j] for i in range(part_num)],
+ axis=0) for j in range(len(all_features[0]))
+ ]
+ return all_features
+ else:
+ return None
+
+
+def nondist_forward_collect(func, data_loader, length, to_numpy=False):
+ """Forward and collect network outputs.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+
+ Args:
+ func (function): The function to process data. The output must be
+ a list of CPU tensors.
+ length (int): Expected length of output arrays.
+ to_numpy (bool): Whether to conver tensors to the numpy array.
+
+ Returns:
+ results_all (dict(np.ndarray)): The concatenated outputs.
+ """
+ results = []
+ prog_bar = ProgressBar(len(data_loader))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(*data) # list{tensor, ...}
+ results.append(result)
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ if to_numpy:
+ results_all[k] = np.concatenate(
+ [batch[k].cpu().numpy() for batch in results], axis=0)
+ else:
+ results_all[k] = torch.cat(
+ [batch[k] for batch in results], dim=0)
+ assert results_all[k].shape[0] == length
+ return results_all
+
+
+def dist_forward_collect(func, data_loader, rank, length, ret_rank=-1, to_numpy=False):
+ """Forward and collect network outputs in a distributed manner.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+
+ Args:
+ func (function): The function to process data. The output must be
+ a list of CPU tensors.
+ rank (int): This process id.
+ length (int): Expected length of output arrays.
+ ret_rank (int): The process that returns.
+ Other processes will return None.
+ to_numpy (bool): Whether to conver tensors to the numpy array.
+
+ Returns:
+ results_all (dict(np.ndarray)): The concatenated outputs.
+ """
+ assert to_numpy == True
+ results = []
+ if rank == 0:
+ prog_bar = ProgressBar(len(data_loader))
+ for idx, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(*data) # list{tensor, ...}
+ results.append(result)
+
+ if rank == 0:
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ results_cat = np.concatenate([batch[k].cpu().numpy() for batch in results],
+ axis=0)
+ if ret_rank == -1:
+ results_gathered = gather_tensors_batch(results_cat, part_size=20)
+ results_strip = np.concatenate(results_gathered, axis=0)[:length]
+ else:
+ results_gathered = gather_tensors_batch(
+ results_cat, part_size=20, ret_rank=ret_rank)
+ if rank == ret_rank:
+ results_strip = np.concatenate(
+ results_gathered, axis=0)[:length]
+ else:
+ results_strip = None
+ results_all[k] = results_strip
+ return results_all
+
+
+def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
+ else:
+ return None
diff --git a/simvp/utils/config_utils.py b/openstl/utils/config_utils.py
similarity index 100%
rename from simvp/utils/config_utils.py
rename to openstl/utils/config_utils.py
diff --git a/openstl/utils/main_utils.py b/openstl/utils/main_utils.py
new file mode 100644
index 00000000..98d756c7
--- /dev/null
+++ b/openstl/utils/main_utils.py
@@ -0,0 +1,307 @@
+# Copyright (c) CAIRI AI Lab. All rights reserved
+
+import cv2
+import os
+import logging
+import platform
+import random
+import subprocess
+import sys
+import warnings
+import numpy as np
+from collections import defaultdict, OrderedDict
+from typing import Tuple
+
+import torch
+import torchvision
+import torch.multiprocessing as mp
+from torch import distributed as dist
+
+import openstl
+from .config_utils import Config
+
+
+def set_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ else:
+ torch.backends.cudnn.benchmark = True
+
+
+def setup_multi_processes(cfg):
+ """Setup multi-processing environment variables."""
+ # set multi-process start method as `fork` to speed up the training
+ if platform.system() != 'Windows':
+ mp_start_method = cfg.get('mp_start_method', 'fork')
+ current_method = mp.get_start_method(allow_none=True)
+ if current_method is not None and current_method != mp_start_method:
+ warnings.warn(
+ f'Multi-processing start method `{mp_start_method}` is '
+ f'different from the previous setting `{current_method}`.'
+ f'It will be force set to `{mp_start_method}`. You can change '
+ f'this behavior by changing `mp_start_method` in your config.')
+ mp.set_start_method(mp_start_method, force=True)
+
+ # disable opencv multithreading to avoid system being overloaded
+ opencv_num_threads = cfg.get('opencv_num_threads', 0)
+ cv2.setNumThreads(opencv_num_threads)
+
+ # setup OMP threads
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
+ if 'OMP_NUM_THREADS' not in os.environ and cfg['num_workers'] > 1:
+ omp_num_threads = 1
+ warnings.warn(
+ f'Setting OMP_NUM_THREADS environment variable for each process '
+ f'to be {omp_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+ # setup MKL threads
+ if 'MKL_NUM_THREADS' not in os.environ and cfg['num_workers'] > 1:
+ mkl_num_threads = 1
+ warnings.warn(
+ f'Setting MKL_NUM_THREADS environment variable for each process '
+ f'to be {mkl_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = {}
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+
+ cuda_available = torch.cuda.is_available()
+ env_info['CUDA available'] = cuda_available
+
+ if cuda_available:
+ from torch.utils.cpp_extension import CUDA_HOME
+ env_info['CUDA_HOME'] = CUDA_HOME
+
+ if CUDA_HOME is not None and os.path.isdir(CUDA_HOME):
+ try:
+ nvcc = os.path.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(
+ '"{}" -V | tail -n1'.format(nvcc), shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, devids in devices.items():
+ env_info['GPU ' + ','.join(devids)] = name
+
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+ gcc = gcc.decode('utf-8').strip()
+ env_info['GCC'] = gcc
+
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = torch.__config__.show()
+ env_info['TorchVision'] = torchvision.__version__
+ env_info['OpenCV'] = cv2.__version__
+
+ env_info['openstl'] = openstl.__version__
+
+ return env_info
+
+
+def print_log(message):
+ print(message)
+ logging.info(message)
+
+
+def output_namespace(namespace):
+ configs = namespace.__dict__
+ message = ''
+ for k, v in configs.items():
+ message += '\n' + k + ': \t' + str(v) + '\t'
+ return message
+
+
+def check_dir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+ return False
+ return True
+
+
+def get_dataset(dataname, config):
+ from openstl.datasets import dataset_parameters
+ from openstl.datasets import load_data
+ config.update(dataset_parameters[dataname])
+ return load_data(**config)
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def measure_throughput(model, input_dummy):
+ bs = 100
+ repetitions = 100
+ if isinstance(input_dummy, tuple):
+ input_dummy = list(input_dummy)
+ _, T, C, H, W = input_dummy[0].shape
+ _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device)
+ input_dummy[0] = _input
+ input_dummy = tuple(input_dummy)
+ else:
+ _, T, C, H, W = input_dummy.shape
+ input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device)
+ total_time = 0
+ with torch.no_grad():
+ for _ in range(repetitions):
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+ if isinstance(input_dummy, tuple):
+ _ = model(*input_dummy)
+ else:
+ _ = model(input_dummy)
+ ender.record()
+ torch.cuda.synchronize()
+ curr_time = starter.elapsed_time(ender) / 1000
+ total_time += curr_time
+ Throughput = (repetitions * bs) / total_time
+ return Throughput
+
+
+def load_config(filename:str = None):
+ """load and print config"""
+ print('loading config from ' + filename + ' ...')
+ try:
+ configfile = Config(filename=filename)
+ config = configfile._cfg_dict
+ except (FileNotFoundError, IOError):
+ config = dict()
+ print('warning: fail to load the config!')
+ return config
+
+
+def update_config(args, config, exclude_keys=list()):
+ """update the args dict with a new config"""
+ assert isinstance(args, dict) and isinstance(config, dict)
+ for k in config.keys():
+ if args.get(k, False):
+ if args[k] != config[k] and k not in exclude_keys:
+ print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}')
+ else:
+ args[k] = config[k]
+ else:
+ args[k] = config[k]
+ return args
+
+
+def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ # Keep metadata in state_dict
+ state_dict_cpu._metadata = getattr( # type: ignore
+ state_dict, '_metadata', OrderedDict())
+ return state_dict_cpu
+
+
+def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def init_random_seed(seed=None, device='cuda'):
+ """Initialize random seed.
+
+ If the seed is not set, the seed will be automatically randomized,
+ and then broadcast to all processes to prevent some potential bugs.
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is not None:
+ return seed
+
+ # Make sure all ranks share the same random seed to prevent
+ # some potential bugs. Please refer to
+ # https://github.com/open-mmlab/mmdetection/issues/6339
+ rank, world_size = get_dist_info()
+ seed = np.random.randint(2**31)
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+ dist.broadcast(random_num, src=0)
+ return random_num.item()
+
+
+def _init_dist_pytorch(backend: str, **kwargs) -> None:
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend: str, **kwargs) -> None:
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ torch.cuda.set_device(local_rank)
+ if 'MASTER_PORT' not in os.environ:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ if 'MASTER_ADDR' not in os.environ:
+ raise KeyError('The environment variable MASTER_ADDR is not set')
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def get_dist_info() -> Tuple[int, int]:
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def reduce_tensor(tensor):
+ rt = tensor.data.clone()
+ dist.all_reduce(rt.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return rt
diff --git a/simvp/utils/parser.py b/openstl/utils/parser.py
similarity index 69%
rename from simvp/utils/parser.py
rename to openstl/utils/parser.py
index c5b59a96..7f50087a 100644
--- a/simvp/utils/parser.py
+++ b/openstl/utils/parser.py
@@ -4,7 +4,8 @@
def create_parser():
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(
+ description='OpenSTL train/test a model')
# Set-up parameters
parser.add_argument('--device', default='cuda', type=str,
help='Name of device to use for tensor computations (cuda/cpu)')
@@ -12,23 +13,36 @@ def create_parser():
help='Whether to use distributed training (DDP)')
parser.add_argument('--display_step', default=10, type=int,
help='Interval in batches between display of training metrics')
- parser.add_argument('--res_dir', default='./results', type=str)
- parser.add_argument('--ex_name', default='Debug', type=str)
+ parser.add_argument('--res_dir', default='work_dirs', type=str)
+ parser.add_argument('--ex_name', '-ex', default='Debug', type=str)
parser.add_argument('--use_gpu', default=True, type=bool)
- parser.add_argument('--gpu', default=0, type=int)
+ parser.add_argument('--fp16', action='store_true', default=False,
+ help='Whether to use Native AMP for mixed precision training (PyTorch=>1.6.0)')
+ parser.add_argument('--torchscript', action='store_true', default=False,
+ help='Whether to use torchscripted model')
parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--diff_seed', action='store_true', default=False,
+ help='Whether or not set different seeds for different ranks')
parser.add_argument('--fps', action='store_true', default=False,
help='Whether to measure inference speed (FPS)')
parser.add_argument('--resume_from', type=str, default=None, help='the checkpoint file to resume from')
parser.add_argument('--auto_resume', action='store_true', default=False,
help='When training was interupted, resume from the latest checkpoint')
parser.add_argument('--test', action='store_true', default=False, help='Only performs testing')
+ parser.add_argument('--deterministic', action='store_true', default=False,
+ help='whether to set deterministic options for CUDNN backend (reproducable)')
+ parser.add_argument('--launcher', default='none', type=str,
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ help='job launcher for distributed training')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument('--port', type=int, default=29500,
+ help='port only works when launcher=="slurm"')
# dataset parameters
parser.add_argument('--batch_size', '-b', default=16, type=int, help='Training batch size')
parser.add_argument('--val_batch_size', '-vb', default=4, type=int, help='Validation batch size')
- parser.add_argument('--num_workers', default=8, type=int)
- parser.add_argument('--data_root', default='./data/')
+ parser.add_argument('--num_workers', default=4, type=int)
+ parser.add_argument('--data_root', default='./data')
parser.add_argument('--dataname', '-d', default='mmnist', type=str,
help='Dataset name (default: "mmnist")')
parser.add_argument('--pre_seq_length', default=None, type=int, help='Sequence length before prediction')
@@ -42,15 +56,15 @@ def create_parser():
'PredRNN', 'predrnn', 'PredRNNpp', 'predrnnpp', 'PredRNNv2', 'predrnnv2',
'SimVP', 'simvp'],
help='Name of video prediction method to train (default: "SimVP")')
- parser.add_argument('--config_file', '-c', default='./configs/mmnist/simvp/SimVP_gSTA.py', type=str,
+ parser.add_argument('--config_file', '-c', default='configs/mmnist/simvp/SimVP_gSTA.py', type=str,
help='Path to the default config file')
parser.add_argument('--model_type', default=None, type=str,
help='Name of model for SimVP (default: None)')
parser.add_argument('--drop', type=float, default=0.0, help='Dropout rate(default: 0.)')
- parser.add_argument('--drop_path', type=float, default=0.1, help='Drop path rate for SimVP (default: 0.1)')
+ parser.add_argument('--drop_path', type=float, default=0.0, help='Drop path rate for SimVP (default: 0.)')
- # Training parameters
- parser.add_argument('--epoch', default=200, type=int, help='end epochs')
+ # Training parameters (optimizer)
+ parser.add_argument('--epoch', '-e', default=200, type=int, help='end epochs')
parser.add_argument('--log_step', default=1, type=int, help='Log interval by step')
parser.add_argument('--opt', default='adam', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adam"')
@@ -61,6 +75,12 @@ def create_parser():
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer sgd momentum (default: 0.9)')
parser.add_argument('--weight_decay', default=0., type=float, help='Weight decay')
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--clip_mode', type=str, default='norm',
+ help='Gradient clipping mode. One of ("norm", "value", "agc")')
+
+ # Training parameters (scheduler)
parser.add_argument('--sched', default='onecycle', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "onecycle"')
parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
diff --git a/simvp/utils/predrnn_utils.py b/openstl/utils/predrnn_utils.py
similarity index 100%
rename from simvp/utils/predrnn_utils.py
rename to openstl/utils/predrnn_utils.py
diff --git a/openstl/utils/progressbar.py b/openstl/utils/progressbar.py
new file mode 100644
index 00000000..01aeb38f
--- /dev/null
+++ b/openstl/utils/progressbar.py
@@ -0,0 +1,324 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import sys
+from time import time
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+
+
+class ProgressBar:
+ """A progress bar which can print the progress."""
+
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+ self.task_num = task_num
+ self.bar_width = bar_width
+ self.completed = 0
+ self.file = file
+ if start:
+ self.start()
+
+ @property
+ def terminal_width(self):
+ width, _ = get_terminal_size()
+ return width
+
+ def start(self):
+ if self.task_num > 0:
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+ 'elapsed: 0s, ETA:')
+ else:
+ self.file.write('completed: 0, elapsed: 0s')
+ self.file.flush()
+ self.timer = Timer()
+
+ def update(self, num_tasks=1):
+ assert num_tasks > 0
+ self.completed += num_tasks
+ elapsed = self.timer.since_start()
+ if elapsed > 0:
+ fps = self.completed / elapsed
+ else:
+ fps = float('inf')
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+ f'ETA: {eta:5}s'
+
+ bar_width = min(self.bar_width,
+ int(self.terminal_width - len(msg)) + 2,
+ int(self.terminal_width * 0.6))
+ bar_width = max(2, bar_width)
+ mark_width = int(bar_width * percentage)
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+ self.file.write(msg.format(bar_chars))
+ else:
+ self.file.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+ f' {fps:.1f} tasks/s')
+ self.file.flush()
+
+
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+
+ Tasks are done with a simple for-loop.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+ file=sys.stdout):
+ """Track the progress of parallel task execution with a progress bar.
+
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ pool.close()
+ pool.join()
+ return results
+
+
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+ """Track the progress of tasks iteration or enumeration with a progress
+ bar.
+
+ Tasks are yielded with a simple for-loop.
+
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ prog_bar.file.write('\n')
+
+
+class TimerError(Exception):
+
+ def __init__(self, message):
+ self.message = message
+ super().__init__(message)
+
+
+class Timer:
+ """A flexible Timer class.
+
+ Examples:
+ >>> import time
+ >>> import mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+ if start:
+ self.start()
+
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+
+ def since_start(self):
+ """Total time since the timer is started.
+
+ Returns:
+ float: Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ self._t_last = time()
+ return self._t_last - self._t_start
+
+ def since_last_check(self):
+ """Time since the last checking.
+
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+
+ Returns:
+ float: Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+
+
+_g_timers = {} # global timers
+
+
+def check_time(timer_id):
+ """Add check points in a single line.
+
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+
+ Examples:
+ >>> import time
+ >>> import mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+
+ Args:
+ str: Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/simvp/version.py b/openstl/version.py
similarity index 97%
rename from simvp/version.py
rename to openstl/version.py
index 8b2f7aaa..429c1441 100644
--- a/simvp/version.py
+++ b/openstl/version.py
@@ -1,6 +1,6 @@
# Copyright (c) CAIRI AI Lab. All rights reserved
-__version__ = '0.1.0'
+__version__ = '0.2.0'
def parse_version_info(version_str):
diff --git a/setup.py b/setup.py
index 85b41593..309ac880 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@ def readme():
def get_version():
- version_file = 'simvp/version.py'
+ version_file = 'openstl/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
@@ -96,15 +96,16 @@ def gen_packages_items():
if __name__ == '__main__':
setup(
- name='SimVP',
+ name='OpenSTL',
version=get_version(),
- description='SimVP: Towards Simple yet Powerful Spatiotemporal Predictive learning',
+ description='OpenSTL: Open-source Toolbox for SpatioTemporal Predictive Learning',
long_description=readme(),
long_description_content_type='text/markdown',
author='CAIRI Westlake University Contributors',
author_email='lisiyuan@westlake.edu.com',
- keywords='video prediction, unsupervised spatiotemporal learning',
- url='https://github.com/chengtan9907/SimVPv2',
+ keywords='spatiotemporal predictive learning, video prediction, '
+ 'unsupervised spatiotemporal learning',
+ url='https://github.com/chengtan9907/OpenSTL',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
classifiers=[
'Development Status :: 4 - Beta',
@@ -114,6 +115,9 @@ def gen_packages_items():
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
license='Apache License 2.0',
tests_require=parse_requirements('requirements/tests.txt'),
@@ -121,5 +125,6 @@ def gen_packages_items():
extras_require={
'all': parse_requirements('requirements.txt'),
'tests': parse_requirements('requirements/tests.txt'),
+ 'optional': parse_requirements('requirements/optional.txt'),
},
zip_safe=False)
diff --git a/simvp/api/__init__.py b/simvp/api/__init__.py
deleted file mode 100644
index f171ae3f..00000000
--- a/simvp/api/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) CAIRI AI Lab. All rights reserved
-
-from .train import NonDistExperiment
-
-__all__ = ['NonDistExperiment']
\ No newline at end of file
diff --git a/simvp/datasets/dataloader_taxibj.py b/simvp/datasets/dataloader_taxibj.py
deleted file mode 100644
index 6faeb99b..00000000
--- a/simvp/datasets/dataloader_taxibj.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch
-import numpy as np
-from torch.utils.data import Dataset
-
-
-class TaxibjDataset(Dataset):
- """Taxibj `_ Dataset"""
-
- def __init__(self, X, Y):
- super(TaxibjDataset, self).__init__()
- self.X = (X+1)/2
- self.Y = (Y+1)/2
- self.mean = 0
- self.std = 1
-
- def __len__(self):
- return self.X.shape[0]
-
- def __getitem__(self, index):
- data = torch.tensor(self.X[index, ::]).float()
- labels = torch.tensor(self.Y[index, ::]).float()
- return data, labels
-
-
-def load_data(batch_size, val_batch_size, data_root,
- num_workers=4, pre_seq_length=None, aft_seq_length=None):
-
- dataset = np.load(data_root+'taxibj/dataset.npz')
- X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset[
- 'Y_train'], dataset['X_test'], dataset['Y_test']
- train_set = TaxibjDataset(X=X_train, Y=Y_train)
- test_set = TaxibjDataset(X=X_test, Y=Y_test)
-
- dataloader_train = torch.utils.data.DataLoader(train_set,
- batch_size=batch_size, shuffle=True,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_vali = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
- dataloader_test = torch.utils.data.DataLoader(test_set,
- batch_size=val_batch_size, shuffle=False,
- pin_memory=True, drop_last=True,
- num_workers=num_workers)
-
- return dataloader_train, dataloader_vali, dataloader_test
diff --git a/simvp/methods/base_method.py b/simvp/methods/base_method.py
deleted file mode 100644
index 0cdec27b..00000000
--- a/simvp/methods/base_method.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-from typing import Dict, List, Union
-
-from simvp.core.optim_scheduler import get_optim_scheduler
-
-
-class Base_method(object):
- """Base Method.
-
- This class defines the basic functions of a video prediction (VP)
- method training and testing. Any VP method that inherits this class
- should at least define its own `train_one_epoch`, `vali_one_epoch`,
- and `test_one_epoch` function.
-
- """
-
- def __init__(self, args, device, steps_per_epoch):
- super(Base_method, self).__init__()
- self.args = args
- self.device = device
- self.config = args.__dict__
- self.criterion = None
- self.model_optim = None
- self.scheduler = None
-
- def _build_model(self, **kwargs):
- raise NotImplementedError
-
- def _init_optimizer(self, steps_per_epoch):
- return get_optim_scheduler(
- self.args, self.args.epoch, self.model, steps_per_epoch)
-
- def train_one_epoch(self, runner, train_loader, **kwargs):
- '''
- Train the model with train_loader.
- Input params:
- train_loader: dataloader of train.
- '''
- raise NotImplementedError
-
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- '''
- Evaluate the model with val_loader.
- Input params:
- val_loader: dataloader of validation.
- '''
- raise NotImplementedError
-
- def test_one_epoch(self, runner, test_loader, **kwargs):
- raise NotImplementedError
-
- def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
- """Get current learning rates.
-
- Returns:
- list[float] | dict[str, list[float]]: Current learning rates of all
- param groups. If the runner has a dict of optimizers, this method
- will return a dict.
- """
- lr: Union[List[float], Dict[str, List[float]]]
- if isinstance(self.model_optim, torch.optim.Optimizer):
- lr = [group['lr'] for group in self.model_optim.param_groups]
- elif isinstance(self.model_optim, dict):
- lr = dict()
- for name, optim in self.model_optim.items():
- lr[name] = [group['lr'] for group in optim.param_groups]
- else:
- raise RuntimeError(
- 'lr is not applicable because optimizer does not exist.')
- return lr
diff --git a/simvp/methods/predrnn.py b/simvp/methods/predrnn.py
deleted file mode 100644
index 1e535fe8..00000000
--- a/simvp/methods/predrnn.py
+++ /dev/null
@@ -1,161 +0,0 @@
-import torch
-import torch.nn as nn
-import numpy as np
-from timm.utils import AverageMeter
-from tqdm import tqdm
-
-from simvp.models import PredRNN_Model
-from simvp.utils import (reshape_patch, reshape_patch_back,
- reserve_schedule_sampling_exp, schedule_sampling)
-from .base_method import Base_method
-
-
-class PredRNN(Base_method):
- r"""PredRNN
-
- Implementation of `PredRNN: A Recurrent Neural Network for Spatiotemporal
- Predictive Learning `_.
-
- """
-
- def __init__(self, args, device, steps_per_epoch):
- Base_method.__init__(self, args, device, steps_per_epoch)
- self.model = self._build_model(self.args)
- self.model_optim, self.scheduler, self.by_epoch = self._init_optimizer(steps_per_epoch)
- self.criterion = nn.MSELoss()
-
- def _build_model(self, args):
- num_hidden = [int(x) for x in self.args.num_hidden.split(',')]
- num_layers = len(num_hidden)
- return PredRNN_Model(num_layers, num_hidden, args).to(self.device)
-
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta, **kwargs):
- losses_m = AverageMeter()
- self.model.train()
- if self.by_epoch:
- self.scheduler.step(epoch)
-
- train_pbar = tqdm(train_loader)
- for batch_x, batch_y in train_pbar:
- self.model_optim.zero_grad()
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
-
- # preprocess
- ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
- ims = reshape_patch(ims, self.args.patch_size)
- if self.args.reverse_scheduled_sampling == 1:
- real_input_flag = reserve_schedule_sampling_exp(
- num_updates, ims.shape[0], self.args)
- else:
- eta, real_input_flag = schedule_sampling(
- eta, num_updates, ims.shape[0], self.args)
-
- img_gen, loss = self.model(ims, real_input_flag)
- loss.backward()
- self.model_optim.step()
- if not self.by_epoch:
- self.scheduler.step(epoch)
-
- num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
-
- train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
-
- if hasattr(self.model_optim, 'sync_lookahead'):
- self.model_optim.sync_lookahead()
-
- return num_updates, loss_mean, eta
-
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- self.model.eval()
- preds_lst, trues_lst, total_loss = [], [], []
- vali_pbar = tqdm(vali_loader)
-
- # reverse schedule sampling
- if self.args.reverse_scheduled_sampling == 1:
- mask_input = 1
- else:
- mask_input = self.args.pre_seq_length
-
- _, img_channel, img_height, img_width = self.args.in_shape
-
- for i, (batch_x, batch_y) in enumerate(vali_pbar):
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
-
- # preprocess
- test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
- test_dat = reshape_patch(test_ims, self.args.patch_size)
- test_ims = test_ims[:, :, :, :, :img_channel]
-
- real_input_flag = torch.zeros(
- (batch_x.shape[0],
- self.args.total_length - mask_input - 1,
- img_height // self.args.patch_size,
- img_width // self.args.patch_size,
- self.args.patch_size ** 2 * img_channel)).to(self.device)
-
- if self.args.reverse_scheduled_sampling == 1:
- real_input_flag[:, :self.args.pre_seq_length - 1, :, :] = 1.0
-
- img_gen, loss = self.model(test_dat, real_input_flag)
-
- img_gen = reshape_patch_back(img_gen, self.args.patch_size)
- pred_y = img_gen[:, -self.args.aft_seq_length:].permute(0, 1, 4, 2, 3).contiguous()
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()
- ), [pred_y, batch_y], [preds_lst, trues_lst]))
-
- if i * batch_x.shape[0] > 1000:
- break
-
- vali_pbar.set_description('vali loss: {:.4f}'.format(loss.mean().item()))
- total_loss.append(loss.mean().item())
-
- total_loss = np.average(total_loss)
-
- preds = np.concatenate(preds_lst, axis=0)
- trues = np.concatenate(trues_lst, axis=0)
- return preds, trues, total_loss
-
- def test_one_epoch(self, runner, test_loader, **kwargs):
- self.model.eval()
- inputs_lst, trues_lst, preds_lst = [], [], []
- test_pbar = tqdm(test_loader)
-
- # reverse schedule sampling
- if self.args.reverse_scheduled_sampling == 1:
- mask_input = 1
- else:
- mask_input = self.args.pre_seq_length
-
- _, img_channel, img_height, img_width = self.args.in_shape
-
- for batch_x, batch_y in test_pbar:
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
-
- # preprocess
- test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
- test_dat = reshape_patch(test_ims, self.args.patch_size)
- test_ims = test_ims[:, :, :, :, :img_channel]
-
- real_input_flag = torch.zeros(
- (batch_x.shape[0],
- self.args.total_length - mask_input - 1,
- img_height // self.args.patch_size,
- img_width // self.args.patch_size,
- self.args.patch_size ** 2 * img_channel)).to(self.device)
-
- if self.args.reverse_scheduled_sampling == 1:
- real_input_flag[:, :self.args.pre_seq_length - 1, :, :] = 1.0
-
- img_gen, _ = self.model(test_dat, real_input_flag)
- img_gen = reshape_patch_back(img_gen, self.args.patch_size)
- pred_y = img_gen[:, -self.args.aft_seq_length:].permute(0, 1, 4, 2, 3).contiguous()
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
- batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
-
- inputs, trues, preds = map(
- lambda data: np.concatenate(data, axis=0), [inputs_lst, trues_lst, preds_lst])
- return inputs, trues, preds
diff --git a/simvp/utils/__init__.py b/simvp/utils/__init__.py
deleted file mode 100644
index 00cc3fd7..00000000
--- a/simvp/utils/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) CAIRI AI Lab. All rights reserved
-
-from .config_utils import Config, check_file_exist
-from .main_utils import (set_seed, print_log, output_namespace, check_dir, get_dataset,
- count_parameters, measure_throughput, load_config, update_config, weights_to_cpu,
- init_dist, get_dist_info)
-from .parser import create_parser
-from .predrnn_utils import (reserve_schedule_sampling_exp, schedule_sampling, reshape_patch,
- reshape_patch_back)
-
-__all__ = [
- 'Config', 'check_file_exist', 'create_parser',
- 'set_seed', 'print_log', 'output_namespace', 'check_dir', 'get_dataset', 'count_parameters',
- 'measure_throughput', 'load_config', 'update_config', 'weights_to_cpu', 'init_dist', 'get_dist_info',
- 'reserve_schedule_sampling_exp', 'schedule_sampling', 'reshape_patch', 'reshape_patch_back',
-]
\ No newline at end of file
diff --git a/simvp/utils/main_utils.py b/simvp/utils/main_utils.py
deleted file mode 100644
index 0aacaf13..00000000
--- a/simvp/utils/main_utils.py
+++ /dev/null
@@ -1,168 +0,0 @@
-# Copyright (c) CAIRI AI Lab. All rights reserved
-
-import os
-import logging
-import numpy as np
-import torch
-import random
-import torch.backends.cudnn as cudnn
-from collections import OrderedDict
-from typing import Tuple
-from .config_utils import Config
-
-import torch
-import torch.multiprocessing as mp
-from torch import distributed as dist
-
-
-def set_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- cudnn.deterministic = True
-
-
-def print_log(message):
- print(message)
- logging.info(message)
-
-
-def output_namespace(namespace):
- configs = namespace.__dict__
- message = ''
- for k, v in configs.items():
- message += '\n' + k + ': \t' + str(v) + '\t'
- return message
-
-
-def check_dir(path):
- if not os.path.exists(path):
- os.makedirs(path)
- return False
- return True
-
-
-def get_dataset(dataname, config):
- from simvp.datasets import dataset_parameters
- from simvp.datasets import load_data
- config.update(dataset_parameters[dataname])
- return load_data(**config)
-
-
-def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
-
-def measure_throughput(model, input_dummy):
- bs = 100
- repetitions = 100
- if isinstance(input_dummy, tuple):
- input_dummy = list(input_dummy)
- _, T, C, H, W = input_dummy[0].shape
- _input = torch.rand(bs, T, C, H, W).to(input_dummy[0].device)
- input_dummy[0] = _input
- input_dummy = tuple(input_dummy)
- else:
- _, T, C, H, W = input_dummy.shape
- input_dummy = torch.rand(bs, T, C, H, W).to(input_dummy.device)
- total_time = 0
- with torch.no_grad():
- for _ in range(repetitions):
- starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
- starter.record()
- if isinstance(input_dummy, tuple):
- _ = model(*input_dummy)
- else:
- _ = model(input_dummy)
- ender.record()
- torch.cuda.synchronize()
- curr_time = starter.elapsed_time(ender) / 1000
- total_time += curr_time
- Throughput = (repetitions * bs) / total_time
- return Throughput
-
-
-def load_config(filename:str = None):
- """load and print config"""
- print('loading config from ' + filename + ' ...')
- try:
- configfile = Config(filename=filename)
- config = configfile._cfg_dict
- except (FileNotFoundError, IOError):
- config = dict()
- print('warning: fail to load the config!')
- return config
-
-
-def update_config(args, config, exclude_keys=list()):
- """update the args dict with a new config"""
- assert isinstance(args, dict) and isinstance(config, dict)
- for k in config.keys():
- if args.get(k, False):
- if args[k] != config[k] and k not in exclude_keys:
- print(f'overwrite config key -- {k}: {config[k]} -> {args[k]}')
- else:
- args[k] = config[k]
- else:
- args[k] = config[k]
- return args
-
-
-def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
- """Copy a model state_dict to cpu.
-
- Args:
- state_dict (OrderedDict): Model weights on GPU.
-
- Returns:
- OrderedDict: Model weights on GPU.
- """
- state_dict_cpu = OrderedDict()
- for key, val in state_dict.items():
- state_dict_cpu[key] = val.cpu()
- # Keep metadata in state_dict
- state_dict_cpu._metadata = getattr( # type: ignore
- state_dict, '_metadata', OrderedDict())
- return state_dict_cpu
-
-
-def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
- if mp.get_start_method(allow_none=True) is None:
- mp.set_start_method('spawn')
- if launcher == 'pytorch':
- _init_dist_pytorch(backend, **kwargs)
- elif launcher == 'mpi':
- _init_dist_mpi(backend, **kwargs)
- else:
- raise ValueError(f'Invalid launcher type: {launcher}')
-
-
-def _init_dist_pytorch(backend: str, **kwargs) -> None:
- # TODO: use local_rank instead of rank % num_gpus
- rank = int(os.environ['RANK'])
- num_gpus = torch.cuda.device_count()
- torch.cuda.set_device(rank % num_gpus)
- dist.init_process_group(backend=backend, **kwargs)
-
-
-def _init_dist_mpi(backend: str, **kwargs) -> None:
- local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- torch.cuda.set_device(local_rank)
- if 'MASTER_PORT' not in os.environ:
- # 29500 is torch.distributed default port
- os.environ['MASTER_PORT'] = '29500'
- if 'MASTER_ADDR' not in os.environ:
- raise KeyError('The environment variable MASTER_ADDR is not set')
- os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
- os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
- dist.init_process_group(backend=backend, **kwargs)
-
-
-def get_dist_info() -> Tuple[int, int]:
- if dist.is_available() and dist.is_initialized():
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- else:
- rank = 0
- world_size = 1
- return rank, world_size
diff --git a/tools/dist_test.sh b/tools/dist_test.sh
new file mode 100644
index 00000000..c2e06b09
--- /dev/null
+++ b/tools/dist_test.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+
+set -x
+
+CFG=$1
+GPUS=$2
+CHECKPOINT=$3
+PY_ARGS=${@:4}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+# test
+python -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ tools/test.py --dist \
+ --config_file $CFG \
+ --ex_name $CHECKPOINT --launcher="pytorch" ${PY_ARGS}
diff --git a/tools/dist_train.sh b/tools/dist_train.sh
new file mode 100644
index 00000000..d36ba849
--- /dev/null
+++ b/tools/dist_train.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+PYTHON=${PYTHON:-"python"}
+
+CFG=$1
+GPUS=$2
+PY_ARGS=${@:3}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
+
+$PYTHON -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ tools/train.py --dist \
+ --config_file $CFG \
+ --ex_name $WORK_DIR \
+ --seed 42 --launcher="pytorch" ${PY_ARGS}
diff --git a/tools/non_dist_test.py b/tools/test.py
similarity index 68%
rename from tools/non_dist_test.py
rename to tools/test.py
index 89d6f940..92884f61 100644
--- a/tools/non_dist_test.py
+++ b/tools/test.py
@@ -1,11 +1,11 @@
# Copyright (c) CAIRI AI Lab. All rights reserved
-import os.path as osp
import warnings
warnings.filterwarnings('ignore')
-from simvp.api import NonDistExperiment
-from simvp.utils import create_parser, load_config, update_config
+from openstl.api import BaseExperiment
+from openstl.utils import (create_parser, get_dist_info, load_config,
+ setup_multi_processes, update_config)
try:
import nni
@@ -27,9 +27,13 @@
exclude_keys=['method', 'batch_size', 'val_batch_size'])
config['test'] = True
- exp = NonDistExperiment(args)
+ # set multi-process settings
+ setup_multi_processes(config)
print('>'*35 + ' testing ' + '<'*35)
+ exp = BaseExperiment(args)
+ rank, _ = get_dist_info()
+
mse = exp.test()
- if has_nni:
+ if rank == 0 and has_nni:
nni.report_final_result(mse)
diff --git a/tools/non_dist_train.py b/tools/train.py
similarity index 66%
rename from tools/non_dist_train.py
rename to tools/train.py
index 5603858f..23d3d562 100644
--- a/tools/non_dist_train.py
+++ b/tools/train.py
@@ -4,8 +4,9 @@
import warnings
warnings.filterwarnings('ignore')
-from simvp.api import NonDistExperiment
-from simvp.utils import create_parser, load_config, update_config
+from openstl.api import BaseExperiment
+from openstl.utils import (create_parser, get_dist_info, load_config,
+ setup_multi_processes, update_config)
try:
import nni
@@ -27,11 +28,17 @@
config = update_config(config, load_config(cfg_path),
exclude_keys=['method', 'batch_size', 'val_batch_size', 'sched'])
- exp = NonDistExperiment(args)
+ # set multi-process settings
+ setup_multi_processes(config)
+
print('>'*35 + ' training ' + '<'*35)
+ exp = BaseExperiment(args)
+ rank, _ = get_dist_info()
exp.train()
- print('>'*35 + ' testing ' + '<'*35)
+ if rank == 0:
+ print('>'*35 + ' testing ' + '<'*35)
mse = exp.test()
- if has_nni:
+
+ if rank == 0 and has_nni:
nni.report_final_result(mse)