From fc75a95a1d27b56f35fb0012070e53fa4910c692 Mon Sep 17 00:00:00 2001
From: Lupin1998 <1070535169@qq.com>
Date: Thu, 20 Apr 2023 20:23:44 +0000
Subject: [PATCH] release and fix OpenSTL V0.2.0 (issue #20)
---
README.md | 8 +-
.../weather/t2m_1_40625/SimVP_ConvMixer.py | 15 ++
configs/weather/t2m_1_40625/SimVP_ConvNeXt.py | 15 ++
configs/weather/t2m_1_40625/SimVP_HorNet.py | 15 ++
configs/weather/t2m_1_40625/SimVP_IncepU.py | 15 ++
configs/weather/t2m_1_40625/SimVP_MLPMixer.py | 15 ++
configs/weather/t2m_1_40625/SimVP_MogaNet.py | 15 ++
.../weather/t2m_1_40625/SimVP_Poolformer.py | 15 ++
configs/weather/t2m_1_40625/SimVP_Swin.py | 15 ++
.../weather/t2m_1_40625/SimVP_Uniformer.py | 15 ++
configs/weather/t2m_1_40625/SimVP_VAN.py | 15 ++
configs/weather/t2m_1_40625/SimVP_ViT.py | 15 ++
configs/weather/t2m_1_40625/SimVP_gSTA.py | 14 ++
docs/en/changelog.md | 25 +++-
docs/en/get_started.md | 53 ++++++-
docs/en/install.md | 17 ++-
environment.yml | 2 +-
openstl/api/train.py | 20 ++-
openstl/datasets/dataloader_weather.py | 66 +++++----
openstl/methods/base_method.py | 7 +-
openstl/methods/convlstm.py | 1 +
openstl/methods/crevnet.py | 97 ++++++------
openstl/methods/mau.py | 138 ++++++++----------
openstl/methods/phydnet.py | 99 +++++++------
openstl/methods/predrnn.py | 70 +++++----
openstl/methods/predrnnv2.py | 53 +++++--
openstl/methods/simvp.py | 4 +-
openstl/models/convlstm_model.py | 7 +-
openstl/models/crevnet_model.py | 12 +-
openstl/models/e3dlstm_model.py | 9 +-
openstl/models/mau_model.py | 8 +-
openstl/models/mim_model.py | 7 +-
openstl/models/phydnet_model.py | 17 ++-
openstl/models/predrnn_model.py | 7 +-
openstl/models/predrnnpp_model.py | 7 +-
openstl/models/predrnnv2_model.py | 21 ++-
openstl/models/simvp_model.py | 2 +-
openstl/utils/parser.py | 4 +-
requirements/runtime.txt | 4 +-
39 files changed, 640 insertions(+), 304 deletions(-)
create mode 100644 configs/weather/t2m_1_40625/SimVP_ConvMixer.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_ConvNeXt.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_HorNet.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_IncepU.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_MLPMixer.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_MogaNet.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_Poolformer.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_Swin.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_Uniformer.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_VAN.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_ViT.py
create mode 100644 configs/weather/t2m_1_40625/SimVP_gSTA.py
diff --git a/README.md b/README.md
index 2b210a29..173084d0 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ This is the journal version of our previous conference work ([SimVP: Simpler yet
## News and Updates
-[2023-04-19] `OpenSTL` v0.2.0 is released.
+[2023-04-19] `OpenSTL` v0.2.0 is released. The training loop and dataloaders are fixed.
## Installation
@@ -69,17 +69,17 @@ python setup.py develop
* torch
* timm
* tqdm
-* xarray
+* xarray==0.19.0
Please refer to [install.md](docs/en/install.md) for more detailed instructions.
## Getting Started
-Please see [get_started.md](docs/en/get_started.md) for the basic usage. Here is an example of single GPU non-dist training SimVP+gSTA on Moving MNIST dataset.
+Please see [get_started.md](docs/en/get_started.md) for the basic usage. Here is an example of single GPU non-distributed training SimVP+gSTA on Moving MNIST dataset.
```shell
bash tools/prepare_data/download_mmnist.sh
-python tools/train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
+python tools/train.py -d mmnist --lr 1e-3 -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
(back to top)
diff --git a/configs/weather/t2m_1_40625/SimVP_ConvMixer.py b/configs/weather/t2m_1_40625/SimVP_ConvMixer.py
new file mode 100644
index 00000000..4ab032c9
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_ConvMixer.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'convmixer'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-2
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
diff --git a/configs/weather/t2m_1_40625/SimVP_ConvNeXt.py b/configs/weather/t2m_1_40625/SimVP_ConvNeXt.py
new file mode 100644
index 00000000..93a809c2
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_ConvNeXt.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'convnext'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-2
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_HorNet.py b/configs/weather/t2m_1_40625/SimVP_HorNet.py
new file mode 100644
index 00000000..b96fa6d7
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_HorNet.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'hornet'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_IncepU.py b/configs/weather/t2m_1_40625/SimVP_IncepU.py
new file mode 100644
index 00000000..e33989f1
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_IncepU.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'IncepU' # SimVP.V1
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-2
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_MLPMixer.py b/configs/weather/t2m_1_40625/SimVP_MLPMixer.py
new file mode 100644
index 00000000..954c69ec
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_MLPMixer.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'mlp'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_MogaNet.py b/configs/weather/t2m_1_40625/SimVP_MogaNet.py
new file mode 100644
index 00000000..f99111ec
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_MogaNet.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'moga'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 5e-3
+batch_size = 16
+drop_path = 0.2
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_Poolformer.py b/configs/weather/t2m_1_40625/SimVP_Poolformer.py
new file mode 100644
index 00000000..1cea4e85
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_Poolformer.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'poolformer'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 5e-4
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_Swin.py b/configs/weather/t2m_1_40625/SimVP_Swin.py
new file mode 100644
index 00000000..bdc7e545
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_Swin.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'swin'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_Uniformer.py b/configs/weather/t2m_1_40625/SimVP_Uniformer.py
new file mode 100644
index 00000000..698a6860
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_Uniformer.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'uniformer'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 5e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_VAN.py b/configs/weather/t2m_1_40625/SimVP_VAN.py
new file mode 100644
index 00000000..885d5e21
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_VAN.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'van'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 5e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_ViT.py b/configs/weather/t2m_1_40625/SimVP_ViT.py
new file mode 100644
index 00000000..e711cee3
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_ViT.py
@@ -0,0 +1,15 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'vit'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 1e-3
+batch_size = 16
+drop_path = 0.1
+sched = 'cosine'
+warmup_epoch = 0
\ No newline at end of file
diff --git a/configs/weather/t2m_1_40625/SimVP_gSTA.py b/configs/weather/t2m_1_40625/SimVP_gSTA.py
new file mode 100644
index 00000000..8eefd720
--- /dev/null
+++ b/configs/weather/t2m_1_40625/SimVP_gSTA.py
@@ -0,0 +1,14 @@
+method = 'SimVP'
+# model
+spatio_kernel_enc = 3
+spatio_kernel_dec = 3
+model_type = 'gSTA'
+hid_S = 32
+hid_T = 256
+N_T = 8
+N_S = 2
+# training
+lr = 5e-3
+batch_size = 16
+drop_path = 0.1
+warmup_epoch = 0
\ No newline at end of file
diff --git a/docs/en/changelog.md b/docs/en/changelog.md
index 0de2139e..20d9953c 100644
--- a/docs/en/changelog.md
+++ b/docs/en/changelog.md
@@ -1,5 +1,28 @@
## Changelog
+### v0.1.0 (21/04/2023)
+
+Release version to OpenSTL V0.2.0 as [#20](https://github.com/chengtan9907/OpenSTL/issues/20).
+
+#### Code Refactoring
+
+* Rename the project to `OpenSTL` instead of `SimVPv2` with module name refactoring.
+* Refactor the code structure thoroughly to support non-distributed and distributed (DDP) training & testing with `tools/train.py` and `tools/test.py`.
+
+#### New Features
+
+* Update the Weather Bench dataloader with `5.625deg`, `2.8125deg`, and `1.40625deg` settings.
+
+#### Update Documents
+
+* Update documents of video prediction and weather prediction benchmarks. Provide config files for supported mixup methods.
+* Update `docs/en` documents for the basic usages and new features of V0.2.0.
+
+#### Fix Bugs
+
+* Fix bugs in training loops and validation loops to save GPU memory.
+* There might be some bugs in not using all parameters for calculating losses in ConvLSTM CrevNet, which should use `--find_unused_parameters` for DDP training.
+
### v0.1.0 (18/02/2023)
Release version to V0.1.0 with code refactoring.
@@ -15,7 +38,7 @@ Release version to V0.1.0 with code refactoring.
* Update popular Metaformer models as the hidden Translator $h$ in SimVP, supporting [ViT](https://arxiv.org/abs/2010.11929), [Swin-Transformer](https://arxiv.org/abs/2103.14030), [MLP-Mixer](https://arxiv.org/abs/2105.01601), [ConvMixer](https://arxiv.org/abs/2201.09792), [UniFormer](https://arxiv.org/abs/2201.09450), [PoolFormer](https://arxiv.org/abs/2111.11418), [ConvNeXt](https://arxiv.org/abs/2201.03545), [VAN](https://arxiv.org/abs/2202.09741), [HorNet](https://arxiv.org/abs/2207.14284), and [MogaNet](https://arxiv.org/abs/2211.03295).
* Update implementations of dataset and dataloader, supporting [KTH Action](https://ieeexplore.ieee.org/document/1334462), [KittiCaltech Pedestrian](https://dl.acm.org/doi/10.1177/0278364913491297), [Moving MNIST](http://arxiv.org/abs/1502.04681), [TaxiBJ](https://arxiv.org/abs/1610.00081), and [WeatherBench](https://arxiv.org/abs/2002.00469).
-### Update Documents
+#### Update Documents
* Upload `readthedocs` documents. Summarize video prediction benchmark results on MMNIST in [video_benchmarks.md](https://github.com/chengtan9907/SimVPv2/docs/en/model_zoos/video_benchmarks.md).
* Update benchmark results of video prediction baselines and MetaFormer architectures based on SimVP on MMNIST, TaxiBJ, and WeatherBench datasets.
diff --git a/docs/en/get_started.md b/docs/en/get_started.md
index f5ed2598..132ef89c 100644
--- a/docs/en/get_started.md
+++ b/docs/en/get_started.md
@@ -1,10 +1,10 @@
# Getting Started
-This page provides basic tutorials about the usage of SimVP. For installation instructions, please see [Install](docs/en/install.md).
+This page provides basic tutorials about the usage of OpenSTL with various spatioTemporal predictive learning (STL) tasks. For installation instructions, please see [Install](docs/en/install.md).
## Training and Testing with a Single GPU
-You can perform single/multiple GPU training and testing with `tools/train.py` and `tools/test.py`. We provide descriptions of some essential arguments.
+You can perform single GPU training and testing with `tools/train.py` and `tools/test.py` with non-distributed and distributed (DDP) modes. Non-distributed mode is recommanded for the single GPU training (a bit faster than DDP). We provide descriptions of some essential arguments. Other arguments related to datasets, optimizers, methods can be found in [parser.py](https://github.com/chengtan9907/OpenSTL/tree/master/openstl/utils/parser.py).
```bash
python tools/train.py \
@@ -12,27 +12,70 @@ python tools/train.py \
--method ${METHOD_NAME} \
--config_file ${CONFIG_FILE} \
--ex_name ${EXP_NAME} \
+ --resume_from ${CHECKPOINT_FILE} \
--auto_resume \
--batch_size ${BATCH_SIZE} \
--lr ${LEARNING_RATE} \
+ --dist \
+ --fp16 \
+ --seed ${SEED} \
+ --clip_grad ${VALUE} \
+ --find_unused_parameters \
+ --deterministic \
```
**Description of arguments**:
- `--dataname (-d)` : The name of dataset, default to be `mmnist`.
- `--method (-m)` : The name of the video prediction method to train or test, default to be `SimVP`.
-- `--config_file (-c)` : The path of a model config file, which will provide detailed settings for a video prediction method.
+- `--config_file (-c)` : The path of a model config file, which will provide detailed settings for a STL method.
- `--ex_name` : The name of the experiment under the `res_dir`. Default to be `Debug`.
+- `--resume_from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file. Or you can use `--auto_resume` to resume from `latest.pth` automatically.
- `--auto_resume` : Whether to automatically resume training when the experiment was interrupted.
- `--batch_size (-b)` : Training batch size, default to 16.
- `--lr` : The basic training learning rate, defaults to 0.001.
+- `--dist`: Whether to use distributed training (DDP).
+- `--fp16`: Whether to use Native AMP for mixed precision training (PyTorch=>1.6.0).
+- `--seed ${SEED}`: Setup all random seeds to a certain number (defaults to 42).
+- `--clip_grad ${VALUE}`: Clip gradient norm value (default: None, no clipping).
+- `--find_unused_parameters`: Whether to find unused parameters in forward during DDP training.
+- `--deterministic`: Switch on "deterministic" mode, which slows down training while the results are reproducible.
-An example of single GPU training with SimVP+gSTA on Moving MNIST dataset.
+An example of single GPU (non-distributed) training with SimVP+gSTA on Moving MNIST dataset.
```shell
bash tools/prepare_data/download_mmnist.sh
-python tools/train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
+python tools/train.py -d mmnist --lr 1e-3 -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
An example of single GPU testing with SimVP+gSTA on Moving MNIST dataset.
```shell
python tools/test.py -d mmnist -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta
```
+
+## Training and Testing with Multiple GPUs
+
+For larger STL tasks (e.g., high resolutions), you can also perform multiple GPUs training and testing with `tools/dist_train.sh` and `tools/dist_test.sh` with DDP mode. The bash files will call `tools/train.py` and `tools/test.py` with the necessary arguments.
+
+```shell
+bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} [optional arguments]
+```
+**Description of arguments**:
+- `${CONFIG_FILE}` : The path of a model config file, which will provide detailed settings for a STL method.
+- `${GPUS}` : The number of GPUs for DDP training.
+
+Examples of multiple GPUs training on Moving MNIST dataset with a machine with 8 GPUs.
+```shell
+PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_train.sh configs/mmnist/simvp/SimVP_gSTA.py 2 -d mmnist --lr 1e-3 --batch_size 8
+PORT=29002 CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/mmnist/PredRNN.py 2 -d mmnist --lr 1e-3 --batch_size 8
+PORT=29003 CUDA_VISIBLE_DEVICES=4,5,6,7 bash tools/dist_train.sh configs/mmnist/PredRNNpp.py 4 -d mmnist --lr 1e-3 --batch_size 4
+```
+
+An example of multiple GPUs testing on Moving MNIST dataset. The bash script is `bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} [optional arguments]`.
+```shell
+PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_test.sh configs/mmnist/simvp/SimVP_gSTA.py 2 work_dirs/mmnist/simvp/SimVP_gSTA -d mmnist
+```
+
+**Note**: During DDP training, the number of GPUS `ngpus` should be provided and checkpoints and logs are saved in the same folder structure as the config file under `work_dirs/` (it will be the default setting if `--ex_name` is not specified). The default learning rate `lr` and the batch size `bs` in config files are for a single GPU. If using a different number GPUs, the total batch size will change in proportion, you have to scale the learning rate following `lr = base_lr * ngpus` and `bs = base_bs * ngpus`. Other arguments should be added as the single GPU training.
+
+## Mixed Precision Training
+
+We support Mixed Precision Training implemented by PyTorch AMP. If you want to use Mixed Precision Training, you can add `--fp16` in the arguments.
diff --git a/docs/en/install.md b/docs/en/install.md
index 5674d29c..c83d36b5 100644
--- a/docs/en/install.md
+++ b/docs/en/install.md
@@ -11,6 +11,16 @@ conda activate OpenSTL
python setup.py develop # or `pip install -e .`
```
+
+Requirements
+* Linux (Windows is not officially supported)
+* Python 3.7+
+* PyTorch 1.8 or higher
+* CUDA 10.1 or higher
+* NCCL 2
+* GCC 4.9 or higher
+
+
Dependencies
@@ -23,12 +33,12 @@ python setup.py develop # or `pip install -e .`
* torch
* timm
* tqdm
-* xarray
+* xarray==0.19.0
**Note:**
-1. Some errors might occur with `hickle` and `xarray` when using KittiCaltech and WeatherBench datasets. As for KittiCaltech, you can solve the issues by installing additional pacakges according to the output messeage. As for WeatherBench, you can install the latest version of `xarray` to solve the errors, i.e., `pip install git+https://github.com/pydata/xarray/@v2022.03.0` and then installing required pacakges according to error messages.
+1. Some errors might occur with `hickle` and `xarray` when using KittiCaltech and WeatherBench datasets. As for KittiCaltech, you can solve the issues by installing additional pacakges according to the output messeage. As for WeatherBench, you can install the latest version of `xarray` to solve the errors, i.e., `pip install xarray==0.19.0` and then installing required pacakges according to error messages.
2. Following the above instructions, OpenSTL is installed on `dev` mode, any local modifications made to the code will take effect. You can install it by `pip install .` to use it as a PyPi package, and you should reinstall it to make the local modifications effect.
@@ -61,4 +71,7 @@ OpenSTL
|── weather
| ├── 2m_temperature
| ├── ...
+ |── weather_1_40625deg
+ | ├── 2m_temperature
+ | ├── ...
```
diff --git a/environment.yml b/environment.yml
index 888c6317..28c84c4c 100644
--- a/environment.yml
+++ b/environment.yml
@@ -10,7 +10,7 @@ dependencies:
- pip
- python
- pytorch
- - xarray
+ - xarray==0.19.0
- pip:
- scikit-image
- timm
diff --git a/openstl/api/train.py b/openstl/api/train.py
index d82fee7d..57ca9099 100644
--- a/openstl/api/train.py
+++ b/openstl/api/train.py
@@ -70,7 +70,7 @@ def _acquire_device(self):
return device
def _preparation(self):
- """Preparation of basic experiment setups"""
+ """Preparation of environment and basic experiment setups"""
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(self.args.local_rank)
@@ -173,7 +173,7 @@ def call_hook(self, fn_name: str) -> None:
getattr(hook, fn_name)(self)
def _get_hook_info(self):
- # Get hooks info in each stage
+ """Get hook information in each stage"""
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self._hooks:
priority = hook.priority # type: ignore
@@ -193,6 +193,7 @@ def _get_hook_info(self):
return '\n'.join(stage_hook_infos)
def _get_data(self):
+ """Prepare datasets and dataloaders"""
self.train_loader, self.vali_loader, self.test_loader = \
get_dataset(self.args.dataname, self.config)
if self.vali_loader is None:
@@ -255,14 +256,19 @@ def display_method_info(self):
else:
raise ValueError(f'Invalid method name {self.args.method}')
- print_log(self.method.model)
+ dash_line = '-' * 80 + '\n'
+ info = self.method.model.__repr__()
flops = FlopCountAnalysis(self.method.model, input_dummy)
- print_log(flop_count_table(flops))
+ flops = flop_count_table(flops)
if self.args.fps:
fps = measure_throughput(self.method.model, input_dummy)
- print_log('Throughputs of {}: {:.3f}'.format(self.args.method, fps))
+ fps = 'Throughputs of {}: {:.3f}\n'.format(self.args.method, fps)
+ else:
+ fps = ''
+ print_log('Model info:\n' + info+'\n' + flops+'\n' + fps + dash_line)
def train(self):
+ """Training loops of STL methods"""
recorder = Recorder(verbose=True)
num_updates = self._epoch * self.steps_per_epoch
self.call_hook('before_train_epoch')
@@ -298,6 +304,7 @@ def train(self):
self.call_hook('after_run')
def vali(self, vali_loader):
+ """A validation loop during training"""
self.call_hook('before_val_epoch')
preds, trues, val_loss = self.method.vali_one_epoch(self, self.vali_loader)
self.call_hook('after_val_epoch')
@@ -310,13 +317,14 @@ def vali(self, vali_loader):
eval_res, eval_log = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std,
metrics=metric_list, spatial_norm=spatial_norm)
- print_log('val\t '+eval_log)
+ print_log('\nval\t '+eval_log)
if has_nni:
nni.report_intermediate_result(eval_res['mse'])
return val_loss
def test(self):
+ """A testing loop of STL methods"""
if self.args.test:
best_model_path = osp.join(self.path, 'checkpoint.pth')
if self._dist:
diff --git a/openstl/datasets/dataloader_weather.py b/openstl/datasets/dataloader_weather.py
index 17a3fdba..ae16d388 100644
--- a/openstl/datasets/dataloader_weather.py
+++ b/openstl/datasets/dataloader_weather.py
@@ -76,14 +76,12 @@ def __init__(self, data_root, data_name, training_time,
if data_name != 'uv10':
try:
- # dataset = xr.open_mfdataset(
- # data_root+'/{}/*.nc'.format(data_map[data_name]), combine='by_coords')
- print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
- dataset = xr.open_mfdataset(
- data_root+'/{}/*.nc'.format(data_map[data_name]), combine='by_coords', parallel=False, chunks={'time':168})
- except AttributeError:
- assert False and 'Please install the latest xarray, e.g.,' \
- 'pip install git+https://github.com/pydata/xarray/@v2022.03.0'
+ dataset = xr.open_mfdataset(data_root+'/{}/{}*.nc'.format(
+ data_map[data_name], data_map[data_name]), combine='by_coords')
+ except (AttributeError, ValueError):
+ assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \
+ 'pip install xarray==0.19.0,' \
+ 'pip install netcdf4 h5netcdf dask'
except OSError:
print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
assert False
@@ -107,14 +105,14 @@ def __init__(self, data_root, data_name, training_time,
input_datasets = []
for key in ['u10', 'v10']:
try:
- dataset = xr.open_mfdataset(
- data_root+'/{}/*.nc'.format(data_map[key]), combine='by_coords')
- except AttributeError:
- assert False and 'Please install the latest xarray, e.g.,' \
- 'pip install git+https://github.com/pydata/xarray/@v2022.03.0,' \
+ dataset = xr.open_mfdataset(data_root+'/{}/{}*.nc'.format(
+ data_map[key], data_map[key]), combine='by_coords')
+ except (AttributeError, ValueError):
+ assert False and 'Please install xarray and its dependency (e.g., netcdf4), ' \
+ 'pip install xarray==0.19.0,' \
'pip install netcdf4 h5netcdf dask'
except OSError:
- print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[data_name]))
+ print("OSError: Invalid path {}/{}/*.nc".format(data_root, data_map[key]))
assert False
dataset = dataset.sel(time=slice(*training_time))
dataset = dataset.isel(time=slice(None, -1, step))
@@ -222,21 +220,25 @@ def load_data(batch_size,
if __name__ == '__main__':
- dataloader_train, _, dataloader_test = \
- load_data(batch_size=128,
- val_batch_size=32,
- data_root='../../data',
- num_workers=2, data_name='t2m',
- data_split='1_40625',
- train_time=['1979', '2015'],
- val_time=['2016', '2016'],
- test_time=['2017', '2018'],
- idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
- idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], step=24)
-
- for item in dataloader_train:
- print(item[0].shape)
- break
- for item in dataloader_test:
- print(item[0].shape)
- break
+ data_split=['5_625', '1_40625']
+ data_name = 't2m'
+
+ for _split in data_split:
+ dataloader_train, _, dataloader_test = \
+ load_data(batch_size=128,
+ val_batch_size=32,
+ data_root='../../data',
+ num_workers=4, data_name=data_name,
+ data_split=_split,
+ train_time=['1979', '2015'],
+ val_time=['2016', '2016'],
+ test_time=['2017', '2018'],
+ idx_in=[-11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0],
+ idx_out=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], step=24)
+
+ for item in dataloader_train:
+ print(item[0].shape)
+ break
+ for item in dataloader_test:
+ print(item[0].shape)
+ break
diff --git a/openstl/methods/base_method.py b/openstl/methods/base_method.py
index 22a92277..e6ec76fb 100644
--- a/openstl/methods/base_method.py
+++ b/openstl/methods/base_method.py
@@ -64,7 +64,8 @@ def _init_distributed(self):
print('Using native PyTorch AMP. Training in mixed precision (fp16).')
else:
print('AMP not enabled. Training in float32.')
- self.model = NativeDDP(self.model, device_ids=[self.rank], broadcast_buffers=True)
+ self.model = NativeDDP(self.model, device_ids=[self.rank], broadcast_buffers=True,
+ find_unused_parameters=self.args.find_unused_parameters)
def train_one_epoch(self, runner, train_loader, **kwargs):
"""Train the model with train_loader.
@@ -166,8 +167,8 @@ def vali_one_epoch(self, runner, vali_loader, **kwargs):
else:
results = self._nondist_forward_collect(vali_loader, len(vali_loader.dataset))
- preds = torch.tensor(results['preds']).to(self.device)
- trues = torch.tensor(results['trues']).to(self.device)
+ preds = torch.tensor(results['preds'])
+ trues = torch.tensor(results['trues'])
losses_m = self.criterion(preds, trues).cpu().numpy()
return results['preds'], results['trues'], losses_m
diff --git a/openstl/methods/convlstm.py b/openstl/methods/convlstm.py
index ef1d13f5..07aa24ac 100644
--- a/openstl/methods/convlstm.py
+++ b/openstl/methods/convlstm.py
@@ -10,6 +10,7 @@ class ConvLSTM(PredRNN):
Implementation of `Convolutional LSTM Network: A Machine Learning Approach
for Precipitation Nowcasting `_.
+ Notice: ConvLSTM requires `find_unused_parameters=True` for DDP training.
"""
def __init__(self, args, device, steps_per_epoch):
diff --git a/openstl/methods/crevnet.py b/openstl/methods/crevnet.py
index f7d1535c..bd89363c 100644
--- a/openstl/methods/crevnet.py
+++ b/openstl/methods/crevnet.py
@@ -1,11 +1,12 @@
+import time
import torch
import torch.nn as nn
-import numpy as np
from tqdm import tqdm
from timm.utils import AverageMeter
from openstl.core.optim_scheduler import get_optim_scheduler
from openstl.models import CrevNet_Model
+from openstl.utils import reduce_tensor
from .base_method import Base_method
@@ -33,77 +34,73 @@ def _init_optimizer(self, steps_per_epoch):
self.model_optim2, self.scheduler2, self.by_epoch_2 = get_optim_scheduler(
self.args, self.args.epoch, self.model.encoder, steps_per_epoch)
- def train_one_epoch(self, train_loader, epoch, num_updates, loss_mean, eta=None, **kwargs):
+ def _predict(self, batch_x, batch_y, **kwargs):
+ """Forward the model"""
+ input = torch.cat([batch_x, batch_y], dim=1)
+ pred_y, _ = self.model(input, training=False, return_loss=False)
+ return pred_y
+
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch_1:
self.scheduler.step(epoch)
if self.by_epoch_2:
self.scheduler2.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
self.model_optim2.zero_grad()
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
input = torch.cat([batch_x, batch_y], dim=1)
- loss = self.model(input, training=True)
- loss.backward()
+ runner.call_hook('before_train_iter')
+
+ with self.amp_autocast():
+ loss = self.model(input, training=True)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
self.model_optim.step()
self.model_optim2.step()
-
+ torch.cuda.synchronize()
num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
+
if not self.by_epoch_1:
self.scheduler.step()
if not self.by_epoch_2:
self.scheduler2.step()
- train_pbar.set_description('train loss: {:.4f}'.format(
- loss.item() / (self.args.pre_seq_length + self.args.aft_seq_length)))
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
- if hasattr(self.model_optim, 'sync_lookahead'):
- self.model_optim.sync_lookahead()
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
- return num_updates, loss_mean, eta
+ end = time.time() # end for
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- self.model.eval()
- preds_lst, trues_lst, total_loss = [], [], []
- vali_pbar = tqdm(vali_loader)
- for i, (batch_x, batch_y) in enumerate(vali_pbar):
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- input = torch.cat([batch_x, batch_y], dim=1)
- pred_y, loss = self.model(input, training=False)
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()
- ), [pred_y, batch_y], [preds_lst, trues_lst]))
-
- if i * batch_x.shape[0] > 1000:
- break
-
- vali_pbar.set_description('vali loss: {:.4f}'.format(loss.mean().item()))
- total_loss.append(loss.mean().item())
-
- total_loss = np.average(total_loss)
-
- preds = np.concatenate(preds_lst, axis=0)
- trues = np.concatenate(trues_lst, axis=0)
- return preds, trues, total_loss
-
- def test_one_epoch(self, runner, test_loader, **kwargs):
- self.model.eval()
- inputs_lst, trues_lst, preds_lst = [], [], []
- test_pbar = tqdm(test_loader)
- for batch_x, batch_y in test_pbar:
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- input = torch.cat([batch_x, batch_y], dim=1)
- pred_y, _ = self.model(input, training=False)
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
- batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
+ if hasattr(self.model_optim, 'sync_lookahead'):
+ self.model_optim.sync_lookahead()
- inputs, trues, preds = map(
- lambda data: np.concatenate(data, axis=0), [inputs_lst, trues_lst, preds_lst])
- return inputs, trues, preds
+ return num_updates, losses_m, eta
diff --git a/openstl/methods/mau.py b/openstl/methods/mau.py
index d683f106..ae1b1913 100644
--- a/openstl/methods/mau.py
+++ b/openstl/methods/mau.py
@@ -1,11 +1,11 @@
+import time
import torch
import torch.nn as nn
-import numpy as np
from timm.utils import AverageMeter
from tqdm import tqdm
from openstl.models import MAU_Model
-from openstl.utils import schedule_sampling
+from openstl.utils import reduce_tensor, schedule_sampling
from .base_method import Base_method
@@ -28,102 +28,82 @@ def _build_model(self, args):
num_layers = len(num_hidden)
return MAU_Model(num_layers, num_hidden, args).to(self.device)
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta, **kwargs):
+ def _predict(self, batch_x, batch_y, **kwargs):
+ """Forward the model."""
+ _, img_channel, img_height, img_width = self.args.in_shape
+
+ # preprocess
+ test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
+ real_input_flag = torch.zeros(
+ (batch_x.shape[0],
+ self.args.total_length - self.args.pre_seq_length - 1,
+ img_height // self.args.patch_size,
+ img_width // self.args.patch_size,
+ self.args.patch_size ** 2 * img_channel)).to(self.device)
+
+ img_gen, _ = self.model(test_ims, real_input_flag, return_loss=False)
+ pred_y = img_gen[:, -self.args.aft_seq_length:, :]
+
+ return pred_y
+
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ runner.call_hook('before_train_iter')
# preprocess
ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
eta, real_input_flag = schedule_sampling(eta, num_updates, ims.shape[0], self.args)
- img_gen, loss = self.model(ims, real_input_flag)
- loss.backward()
- self.model_optim.step()
+ with self.amp_autocast():
+ img_gen, loss = self.model(ims, real_input_flag)
- num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
- if not self.by_epoch:
- self.scheduler.step()
-
- train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
-
- if hasattr(self.model_optim, 'sync_lookahead'):
- self.model_optim.sync_lookahead()
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
- return num_updates, loss_mean, eta
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- self.model.eval()
- preds_lst, trues_lst, total_loss = [], [], []
- vali_pbar = tqdm(vali_loader)
+ self.model_optim.step()
+ torch.cuda.synchronize()
+ num_updates += 1
- _, img_channel, img_height, img_width = self.args.in_shape
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
- for i, (batch_x, batch_y) in enumerate(vali_pbar):
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ if not self.by_epoch:
+ self.scheduler.step()
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
- # preprocess
- test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
-
- real_input_flag = torch.zeros(
- (batch_x.shape[0],
- self.args.total_length - self.args.pre_seq_length - 1,
- img_height // self.args.patch_size,
- img_width // self.args.patch_size,
- self.args.patch_size ** 2 * img_channel)).to(self.device)
-
- img_gen, loss = self.model(test_ims, real_input_flag)
- pred_y = img_gen[:, -self.args.aft_seq_length:, :]
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()
- ), [pred_y, batch_y], [preds_lst, trues_lst]))
-
- if i * batch_x.shape[0] > 1000:
- break
-
- vali_pbar.set_description('vali loss: {:.4f}'.format(loss.mean().item()))
- total_loss.append(loss.mean().item())
-
- total_loss = np.average(total_loss)
-
- preds = np.concatenate(preds_lst, axis=0)
- trues = np.concatenate(trues_lst, axis=0)
- return preds, trues, total_loss
-
- def test_one_epoch(self, runner, test_loader, **kwargs):
- self.model.eval()
- inputs_lst, trues_lst, preds_lst = [], [], []
- test_pbar = tqdm(test_loader)
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
- _, img_channel, img_height, img_width = self.args.in_shape
+ end = time.time() # end for
- for batch_x, batch_y in test_pbar:
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ if hasattr(self.model_optim, 'sync_lookahead'):
+ self.model_optim.sync_lookahead()
- # preprocess
- test_ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
-
- real_input_flag = torch.zeros(
- (batch_x.shape[0],
- self.args.total_length - self.args.pre_seq_length - 1,
- img_height // self.args.patch_size,
- img_width // self.args.patch_size,
- self.args.patch_size ** 2 * img_channel)).to(self.device)
-
- img_gen, _ = self.model(test_ims, real_input_flag)
- pred_y = img_gen[:, -self.args.aft_seq_length:, :]
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
- batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
-
- inputs, trues, preds = map(
- lambda data: np.concatenate(data, axis=0), [inputs_lst, trues_lst, preds_lst])
- return inputs, trues, preds
+ return num_updates, losses_m, eta
diff --git a/openstl/methods/phydnet.py b/openstl/methods/phydnet.py
index 466e05f5..6a9826db 100644
--- a/openstl/methods/phydnet.py
+++ b/openstl/methods/phydnet.py
@@ -1,3 +1,4 @@
+import time
import torch
import torch.nn as nn
import numpy as np
@@ -5,6 +6,7 @@
from tqdm import tqdm
from openstl.models import PhyDNet_Model
+from openstl.utils import reduce_tensor
from .base_method import Base_method
@@ -36,71 +38,68 @@ def _get_constraints(self):
ind +=1
return constraints
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta=None, **kwargs):
+ def _predict(self, batch_x, batch_y, **kwargs):
+ """Forward the model"""
+ pred_y, _ = self.model.inference(batch_x, batch_y, self.constraints, return_loss=False)
+ return pred_y
+
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
teacher_forcing_ratio = np.maximum(0 , 1 - epoch * 0.003)
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- pred_y = self.model(batch_x, batch_y, self.constraints, teacher_forcing_ratio)
- loss = self.criterion(pred_y, batch_y)
- loss.backward()
+ runner.call_hook('before_train_iter')
+
+ with self.amp_autocast():
+ pred_y = self.model(batch_x, batch_y, self.constraints, teacher_forcing_ratio)
+ loss = self.criterion(pred_y, batch_y)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
+
self.model_optim.step()
-
+ torch.cuda.synchronize()
num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
- if not self.by_epoch:
- self.scheduler.step()
- train_pbar.set_description('train loss: {:.4f}'.format(
- loss.item() / (self.args.pre_seq_length + self.args.aft_seq_length)))
- if hasattr(self.model_optim, 'sync_lookahead'):
- self.model_optim.sync_lookahead()
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
- return num_updates, loss_mean, eta
-
- def vali_one_epoch(self, runner, vali_loader, **kwargs):
- self.model.eval()
- preds_lst, trues_lst, total_loss = [], [], []
- vali_pbar = tqdm(vali_loader)
- for i, (batch_x, batch_y) in enumerate(vali_pbar):
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- pred_y, loss = self.model.inference(batch_x, batch_y, self.constraints)
- loss = self.criterion(pred_y, batch_y)
-
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()
- ), [pred_y, batch_y], [preds_lst, trues_lst]))
-
- if i * batch_x.shape[0] > 1000:
- break
-
- vali_pbar.set_description('vali loss: {:.4f}'.format(loss.mean().item()))
- total_loss.append(loss.mean().item())
-
- total_loss = np.average(total_loss)
+ if not self.by_epoch:
+ self.scheduler.step()
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
- preds = np.concatenate(preds_lst, axis=0)
- trues = np.concatenate(trues_lst, axis=0)
- return preds, trues, total_loss
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
- def test_one_epoch(self, runner, test_loader, **kwargs):
- self.model.eval()
- inputs_lst, trues_lst, preds_lst = [], [], []
- test_pbar = tqdm(test_loader)
- for batch_x, batch_y in test_pbar:
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
- pred_y, _ = self.model.inference(batch_x, batch_y, self.constraints)
+ end = time.time() # end for
- list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
- batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
+ if hasattr(self.model_optim, 'sync_lookahead'):
+ self.model_optim.sync_lookahead()
- inputs, trues, preds = map(
- lambda data: np.concatenate(data, axis=0), [inputs_lst, trues_lst, preds_lst])
- return inputs, trues, preds
+ return num_updates, losses_m, eta
diff --git a/openstl/methods/predrnn.py b/openstl/methods/predrnn.py
index b5ba3f38..ac4c3157 100644
--- a/openstl/methods/predrnn.py
+++ b/openstl/methods/predrnn.py
@@ -1,11 +1,11 @@
+import time
import torch
import torch.nn as nn
-import numpy as np
from timm.utils import AverageMeter
from tqdm import tqdm
from openstl.models import PredRNN_Model
-from openstl.utils import (reshape_patch, reshape_patch_back,
+from openstl.utils import (reduce_tensor, reshape_patch, reshape_patch_back,
reserve_schedule_sampling_exp, schedule_sampling)
from .base_method import Base_method
@@ -29,8 +29,8 @@ def _build_model(self, args):
num_layers = len(num_hidden)
return PredRNN_Model(num_layers, num_hidden, args).to(self.device)
- def _predict(self, batch_x, batch_y):
- """Forward the model."""
+ def _predict(self, batch_x, batch_y, **kwargs):
+ """Forward the model"""
# reverse schedule sampling
if self.args.reverse_scheduled_sampling == 1:
mask_input = 1
@@ -53,32 +53,28 @@ def _predict(self, batch_x, batch_y):
if self.args.reverse_scheduled_sampling == 1:
real_input_flag[:, :self.args.pre_seq_length - 1, :, :] = 1.0
- img_gen, _ = self.model(test_dat, real_input_flag)
+ img_gen, _ = self.model(test_dat, real_input_flag, return_loss=False)
img_gen = reshape_patch_back(img_gen, self.args.patch_size)
pred_y = img_gen[:, -self.args.aft_seq_length:].permute(0, 1, 4, 2, 3).contiguous()
return pred_y
- def forward_test(self, batch_x, batch_y):
- """Evaluate the model"""
- batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
-
- with self.amp_autocast():
- pred_y = self._predict(batch_x, batch_y)
-
- return dict(zip(['inputs', 'preds', 'trues'],
- [batch_x, pred_y, batch_y]))
-
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta, **kwargs):
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ runner.call_hook('before_train_iter')
# preprocess
ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
@@ -90,19 +86,43 @@ def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, e
eta, real_input_flag = schedule_sampling(
eta, num_updates, ims.shape[0], self.args)
- img_gen, loss = self.model(ims, real_input_flag)
- loss.backward()
+ with self.amp_autocast():
+ img_gen, loss = self.model(ims, real_input_flag)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
+
self.model_optim.step()
+ torch.cuda.synchronize()
+ num_updates += 1
+
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
+
if not self.by_epoch:
- self.scheduler.step(epoch)
+ self.scheduler.step()
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
- num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
- train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
+ end = time.time() # end for
if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()
- return num_updates, loss_mean, eta
+ return num_updates, losses_m, eta
diff --git a/openstl/methods/predrnnv2.py b/openstl/methods/predrnnv2.py
index 87811c34..846a1265 100644
--- a/openstl/methods/predrnnv2.py
+++ b/openstl/methods/predrnnv2.py
@@ -1,10 +1,12 @@
+import time
import torch
import torch.nn as nn
from timm.utils import AverageMeter
from tqdm import tqdm
from openstl.models import PredRNNv2_Model
-from openstl.utils import reshape_patch, reserve_schedule_sampling_exp, schedule_sampling
+from openstl.utils import (reduce_tensor, reshape_patch,
+ reserve_schedule_sampling_exp, schedule_sampling)
from .predrnn import PredRNN
@@ -27,16 +29,22 @@ def _build_model(self, args):
num_layers = len(num_hidden)
return PredRNNv2_Model(num_layers, num_hidden, args).to(self.device)
- def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, eta, **kwargs):
+ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **kwargs):
+ """Train the model with train_loader."""
+ data_time_m = AverageMeter()
losses_m = AverageMeter()
self.model.train()
if self.by_epoch:
self.scheduler.step(epoch)
+ train_pbar = tqdm(train_loader) if self.rank == 0 else train_loader
- train_pbar = tqdm(train_loader)
+ end = time.time()
for batch_x, batch_y in train_pbar:
+ data_time_m.update(time.time() - end)
self.model_optim.zero_grad()
+
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
+ runner.call_hook('before_train_iter')
# preprocess
ims = torch.cat([batch_x, batch_y], dim=1).permute(0, 1, 3, 4, 2).contiguous()
@@ -48,18 +56,43 @@ def train_one_epoch(self, runner, train_loader, epoch, num_updates, loss_mean, e
eta, real_input_flag = schedule_sampling(
eta, num_updates, ims.shape[0], self.args)
- img_gen, loss = self.model(ims, real_input_flag)
- loss.backward()
- self.model_optim.step()
+ with self.amp_autocast():
+ img_gen, loss = self.model(ims, real_input_flag)
+
+ if not self.dist:
+ losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.loss_scaler is not None:
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
+ self.loss_scaler(
+ loss, self.model_optim,
+ clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
+ parameters=self.model.parameters())
+ else:
+ loss.backward()
+ self.clip_grads(self.model.parameters())
+ self.model_optim.step()
+ torch.cuda.synchronize()
num_updates += 1
- loss_mean += loss.item()
- losses_m.update(loss.item(), batch_x.size(0))
+
+ if self.dist:
+ losses_m.update(reduce_tensor(loss), batch_x.size(0))
+
if not self.by_epoch:
self.scheduler.step()
- train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
+ runner.call_hook('after_train_iter')
+ runner._iter += 1
+
+ if self.rank == 0:
+ log_buffer = 'train loss: {:.4f}'.format(loss.item())
+ log_buffer += ' | data time: {:.4f}'.format(data_time_m.avg)
+ train_pbar.set_description(log_buffer)
+
+ end = time.time() # end for
if hasattr(self.model_optim, 'sync_lookahead'):
self.model_optim.sync_lookahead()
- return num_updates, loss_mean, eta
+ return num_updates, losses_m, eta
diff --git a/openstl/methods/simvp.py b/openstl/methods/simvp.py
index 692d5d0a..f6fc02dd 100644
--- a/openstl/methods/simvp.py
+++ b/openstl/methods/simvp.py
@@ -27,7 +27,7 @@ def _build_model(self, config):
return SimVP_Model(**config).to(self.device)
def _predict(self, batch_x, batch_y=None, **kwargs):
- """Forward the model."""
+ """Forward the model"""
if self.args.aft_seq_length == self.args.pre_seq_length:
pred_y = self.model(batch_x)
elif self.args.aft_seq_length < self.args.pre_seq_length:
@@ -76,7 +76,7 @@ def train_one_epoch(self, runner, train_loader, epoch, num_updates, eta=None, **
if self.loss_scaler is not None:
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
- raise ValueError("Inf or nan loss value. Please use fp32 training instead!")
+ raise ValueError("Inf or nan loss value. Please use fp32 training!")
self.loss_scaler(
loss, self.model_optim,
clip_grad=self.args.clip_grad, clip_mode=self.args.clip_mode,
diff --git a/openstl/models/convlstm_model.py b/openstl/models/convlstm_model.py
index 3414447e..bb739d97 100644
--- a/openstl/models/convlstm_model.py
+++ b/openstl/models/convlstm_model.py
@@ -36,7 +36,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel,
kernel_size=1, stride=1, padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -78,6 +78,9 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/crevnet_model.py b/openstl/models/crevnet_model.py
index fbdc2e20..8a24037b 100644
--- a/openstl/models/crevnet_model.py
+++ b/openstl/models/crevnet_model.py
@@ -10,6 +10,8 @@ class CrevNet_Model(nn.Module):
Implementation of `Efficient and Information-Preserving Future Frame Prediction
and Beyond `_.
+
+ Notice: CrevNet Model requires `batch_size` == `val_batch_size`, or it will raise
"""
def __init__(self, in_shape, rnn_size, batch_size, predictor_rnn_layers,
@@ -31,7 +33,7 @@ def __init__(self, in_shape, rnn_size, batch_size, predictor_rnn_layers,
mult=2)
self.criterion = nn.MSELoss()
- def forward(self, x, training=True):
+ def forward(self, x, training=True, **kwargs):
B, T, C, H, W = x.shape
input = []
@@ -46,9 +48,13 @@ def forward(self, x, training=True):
memo = Variable(torch.zeros(B, self.rnn_size, 3, H // 8, W // 8).cuda())
for i in range(1, self.pre_seq_length + self.aft_seq_length):
h = self.encoder(input[i - 1], True)
- h_pred, memo = self.frame_predictor((h, memo))
+ try:
+ h_pred, memo = self.frame_predictor((h, memo))
+ except RuntimeError:
+ assert False and "CrevNet Model requires `batch_size` == `val_batch_size`"
x_pred = self.encoder(h_pred, False)
- loss += (self.criterion(x_pred, input[i]))
+ if kwargs.get('return_loss', True):
+ loss += (self.criterion(x_pred, input[i]))
if training is True:
return loss
diff --git a/openstl/models/e3dlstm_model.py b/openstl/models/e3dlstm_model.py
index 9dda637f..e028539e 100644
--- a/openstl/models/e3dlstm_model.py
+++ b/openstl/models/e3dlstm_model.py
@@ -41,7 +41,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
kernel_size=(self.window_length, 1, 1),
stride=(self.window_length, 1, 1), padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -104,7 +104,10 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(
- next_frames, frames_tensor[:, 1:]) + self.L1_criterion(next_frames, frames_tensor[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:]) + \
+ self.L1_criterion(next_frames, frames_tensor[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/mau_model.py b/openstl/models/mau_model.py
index 976f19d3..c371835d 100644
--- a/openstl/models/mau_model.py
+++ b/openstl/models/mau_model.py
@@ -105,7 +105,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.conv_last_sr = nn.Conv2d(
self.frame_channel * 2, self.frame_channel, kernel_size=1, stride=1, padding=0)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -161,7 +161,6 @@ def forward(self, frames_tensor, mask_true):
T_pre[i].append(T_t[i])
out = S_t
-
for i in range(len(self.decoders)):
out = self.decoders[i](out)
if self.configs.model_mode == 'recall':
@@ -172,6 +171,9 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4).contiguous()
- loss = self.MSE_criterion(next_frames, frames[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/mim_model.py b/openstl/models/mim_model.py
index 367201b5..0244dbb6 100644
--- a/openstl/models/mim_model.py
+++ b/openstl/models/mim_model.py
@@ -48,7 +48,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel,
kernel_size=1, stride=1, padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -110,6 +110,9 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/phydnet_model.py b/openstl/models/phydnet_model.py
index e87c142a..81573ce9 100644
--- a/openstl/models/phydnet_model.py
+++ b/openstl/models/phydnet_model.py
@@ -51,12 +51,14 @@ def forward(self, input_tensor, target_tensor, constraints, teacher_forcing_rati
return loss
- def inference(self, input_tensor, target_tensor, constraints):
+ def inference(self, input_tensor, target_tensor, constraints, **kwargs):
with torch.no_grad():
loss = 0
for ei in range(self.pre_seq_length - 1):
- encoder_output, encoder_hidden, output_image, _, _ = self.encoder(input_tensor[:,ei,:,:,:], (ei==0))
- loss += self.criterion(output_image, input_tensor[:,ei+1,:,:,:])
+ encoder_output, encoder_hidden, output_image, _, _ = \
+ self.encoder(input_tensor[:,ei,:,:,:], (ei==0))
+ if kwargs.get('return_loss', True):
+ loss += self.criterion(output_image, input_tensor[:,ei+1,:,:,:])
decoder_input = input_tensor[:,-1,:,:,:]
predictions = []
@@ -65,12 +67,13 @@ def inference(self, input_tensor, target_tensor, constraints):
_, _, output_image, _, _ = self.encoder(decoder_input, False, False)
decoder_input = output_image
predictions.append(output_image)
-
- loss += self.criterion(output_image, target_tensor[:,di,:,:,:])
+ if kwargs.get('return_loss', True):
+ loss += self.criterion(output_image, target_tensor[:,di,:,:,:])
for b in range(0, self.encoder.phycell.cell_list[0].input_dim):
filters = self.encoder.phycell.cell_list[0].F.conv1.weight[:,b,:,:]
m = self.k2m(filters.double()).float()
- loss += self.criterion(m, constraints)
-
+ if kwargs.get('return_loss', True):
+ loss += self.criterion(m, constraints)
+
return torch.stack(predictions, dim=1), loss
diff --git a/openstl/models/predrnn_model.py b/openstl/models/predrnn_model.py
index 37f20501..026f53a5 100644
--- a/openstl/models/predrnn_model.py
+++ b/openstl/models/predrnn_model.py
@@ -35,7 +35,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel,
kernel_size=1, stride=1, padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -81,6 +81,9 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/predrnnpp_model.py b/openstl/models/predrnnpp_model.py
index 142d5b30..090bed72 100644
--- a/openstl/models/predrnnpp_model.py
+++ b/openstl/models/predrnnpp_model.py
@@ -38,7 +38,7 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.conv_last = nn.Conv2d(num_hidden[num_layers - 1], self.frame_channel,
kernel_size=1, stride=1, padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -87,6 +87,9 @@ def forward(self, frames_tensor, mask_true):
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ if kwargs.get('return_loss', True):
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/predrnnv2_model.py b/openstl/models/predrnnv2_model.py
index 037f4eb3..a560d737 100644
--- a/openstl/models/predrnnv2_model.py
+++ b/openstl/models/predrnnv2_model.py
@@ -41,7 +41,8 @@ def __init__(self, num_layers, num_hidden, configs, **kwargs):
self.adapter = nn.Conv2d(
adapter_num_hidden, adapter_num_hidden, 1, stride=1, padding=0, bias=False)
- def forward(self, frames_tensor, mask_true):
+ def forward(self, frames_tensor, mask_true, **kwargs):
+ return_loss = kwargs.get('return_loss', True)
# [batch, length, height, width, channel] -> [batch, length, channel, height, width]
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
@@ -102,16 +103,22 @@ def forward(self, frames_tensor, mask_true):
x_gen = self.conv_last(h_t[self.num_layers - 1])
next_frames.append(x_gen)
+
# decoupling loss
- for i in range(0, self.num_layers):
- decouple_loss.append(torch.mean(torch.abs(
- torch.cosine_similarity(delta_c_list[i], delta_m_list[i], dim=2))))
+ if return_loss:
+ for i in range(0, self.num_layers):
+ decouple_loss.append(torch.mean(torch.abs(
+ torch.cosine_similarity(delta_c_list[i], delta_m_list[i], dim=2))))
- decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0))
+ if return_loss:
+ decouple_loss = torch.mean(torch.stack(decouple_loss, dim=0))
# [length, batch, channel, height, width] -> [batch, length, height, width, channel]
next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
- loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:]) + \
- self.configs.decouple_beta * decouple_loss
+ if return_loss:
+ loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:]) + \
+ self.configs.decouple_beta * decouple_loss
+ else:
+ loss = None
return next_frames, loss
diff --git a/openstl/models/simvp_model.py b/openstl/models/simvp_model.py
index 7eca2231..6b1400eb 100644
--- a/openstl/models/simvp_model.py
+++ b/openstl/models/simvp_model.py
@@ -32,7 +32,7 @@ def __init__(self, in_shape, hid_S=16, hid_T=256, N_S=4, N_T=4, model_type='gSTA
input_resolution=(H, W), model_type=model_type,
mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
- def forward(self, x_raw):
+ def forward(self, x_raw, **kwargs):
B, T, C, H, W = x_raw.shape
x = x_raw.view(B*T, C, H, W)
diff --git a/openstl/utils/parser.py b/openstl/utils/parser.py
index cdd05754..847e739a 100644
--- a/openstl/utils/parser.py
+++ b/openstl/utils/parser.py
@@ -26,7 +26,9 @@ def create_parser():
parser.add_argument('--fps', action='store_true', default=False,
help='Whether to measure inference speed (FPS)')
parser.add_argument('--empty_cache', action='store_true', default=True,
- help='Whether to empty cuda cache after training')
+ help='Whether to empty cuda cache after GPU training')
+ parser.add_argument('--find_unused_parameters', action='store_true', default=False,
+ help='Whether to find unused parameters in forward during DDP training')
parser.add_argument('--resume_from', type=str, default=None, help='the checkpoint file to resume from')
parser.add_argument('--auto_resume', action='store_true', default=False,
help='When training was interupted, resume from the latest checkpoint')
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index bae4441a..f5d435f9 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,7 +1,9 @@
+dask
future
fvcore
matplotlib
nni
+netcdf4
numpy
hickle
packaging
@@ -10,4 +12,4 @@ six
scikit-learn
timm
tqdm
-xarray
+xarray==0.19.0