forked from chengtan9907/OpenSTL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
release and fix OpenSTL V0.2.0 (issue chengtan9907#20)
- Loading branch information
Showing
39 changed files
with
640 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'convmixer' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-2 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'convnext' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-2 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'hornet' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'IncepU' # SimVP.V1 | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-2 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'mlp' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'moga' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 5e-3 | ||
batch_size = 16 | ||
drop_path = 0.2 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'poolformer' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 5e-4 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'swin' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'uniformer' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 5e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'van' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 5e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'vit' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 1e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
sched = 'cosine' | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
method = 'SimVP' | ||
# model | ||
spatio_kernel_enc = 3 | ||
spatio_kernel_dec = 3 | ||
model_type = 'gSTA' | ||
hid_S = 32 | ||
hid_T = 256 | ||
N_T = 8 | ||
N_S = 2 | ||
# training | ||
lr = 5e-3 | ||
batch_size = 16 | ||
drop_path = 0.1 | ||
warmup_epoch = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,81 @@ | ||
# Getting Started | ||
|
||
This page provides basic tutorials about the usage of SimVP. For installation instructions, please see [Install](docs/en/install.md). | ||
This page provides basic tutorials about the usage of OpenSTL with various spatioTemporal predictive learning (STL) tasks. For installation instructions, please see [Install](docs/en/install.md). | ||
|
||
## Training and Testing with a Single GPU | ||
|
||
You can perform single/multiple GPU training and testing with `tools/train.py` and `tools/test.py`. We provide descriptions of some essential arguments. | ||
You can perform single GPU training and testing with `tools/train.py` and `tools/test.py` with non-distributed and distributed (DDP) modes. Non-distributed mode is recommanded for the single GPU training (a bit faster than DDP). We provide descriptions of some essential arguments. Other arguments related to datasets, optimizers, methods can be found in [parser.py](https://github.com/chengtan9907/OpenSTL/tree/master/openstl/utils/parser.py). | ||
|
||
```bash | ||
python tools/train.py \ | ||
--dataname ${DATASET_NAME} \ | ||
--method ${METHOD_NAME} \ | ||
--config_file ${CONFIG_FILE} \ | ||
--ex_name ${EXP_NAME} \ | ||
--resume_from ${CHECKPOINT_FILE} \ | ||
--auto_resume \ | ||
--batch_size ${BATCH_SIZE} \ | ||
--lr ${LEARNING_RATE} \ | ||
--dist \ | ||
--fp16 \ | ||
--seed ${SEED} \ | ||
--clip_grad ${VALUE} \ | ||
--find_unused_parameters \ | ||
--deterministic \ | ||
``` | ||
|
||
**Description of arguments**: | ||
- `--dataname (-d)` : The name of dataset, default to be `mmnist`. | ||
- `--method (-m)` : The name of the video prediction method to train or test, default to be `SimVP`. | ||
- `--config_file (-c)` : The path of a model config file, which will provide detailed settings for a video prediction method. | ||
- `--config_file (-c)` : The path of a model config file, which will provide detailed settings for a STL method. | ||
- `--ex_name` : The name of the experiment under the `res_dir`. Default to be `Debug`. | ||
- `--resume_from ${CHECKPOINT_FILE}`: Resume from a previous checkpoint file. Or you can use `--auto_resume` to resume from `latest.pth` automatically. | ||
- `--auto_resume` : Whether to automatically resume training when the experiment was interrupted. | ||
- `--batch_size (-b)` : Training batch size, default to 16. | ||
- `--lr` : The basic training learning rate, defaults to 0.001. | ||
- `--dist`: Whether to use distributed training (DDP). | ||
- `--fp16`: Whether to use Native AMP for mixed precision training (PyTorch=>1.6.0). | ||
- `--seed ${SEED}`: Setup all random seeds to a certain number (defaults to 42). | ||
- `--clip_grad ${VALUE}`: Clip gradient norm value (default: None, no clipping). | ||
- `--find_unused_parameters`: Whether to find unused parameters in forward during DDP training. | ||
- `--deterministic`: Switch on "deterministic" mode, which slows down training while the results are reproducible. | ||
|
||
An example of single GPU training with SimVP+gSTA on Moving MNIST dataset. | ||
An example of single GPU (non-distributed) training with SimVP+gSTA on Moving MNIST dataset. | ||
```shell | ||
bash tools/prepare_data/download_mmnist.sh | ||
python tools/train.py -d mmnist --lr 1e-3 -c ./configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta | ||
python tools/train.py -d mmnist --lr 1e-3 -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta | ||
``` | ||
|
||
An example of single GPU testing with SimVP+gSTA on Moving MNIST dataset. | ||
```shell | ||
python tools/test.py -d mmnist -c configs/mmnist/simvp/SimVP_gSTA.py --ex_name mmnist_simvp_gsta | ||
``` | ||
|
||
## Training and Testing with Multiple GPUs | ||
|
||
For larger STL tasks (e.g., high resolutions), you can also perform multiple GPUs training and testing with `tools/dist_train.sh` and `tools/dist_test.sh` with DDP mode. The bash files will call `tools/train.py` and `tools/test.py` with the necessary arguments. | ||
|
||
```shell | ||
bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} [optional arguments] | ||
``` | ||
**Description of arguments**: | ||
- `${CONFIG_FILE}` : The path of a model config file, which will provide detailed settings for a STL method. | ||
- `${GPUS}` : The number of GPUs for DDP training. | ||
|
||
Examples of multiple GPUs training on Moving MNIST dataset with a machine with 8 GPUs. | ||
```shell | ||
PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_train.sh configs/mmnist/simvp/SimVP_gSTA.py 2 -d mmnist --lr 1e-3 --batch_size 8 | ||
PORT=29002 CUDA_VISIBLE_DEVICES=2,3 bash tools/dist_train.sh configs/mmnist/PredRNN.py 2 -d mmnist --lr 1e-3 --batch_size 8 | ||
PORT=29003 CUDA_VISIBLE_DEVICES=4,5,6,7 bash tools/dist_train.sh configs/mmnist/PredRNNpp.py 4 -d mmnist --lr 1e-3 --batch_size 4 | ||
``` | ||
|
||
An example of multiple GPUs testing on Moving MNIST dataset. The bash script is `bash tools/dist_train.sh ${CONFIG_FILE} ${GPUS} ${CHECKPOINT} [optional arguments]`. | ||
```shell | ||
PORT=29001 CUDA_VISIBLE_DEVICES=0,1 bash tools/dist_test.sh configs/mmnist/simvp/SimVP_gSTA.py 2 work_dirs/mmnist/simvp/SimVP_gSTA -d mmnist | ||
``` | ||
|
||
**Note**: During DDP training, the number of GPUS `ngpus` should be provided and checkpoints and logs are saved in the same folder structure as the config file under `work_dirs/` (it will be the default setting if `--ex_name` is not specified). The default learning rate `lr` and the batch size `bs` in config files are for a single GPU. If using a different number GPUs, the total batch size will change in proportion, you have to scale the learning rate following `lr = base_lr * ngpus` and `bs = base_bs * ngpus`. Other arguments should be added as the single GPU training. | ||
|
||
## Mixed Precision Training | ||
|
||
We support Mixed Precision Training implemented by PyTorch AMP. If you want to use Mixed Precision Training, you can add `--fp16` in the arguments. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.