From 874d27cb55b9d7e9df6cd0881e2d7fe9f262532b Mon Sep 17 00:00:00 2001 From: Kalyan Vasudev Alwala Date: Wed, 19 Jan 2022 15:47:55 -0800 Subject: [PATCH] PytorchVideo - Lightning Training pipeline (#158) Summary: Pull Request resolved: https://github.com/facebookresearch/pytorchvideo/pull/158 Pull Request resolved: https://github.com/fairinternal/pytorchvideo/pull/47 1. Support Video classification 2. Support Video SSL - SimCLR, BYOL, MoCo Reviewed By: haooooooqi Differential Revision: D33431232 fbshipit-source-id: 47ad9c35d45e4c8f9ac95e497dd7b582cb4084a9 --- pytorchvideo/__init__.py | 2 +- pytorchvideo/models/resnet.py | 9 +- pytorchvideo_trainer/README.md | 39 ++ .../pytorchvideo_trainer/__init__.py | 16 + .../callbacks/__init__.py | 8 + .../callbacks/precise_batchnorm.py | 70 +++ .../pytorchvideo_trainer/conf/__init__.py | 8 + .../conf/byol_train_app_conf.yaml | 28 + .../conf/callbacks/precise_bn.yaml | 3 + .../conf/classification_mvit_16x4.yaml | 72 +++ .../conf/classification_slow_8x8_r50.yaml | 46 ++ .../conf/classification_slowfast_8x8_r50.yaml | 46 ++ .../conf/classification_x3d_xs.yaml | 65 +++ .../dataloader/kinetics_classification.yaml | 43 ++ .../dataloader/kinetics_contrastive.yaml | 41 ++ .../kinetics_classification_mvit_16x4.yaml | 70 +++ .../kinetics_classification_slow.yaml | 51 ++ .../kinetics_classification_slowfast.yaml | 60 ++ .../kinetics_classification_x3d_xs.yaml | 51 ++ .../transforms/kinetics_contrastive.yaml | 56 ++ .../transforms/kinetics_moco_v2.yaml | 56 ++ .../pytorchvideo_trainer/conf/logger/ptl.yaml | 4 + .../conf/moco_v2_train_app_conf.yaml | 31 ++ .../conf/module/knn_memory/kinetics_k400.yaml | 7 + .../conf/module/loss/contrastive.yaml | 2 + .../conf/module/loss/cross_entropy.yaml | 2 + .../conf/module/loss/nt_xent.yaml | 3 + .../conf/module/loss/similarity.yaml | 3 + .../conf/module/loss/soft_cross_entropy.yaml | 2 + .../lr_scheduler/cosine_with_warmup.yaml | 7 + .../conf/module/metrics/accuracy.yaml | 8 + .../module/metrics/average_precision.yaml | 3 + .../model/from_lightning_checkpoint.yaml | 2 + .../model/from_model_zoo_checkpoint.yaml | 5 + .../module/model/from_ssl_checkpoint.yaml | 11 + .../conf/module/model/mvit_base_16x4.yaml | 32 ++ .../conf/module/model/slow_r50.yaml | 7 + .../conf/module/model/slow_r50_byol.yaml | 3 + .../conf/module/model/slow_r50_moco_v2.yaml | 3 + .../conf/module/model/slow_r50_simclr.yaml | 4 + .../conf/module/model/slowfast_r50.yaml | 6 + .../conf/module/model/x3d_xs.yaml | 8 + .../conf/module/optim/adam.yaml | 3 + .../conf/module/optim/adamw.yaml | 3 + .../conf/module/optim/sgd.yaml | 5 + .../conf/module/optim/sgd_ssl.yaml | 5 + .../conf/simclr_train_app_conf.yaml | 25 + .../conf/submitit_conf/fair_cluster.yaml | 9 + .../conf/trainer/cpu.yaml | 2 + .../conf/trainer/multi_gpu.yaml | 6 + .../conf/trainer/single_gpu.yaml | 3 + .../datamodule/__init__.py | 8 + .../datamodule/collators.py | 46 ++ .../datamodule/datamodule.py | 226 ++++++++ .../datamodule/rand_erase_transform.py | 196 +++++++ .../datamodule/transforms.py | 287 ++++++++++ .../pytorchvideo_trainer/module/__init__.py | 13 + .../pytorchvideo_trainer/module/byol.py | 329 +++++++++++ .../module/distributed_utils.py | 330 +++++++++++ .../pytorchvideo_trainer/module/losses.py | 135 +++++ .../pytorchvideo_trainer/module/lr_policy.py | 156 ++++++ .../pytorchvideo_trainer/module/moco_v2.py | 456 +++++++++++++++ .../pytorchvideo_trainer/module/optimizer.py | 257 +++++++++ .../pytorchvideo_trainer/module/simclr.py | 229 ++++++++ .../pytorchvideo_trainer/module/ssl_helper.py | 473 ++++++++++++++++ .../module/video_classification.py | 518 ++++++++++++++++++ .../pytorchvideo_trainer/train_app.py | 300 ++++++++++ pytorchvideo_trainer/setup.py | 38 ++ pytorchvideo_trainer/tests/__init__.py | 1 + .../tests/test_conf_datamodule.py | 28 + .../tests/test_conf_module.py | 62 +++ pytorchvideo_trainer/tests/test_task_byol.py | 63 +++ .../tests/test_task_moco_v2.py | 64 +++ .../tests/test_task_module_all.py | 129 +++++ .../tests/test_task_simclr.py | 63 +++ .../tests/test_task_video_classification.py | 92 ++++ pytorchvideo_trainer/tests/util.py | 163 ++++++ 77 files changed, 5714 insertions(+), 2 deletions(-) create mode 100644 pytorchvideo_trainer/README.md create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/__init__.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py create mode 100644 pytorchvideo_trainer/pytorchvideo_trainer/train_app.py create mode 100644 pytorchvideo_trainer/setup.py create mode 100644 pytorchvideo_trainer/tests/__init__.py create mode 100644 pytorchvideo_trainer/tests/test_conf_datamodule.py create mode 100644 pytorchvideo_trainer/tests/test_conf_module.py create mode 100644 pytorchvideo_trainer/tests/test_task_byol.py create mode 100644 pytorchvideo_trainer/tests/test_task_moco_v2.py create mode 100644 pytorchvideo_trainer/tests/test_task_module_all.py create mode 100644 pytorchvideo_trainer/tests/test_task_simclr.py create mode 100644 pytorchvideo_trainer/tests/test_task_video_classification.py create mode 100644 pytorchvideo_trainer/tests/util.py diff --git a/pytorchvideo/__init__.py b/pytorchvideo/__init__.py index 36635173..d2b2f870 100644 --- a/pytorchvideo/__init__.py +++ b/pytorchvideo/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -__version__ = "0.1.3" +__version__ = "0.1.5" diff --git a/pytorchvideo/models/resnet.py b/pytorchvideo/models/resnet.py index 1aba7d90..22f60c01 100644 --- a/pytorchvideo/models/resnet.py +++ b/pytorchvideo/models/resnet.py @@ -316,6 +316,13 @@ def create_acoustic_bottleneck_block( ) +def _trivial_sum(x, y): + """ + Utility function used in lieu of lamda which are not picklable + """ + return x + y + + def create_res_block( *, # Bottleneck Block configs. @@ -324,7 +331,7 @@ def create_res_block( dim_out: int, bottleneck: Callable, use_shortcut: bool = False, - branch_fusion: Callable = lambda x, y: x + y, + branch_fusion: Callable = _trivial_sum, # Conv configs. conv_a_kernel_size: Tuple[int] = (3, 1, 1), conv_a_stride: Tuple[int] = (2, 1, 1), diff --git a/pytorchvideo_trainer/README.md b/pytorchvideo_trainer/README.md new file mode 100644 index 00000000..46886c33 --- /dev/null +++ b/pytorchvideo_trainer/README.md @@ -0,0 +1,39 @@ +## PyTorchVideo Trainer + +A [PyTorch-Lightning]() based trainer supporting PytorchVideo models and dataloaders for various video understanding tasks. + +Currently supported tasks include: + +- Video Action Recognition: ResNet's, SlowFast Models, X3D models and MViT +- Video Self-Supervised Learning: SimCLR, BYOL, MoCo +- (Planned) Video Action Detection + +## Installation + +These instructions assumes that both pytorch and torchvision are already installed +using the instructions in [INSTALL.md](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md#requirements) + +Install the required additional dependency `recipes` by running the following command, +``` +pip install "git+https://github.com/facebookresearch/recipes.git" +``` + +Post that, install PyTorchVideo Trainer by running, +``` +git clone https://github.com/facebookresearch/pytorchvideo.git +cd pytorchvideo/pytorchvideo_trainer +pip install -e . + +# For developing and testing +pip install -e . [test,dev] +``` + +## Testing + +Before running the tests, please ensure that you installed the necessary additional test dependencies. + +Use the the following command to run the tests: +``` +# From the current directory +python -m unittest discover -v -s ./tests +``` diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py b/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py new file mode 100644 index 00000000..2e164fde --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + + +def register_components() -> None: + """ + Calls register_components() for all subfolders so we can register + subcomponents to Hydra's ConfigStore. + """ + import pytorchvideo_trainer.datamodule.datamodule # noqa + import pytorchvideo_trainer.module.byol # noqa + import pytorchvideo_trainer.module.lr_policy # noqa + import pytorchvideo_trainer.module.moco_v2 # noqa + import pytorchvideo_trainer.module.optimizer # noqa + import pytorchvideo_trainer.module.simclr # noqa + import pytorchvideo_trainer.module.video_classification # noqa + import pytorchvideo_trainer.train_app # noqa diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py b/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py new file mode 100644 index 00000000..2cd87adf --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from .precise_batchnorm import PreciseBn # noqa + + +__all__ = [ + "PreciseBn", +] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py b/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py new file mode 100644 index 00000000..7b716ff8 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Generator + +import torch +from fvcore.nn.precise_bn import update_bn_stats +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader + + +class PreciseBn(Callback): + """ + Recompute and update the batch norm stats to make them more precise. During + training both BN stats and the weight are changing after every iteration, so + the running average can not precisely reflect the actual stats of the + current model. + In this callaback, the BN stats are recomputed with fixed weights, to make + the running average more precise during Training Phase. Specifically, it + computes the true average of per-batch mean/variance instead of the + running average. See Sec. 3 of the paper "Rethinking Batch in BatchNorm" + for details. + """ + + def __init__(self, num_batches: int) -> None: + """ + Args: + num_batches (int): Number of steps / mini-batches to + perform to sample for updating the precise batchnorm + stats. + """ + self.num_batches = num_batches + + def _get_precise_bn_loader( + self, data_loader: DataLoader, pl_module: LightningModule + ) -> Generator[torch.Tensor, None, None]: + for batch in data_loader: + inputs = batch[pl_module.modality_key] + if isinstance(inputs, list): + inputs = [x.to(pl_module.device) for x in inputs] + else: + inputs = inputs.to(pl_module.device) + yield inputs + + def on_train_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + ) -> None: + """ + Called at the end of every epoch only during the training + phase. + + Args: + trainer (Trainer): A PyTorch-Lightning trainer object. + pl_module (LightningModule): A PyTorch-Lightning module. + Typically supported modules include - + pytorchvideo_trainer.module.VideoClassificationModule, etc. + """ + # pyre-ignore[16] + dataloader = trainer.datamodule.train_dataloader() + precise_bn_loader = self._get_precise_bn_loader( + data_loader=dataloader, pl_module=pl_module + ) + update_bn_stats( + model=pl_module.model, # pyre-ignore[6] + data_loader=precise_bn_loader, + num_iters=self.num_batches, + ) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py b/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py new file mode 100644 index 00000000..f45051af --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import torchrecipes.core.conf # noqa + +# Components to register with this config +from pytorchvideo_trainer import register_components + +register_components() diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml new file mode 100644 index 00000000..9a53d1af --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml @@ -0,0 +1,28 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: byol_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_contrastive + - logger: ptl + - datamodule/transforms: kinetics_contrastive + - module/knn_memory: kinetics_k400 + - module/model: slow_r50_byol + - module/loss: similarity + - module/optim: sgd_ssl + - module/metrics: accuracy + - schema/trainer: trainer + - trainer: cpu + - callbacks: null + - _self_ +trainer: + sync_batchnorm: false # set this to true for training + +module: + momentum_anneal_cosine: true + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml new file mode 100644 index 00000000..9b0934d8 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml @@ -0,0 +1,3 @@ +precise_bn: + _target_: pytorchvideo_trainer.callbacks.precise_batchnorm.PreciseBn + num_batches: null diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml new file mode 100644 index 00000000..25edcd5d --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml @@ -0,0 +1,72 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: video_classification_module_conf_vision_transformer + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_classification + - logger: ptl + - datamodule/transforms: kinetics_classification_mvit_16x4 + - module/model: mvit_base_16x4 + - module/loss: soft_cross_entropy + - module/optim: adamw + - module/metrics: accuracy + - module/lr_scheduler: cosine_with_warmup + - schema/trainer: trainer + - trainer: multi_gpu + - _self_ + +module: + clip_gradient_norm: 1.0 + ensemble_method: "sum" + lr_scheduler: + max_iters: 200 + warmup_start_lr: 1.6e-05 + warmup_iters: 30 + cosine_after_warmup: true + cosine_end_lr: 1.6e-05 + optim: + lr: 0.0016 + weight_decay: 0.05 + method: adamw + zero_weight_decay_1d_param: true + batch_transform: + _target_: pytorchvideo_trainer.datamodule.transforms.MixVideoBatchWrapper + mixup_alpha: 0.8 + cutmix_prob: 0.5 + cutmix_alpha: 1.0 + label_smoothing: 0.1 + +datamodule: + dataloader: + train: + batch_size: 2 + dataset: + clip_sampler: + clip_duration: 2.13 + collate_fn: + _target_: pytorchvideo_trainer.datamodule.collators.build_collator_from_name + name: multiple_samples_collate + val: + batch_size: 8 + dataset: + clip_sampler: + clip_duration: 2.13 + test: + batch_size: 8 + dataset: + clip_sampler: + clip_duration: 2.13 + +trainer: + num_nodes: 16 + gpus: 8 + log_gpu_memory: null + max_epochs: 200 + sync_batchnorm: False + replace_sampler_ddp: False + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml new file mode 100644 index 00000000..73a35a68 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml @@ -0,0 +1,46 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: video_classification_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_classification + - logger: ptl + - datamodule/transforms: kinetics_classification_slow + - module/model: slow_r50 + - module/loss: cross_entropy + - module/optim: sgd + - module/metrics: accuracy + - module/lr_scheduler: cosine_with_warmup + - schema/trainer: trainer + - trainer: multi_gpu + - callbacks: precise_bn + - _self_ + +module: + ensemble_method: "sum" + lr_scheduler: + max_iters: 196 + warmup_start_lr: 0.01 + warmup_iters: 34 + optim: + lr: 0.8 + nesterov: true + +callbacks: + precise_bn: + num_batches: 200 + +trainer: + num_nodes: 8 + gpus: 8 + log_gpu_memory: null + max_epochs: 196 + sync_batchnorm: False + replace_sampler_ddp: False + + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml new file mode 100644 index 00000000..43dab71e --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml @@ -0,0 +1,46 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: video_classification_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_classification + - logger: ptl + - datamodule/transforms: kinetics_classification_slowfast + - module/model: slowfast_r50 + - module/loss: cross_entropy + - module/optim: sgd + - module/metrics: accuracy + - module/lr_scheduler: cosine_with_warmup + - schema/trainer: trainer + - trainer: multi_gpu + - callbacks: precise_bn + - _self_ + +module: + ensemble_method: "sum" + lr_scheduler: + max_iters: 196 + warmup_start_lr: 0.01 + warmup_iters: 34 + optim: + lr: 0.8 + nesterov: true + +callbacks: + precise_bn: + num_batches: 200 + +trainer: + num_nodes: 8 + gpus: 8 + log_gpu_memory: null + max_epochs: 196 + sync_batchnorm: False + replace_sampler_ddp: False + + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml new file mode 100644 index 00000000..516796b4 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_x3d_xs.yaml @@ -0,0 +1,65 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: video_classification_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_classification + - logger: ptl + - datamodule/transforms: kinetics_classification_x3d_xs + - module/model: x3d_xs + - module/loss: cross_entropy + - module/optim: sgd + - module/metrics: accuracy + - module/lr_scheduler: cosine_with_warmup + - schema/trainer: trainer + - trainer: multi_gpu + - callbacks: precise_bn + - _self_ + +module: + ensemble_method: "sum" + lr_scheduler: + max_iters: 300 + warmup_start_lr: 0.01 + warmup_iters: 35 + optim: + lr: 0.8 + nesterov: true + weight_decay: 5e-5 + +datamodule: + dataloader: + train: + batch_size: 16 + dataset: + clip_sampler: + clip_duration: 1.6 + val: + batch_size: 16 + dataset: + clip_sampler: + clip_duration: 1.6 + test: + batch_size: 16 + dataset: + clip_sampler: + clip_duration: 1.6 + +callbacks: + precise_bn: + num_batches: 200 + +trainer: + num_nodes: 8 + gpus: 8 + log_gpu_memory: null + max_epochs: 300 + sync_batchnorm: False + replace_sampler_ddp: False + + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml new file mode 100644 index 00000000..1d428802 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml @@ -0,0 +1,43 @@ +train: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.RandomClipSampler + clip_duration: 2.13 + + shuffle: True + batch_size: 8 + num_workers: 8 + pin_memory: True + +val: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.UniformClipSampler + clip_duration: 2.13 + + shuffle: False + batch_size: 8 + num_workers: 8 + pin_memory: True + +test: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler + clip_duration: 2.13 + clips_per_video: 10 #num_ensemble_views + augs_per_clip: 3 # num_spatial_crops + + shuffle: False + batch_size: 8 + num_workers: 8 + pin_memory: True diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml new file mode 100644 index 00000000..4208b9a5 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_contrastive.yaml @@ -0,0 +1,41 @@ +train: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.RandomMultiClipSampler + clip_duration: 2.0 + num_clips: 2 + + shuffle: True + batch_size: 8 + num_workers: 8 + +val: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.UniformClipSampler + clip_duration: 2.0 + + shuffle: False + batch_size: 8 + num_workers: 8 + +test: + dataset: + _target_: pytorchvideo.data.Kinetics + data_path: ??? + video_path_prefix: ??? + clip_sampler: + _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler + clip_duration: 2.0 + clips_per_video: 10 #num_ensemble_views + augs_per_clip: 3 # num_spatial_crops + + shuffle: False + batch_size: 8 + num_workers: 8 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml new file mode 100644 index 00000000..7924a22a --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml @@ -0,0 +1,70 @@ +train: + - _target_: pytorchvideo_trainer.datamodule.transforms.RepeatandConverttoList + repeat_num: 2 + - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 16 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Permute + dims: [1,0,2,3] + - _target_: pytorchvideo.transforms.rand_augment.RandAugment + magnitude: 7 + num_layers: 4 + - _target_: pytorchvideo.transforms.Permute + dims: [1,0,2,3] + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomResizedCrop + target_height: 224 + target_width: 224 + scale: [0.08, 1.0] + aspect_ratio: [0.75, 1.3333] + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + - _target_: pytorchvideo.transforms.Permute + dims: [1,0,2,3] + - _target_: pytorchvideo_trainer.datamodule.rand_erase_transform.RandomErasing + probability: 0.25 + mode: "pixel" + max_count: 1 + num_splits: 1 + device: "cpu" + - _target_: pytorchvideo.transforms.Permute + dims: [1,0,2,3] + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 16 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 224 + - _target_: torchvision.transforms.CenterCrop + size: 224 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 16 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 224 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 224 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml new file mode 100644 index 00000000..cc69c698 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml @@ -0,0 +1,51 @@ +train: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomShortSideScale + min_size: 256 + max_size: 320 + - _target_: torchvision.transforms.RandomCrop + size: 224 + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + - _target_: torchvision.transforms.CenterCrop + size: 256 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 256 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml new file mode 100644 index 00000000..5388fb05 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slowfast.yaml @@ -0,0 +1,60 @@ +train: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 32 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomShortSideScale + min_size: 256 + max_size: 320 + - _target_: torchvision.transforms.RandomCrop + size: 224 + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway + alpha: 4 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 32 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + - _target_: torchvision.transforms.CenterCrop + size: 256 + - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway + alpha: 4 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 32 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 256 + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo_trainer.datamodule.transforms.SlowFastPackPathway + alpha: 4 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml new file mode 100644 index 00000000..e80f20e6 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_x3d_xs.yaml @@ -0,0 +1,51 @@ +train: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 4 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomShortSideScale + min_size: 182 + max_size: 228 + - _target_: torchvision.transforms.RandomCrop + size: 160 + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 4 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 182 + - _target_: torchvision.transforms.CenterCrop + size: 182 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 4 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 182 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 182 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml new file mode 100644 index 00000000..23897187 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_contrastive.yaml @@ -0,0 +1,56 @@ +train: + - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo_trainer.datamodule.transforms.ColorJitterVideoSSl + bri_con_sat: [0.6, 0.6, 0.6] + hue: 0.15 + p_color_jitter: 0.8 + p_convert_gray: 0.2 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomResizedCrop + target_height: 224 + target_width: 224 + scale: [0.2, 0.766] + aspect_ratio: [0.75, 1.3333] + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + - _target_: torchvision.transforms.CenterCrop + size: 256 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 256 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml new file mode 100644 index 00000000..2919905a --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/datamodule/transforms/kinetics_moco_v2.yaml @@ -0,0 +1,56 @@ +train: + - _target_: pytorchvideo_trainer.datamodule.transforms.ApplyTransformToKeyOnList + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo_trainer.datamodule.transforms.ColorJitterVideoSSl + bri_con_sat: [0.4, 0.4, 0.4] + hue: 0.4 + p_color_jitter: 0.8 + p_convert_gray: 0.2 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.RandomResizedCrop + target_height: 224 + target_width: 224 + scale: [0.2, 0.766] + aspect_ratio: [0.75, 1.3333] + - _target_: torchvision.transforms.RandomHorizontalFlip + p: 0.5 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +val: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + - _target_: torchvision.transforms.CenterCrop + size: 256 + key: video + - _target_: pytorchvideo.transforms.RemoveKey + key: audio +test: + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 8 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 256 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 256 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml new file mode 100644 index 00000000..352afc06 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/logger/ptl.yaml @@ -0,0 +1,4 @@ +_target_: pytorch_lightning.loggers.TensorBoardLogger +save_dir: ??? +name: default +version: null diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml new file mode 100644 index 00000000..d464b742 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/moco_v2_train_app_conf.yaml @@ -0,0 +1,31 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: moco_v2_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_contrastive + - logger: ptl + - datamodule/transforms: kinetics_moco_v2 + - module/knn_memory: kinetics_k400 + - module/model: slow_r50_moco_v2 + - module/loss: contrastive + - module/optim: sgd_ssl + - module/metrics: accuracy + - schema/trainer: trainer + - trainer: cpu + - callbacks: null + - _self_ +trainer: + sync_batchnorm: false # set this to true for training + +module: + dim: ${module.model.backbone_embed_dim} + k: 65536 + batch_shuffle: true + local_shuffle_bn: true + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml new file mode 100644 index 00000000..edecec61 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/knn_memory/kinetics_k400.yaml @@ -0,0 +1,7 @@ +_target_: pytorchvideo_trainer.module.ssl_helper.KnnMemory +temperature: ${module.loss.temperature} +dim: ${module.model.backbone_embed_dim} +length: 239975 +downstream_classes: 400 +knn_k: 200 +momentum: 1.0 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml new file mode 100644 index 00000000..669a328d --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/contrastive.yaml @@ -0,0 +1,2 @@ +_target_: pytorchvideo_trainer.module.losses.ContrastiveLoss +temperature: 0.1 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml new file mode 100644 index 00000000..f381cd87 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/cross_entropy.yaml @@ -0,0 +1,2 @@ +# @package _group_ +_target_: torch.nn.CrossEntropyLoss diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml new file mode 100644 index 00000000..11df106a --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/nt_xent.yaml @@ -0,0 +1,3 @@ +# @package _group_ +_target_: pytorchvideo_trainer.module.losses.NtxentLoss +temperature: 0.1 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml new file mode 100644 index 00000000..c483cfd7 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/similarity.yaml @@ -0,0 +1,3 @@ +# @package _group_ +_target_: pytorchvideo_trainer.module.losses.SimilarityLoss +temperature: 0.1 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml new file mode 100644 index 00000000..2df3319c --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/loss/soft_cross_entropy.yaml @@ -0,0 +1,2 @@ +# @package _group_ +_target_: pytorchvideo_trainer.module.losses.SoftTargetCrossEntropy diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml new file mode 100644 index 00000000..b307f9fb --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/lr_scheduler/cosine_with_warmup.yaml @@ -0,0 +1,7 @@ +lr_policy: 'cosine' +cosine_after_warmup: False +cosine_end_lr: 0 +warmup_iters: 34 +warmup_start_lr: 0.01 +max_iters: ${trainer.max_epochs} +lr: ${module.optim.lr} diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml new file mode 100644 index 00000000..8bbd4268 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/accuracy.yaml @@ -0,0 +1,8 @@ +- name: accuracy_top1 + config: + _target_: torchmetrics.Accuracy + top_k: 1 +- name: accuracy_top5 + config: + _target_: torchmetrics.Accuracy + top_k: 5 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml new file mode 100644 index 00000000..dd8ad7a6 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/metrics/average_precision.yaml @@ -0,0 +1,3 @@ +- name: average_precision + config: + _target_: torchmetrics.AveragePrecision diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml new file mode 100644 index 00000000..00e30353 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml @@ -0,0 +1,2 @@ +_target_: pytorchvideo_trainer.module.video_classification.create_classification_model_from_lightning +checkpoint_path: ??? diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml new file mode 100644 index 00000000..422bdbdf --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml @@ -0,0 +1,5 @@ +_target_: pytorchvideo_trainer.module.video_classification.create_classification_model_from_modelzoo +checkpoint_path: manifold://fair_logging/tree/kalyanv/hub_models/SLOW_8x8_R50.pyth +model: + _target_: pytorchvideo.models.hub.resnet.slow_r50 + pretrained: False diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml new file mode 100644 index 00000000..21c3b9d0 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/from_ssl_checkpoint.yaml @@ -0,0 +1,11 @@ +_target_: pytorchvideo_trainer.module.ssl_helper.create_classification_model_from_ssl_checkpoint +ssl_checkpoint_path: null +checkpoint_type: simclr +mlp: + _target_: pytorchvideo_trainer.module.byol.create_mlp_util + dim_in: null + dim_out: 400 + mlp_dim: 256 + num_layers: 1 + norm: null +detach_backbone: true diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml new file mode 100644 index 00000000..fa5d958a --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/mvit_base_16x4.yaml @@ -0,0 +1,32 @@ +_target_: pytorchvideo.models.vision_transformers.create_multiscale_vision_transformers +spatial_size: 224 +temporal_size: 16 +cls_embed_on: True +sep_pos_embed: True +depth: 16 +norm: "layernorm" +input_channels: 3 +patch_embed_dim: 96 +conv_patch_embed_kernel: [3, 7, 7] +conv_patch_embed_stride: [2, 4, 4] +conv_patch_embed_padding: [1, 3, 3] +enable_patch_embed_norm: False +use_2d_patch: False +# Attention block config. +num_heads: 1 +mlp_ratio: 4.0 +qkv_bias: True +dropout_rate_block: 0.0 +droppath_rate_block: 0.2 +pooling_mode: "conv" +pool_first: False +embed_dim_mul: [[1, 2.0], [3, 2.0], [14, 2.0]] +atten_head_mul: [[1, 2.0], [3, 2.0], [14, 2.0]] +pool_q_stride_size: [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]] +pool_kv_stride_size: null +pool_kv_stride_adaptive: [1, 8, 8] +pool_kvq_kernel: [3, 3, 3] +# Head config. +head_dropout_rate: 0.5 +head_activation: null +head_num_classes: 400 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml new file mode 100644 index 00000000..05ba9a58 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50.yaml @@ -0,0 +1,7 @@ +_target_: pytorchvideo.models.resnet.create_resnet +input_channel: 3 +model_depth: 50 +model_num_class: 400 +dropout_rate: 0.5 +stem_conv_kernel_size: [1, 7, 7] +head_pool_kernel_size: [8, 7, 7] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml new file mode 100644 index 00000000..daeb90bc --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_byol.yaml @@ -0,0 +1,3 @@ +_target_: pytorchvideo_trainer.module.byol.create_byol_resnet_50 +backbone_embed_dim: 128 +mmt: 0.996 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml new file mode 100644 index 00000000..aef3defe --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_moco_v2.yaml @@ -0,0 +1,3 @@ +_target_: pytorchvideo_trainer.module.moco_v2.create_moco_resnet_50 +backbone_embed_dim: 128 +mmt: 0.994 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml new file mode 100644 index 00000000..d4e9593f --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slow_r50_simclr.yaml @@ -0,0 +1,4 @@ +_target_: pytorchvideo_trainer.module.simclr.create_simclr_resnet_50 +backbone_embed_dim: 128 +mlp_depth: 1 +mlp_inner_dim: 2048 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml new file mode 100644 index 00000000..08e5335f --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/slowfast_r50.yaml @@ -0,0 +1,6 @@ +_target_: pytorchvideo.models.slowfast.create_slowfast +input_channels: [3,3] +model_depth: 50 +model_num_class: 400 +dropout_rate: 0.5 +slowfast_fusion_conv_kernel_size: [7, 1, 1] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml new file mode 100644 index 00000000..c645cda9 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/model/x3d_xs.yaml @@ -0,0 +1,8 @@ +_target_: pytorchvideo.models.x3d.create_x3d +input_channel: 3 +model_num_class: 400 +dropout_rate: 0.5 +input_clip_length: 4 +input_crop_size: 160 +depth_factor: 2.2 +head_activation: null diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml new file mode 100644 index 00000000..a9239a73 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adam.yaml @@ -0,0 +1,3 @@ +method: 'adam' +lr: 0.001 +weight_decay: 0 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml new file mode 100644 index 00000000..ef29b4da --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/adamw.yaml @@ -0,0 +1,3 @@ +method: 'adamw' +lr: 0.001 +weight_decay: 0.01 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml new file mode 100644 index 00000000..011eaef5 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd.yaml @@ -0,0 +1,5 @@ +method: 'sgd' +lr: 0.1 +weight_decay: 1e-4 +momentum: 0.9 +nesterov: True diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml new file mode 100644 index 00000000..a0dc8834 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/module/optim/sgd_ssl.yaml @@ -0,0 +1,5 @@ +method: 'sgd' +lr: 0.6 +weight_decay: 1e-6 +momentum: 0.9 +nesterov: True diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml new file mode 100644 index 00000000..319d1c3f --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/simclr_train_app_conf.yaml @@ -0,0 +1,25 @@ +_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp + +defaults: + - schema/module: simclr_module_conf + - schema/module/optim: optim_conf + - schema/datamodule: ptv_video_classification_data_module_conf + - datamodule/dataloader: kinetics_contrastive + - logger: ptl + - datamodule/transforms: kinetics_moco_v2 + - module/knn_memory: kinetics_k400 + - module/model: slow_r50_simclr + - module/loss: nt_xent + - module/optim: sgd_ssl + - module/metrics: accuracy + - schema/trainer: trainer + - trainer: cpu + - callbacks: null + - _self_ +trainer: + sync_batchnorm: false # set this to true for training + +hydra: + searchpath: + - pkg://pytorchvideo_trainer.conf + - pkg://torchrecipes.core.conf diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml new file mode 100644 index 00000000..87ca47f3 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/submitit_conf/fair_cluster.yaml @@ -0,0 +1,9 @@ +# @package _group_ +log_save_dir: null +name: "ptv_trainer_job" +time: "72:00:00" +cpus_per_task: 10 +partition: "learnlab" +mem: "470GB" +constraint: "volta32gb" +mode: "prod" diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml new file mode 100644 index 00000000..91fea054 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/cpu.yaml @@ -0,0 +1,2 @@ +# @package _group_ +max_epochs: 1 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml new file mode 100644 index 00000000..fa480123 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/multi_gpu.yaml @@ -0,0 +1,6 @@ +# @package _group_ +gpus: 8 +accelerator: ddp +max_epochs: 1 +num_sanity_val_steps: 0 +log_every_n_steps: 10 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml new file mode 100644 index 00000000..431c0e35 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/conf/trainer/single_gpu.yaml @@ -0,0 +1,3 @@ +# @package _group_ +gpus: 1 +max_epochs: 1 diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py new file mode 100644 index 00000000..729d1510 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from .datamodule import PyTorchVideoDataModule # noqa + + +__all__ = [ + "PyTorchVideoDataModule", +] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py new file mode 100644 index 00000000..28680de5 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/collators.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from typing import Dict, Any, List, Callable + +from torch.utils.data._utils.collate import default_collate + + +# pyre-ignore[2] +def multiple_samples_collate(batch: List[Dict[str, List[Any]]]) -> Dict[str, Any]: + """ + Collate function for repeated augmentation. Each instance in the batch has + more than one sample. + + To be used when working with, + `pytorchvideo_trainer.datamodule.transforms.RepeatandConverttoList` + """ + batch_dict = {} + for k in batch[0].keys(): + v_iter = [] + for sample_dict in batch: + v_iter += sample_dict[k] + batch_dict[k] = default_collate(v_iter) + + return batch_dict + + +# pyre-ignore[24] +_COLLATORS: Dict[str, Callable] = { + "multiple_samples_collate": multiple_samples_collate, +} + + +def build_collator_from_name(name: str) -> Callable: # pyre-ignore[24] + """ + A utility function that returns the function handles to specific collators + in `_COLLATORS` dictionary object based on the queried key. Used in + `pytorchvideo_trainer.datamodule.PyTorchVideoDataModule`, etc. + + Arg: + name (str): name of the qurried collators. The key should be present in + `_COLLATORS` dictionary object + """ + assert ( + name in _COLLATORS + ), f"Inavalid Collator method. Available methods are {_COLLATORS.keys()}" + return _COLLATORS[name] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py new file mode 100644 index 00000000..ee754bb8 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/datamodule.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Optional, Any, Callable, Dict, List + +import hydra +import pytorch_lightning as pl +import pytorchvideo.data +import torch +from hydra.core.config_store import ConfigStore + +# @manual "//github/third-party/omry/omegaconf:omegaconf" +from omegaconf import MISSING +from pytorchvideo_trainer.datamodule.transforms import build_transforms +from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data.distributed import DistributedSampler +from torchrecipes.core.conf import DataModuleConf +from torchrecipes.utils.config_utils import get_class_name_str + + +class PyTorchVideoDataModule(pl.LightningDataModule): + """ + A PyTorch-Lightning DataModule module supporting all the dataloaders + in PyTorchVideo for different phases (train, validation and testing) of + Lightning tranining. + + Supports loading any aribtrary iterable and map-style PyTorchVideo dataset + upon following the config schema detailed below. + + Args: + dataloader (DataLoaderConf): + An OmegaConf / Hydra Config object consisting of dataloder + config for each phase i.e, train, val and test. + + The Hydra schema for this config is as defined in + `pytorchvideo_trainer.datamodule.datamodule.DataLoaderConf` + + One such example config can be found at + `pytorchvideo_trainer/conf/datamodule/dataloader/kinetics_classification.yaml` + + transforms (TransformsConf): + An OmegaConf / Hydra Config object consisting of transforms + config for each phase i.e, train, val and test. + + The Hydra schema for this config is as defined in + `pytorchvideo_trainer.datamodule.datamodule.TransformsConf` + + One such example config used for Resnet50 video model traning can be found at + `pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_slow.yaml` + """ + + def __init__( + self, + dataloader: DataLoaderConf, + transforms: TransformsConf, + ) -> None: + super().__init__() + self.config: Dict[str, Any] = { + "train": dataloader.train, + "val": dataloader.val, + "test": dataloader.test, + } + self.transforms: Dict[str, Any] = { + "train": build_transforms(transforms.train), + "val": build_transforms(transforms.val), + "test": build_transforms(transforms.test), + } + self.datasets: dict[str, Any] = {"train": None, "val": None, "test": None} + + def setup(self, stage: Optional[str] = None) -> None: + + if stage == "fit" or stage is None: + self.datasets["train"] = self._get_dataset( + phase="train", transforms=self.transforms["train"] + ) + self.datasets["val"] = self._get_dataset( + phase="val", transforms=self.transforms["val"] + ) + if stage == "test" or stage is None: + self.datasets["test"] = self._get_dataset( + phase="test", transforms=self.transforms["test"] + ) + + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ + Defines the train DataLoader that the PyTorch Lightning Trainer uses. + """ + if ( + self.trainer + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + self.datasets["train"].video_sampler.set_epoch(self.trainer.current_epoch) + + return self._get_dataloader("train") + + def val_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ + Defines the val DataLoader that the PyTorch Lightning Trainer uses. + """ + return self._get_dataloader("val") + + def test_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ + Defines the test DataLoader that the PyTorch Lightning Trainer uses. + """ + return self._get_dataloader("test") + + def _get_dataloader(self, phase: str) -> DataLoader: + assert self.datasets[phase] is not None, "Failed to get the {} dataset!".format( + phase + ) + + if isinstance(self.datasets[phase], torch.utils.data.IterableDataset): + return torch.utils.data.DataLoader( + self.datasets[phase], + batch_size=self.config[phase].batch_size, + num_workers=self.config[phase].num_workers, + pin_memory=self.config[phase].pin_memory, + drop_last=self.config[phase].drop_last, + collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn), + worker_init_fn=hydra.utils.instantiate( + self.config[phase].worker_init_fn + ), + ) + else: + sampler = None + if torch.distributed.is_available() and torch.distributed.is_initialized(): + logging.info( + "Distributed Environmnet detected, using DistributedSampler for dataloader." + ) + sampler = DistributedSampler(self.datasets[phase]) + + return torch.utils.data.DataLoader( + self.datasets[phase], + batch_size=self.config[phase].batch_size, + num_workers=self.config[phase].num_workers, + pin_memory=self.config[phase].pin_memory, + drop_last=self.config[phase].drop_last, + sampler=sampler, + shuffle=(False if sampler else self.config[phase].shuffle), + collate_fn=hydra.utils.instantiate(self.config[phase].collate_fn), + worker_init_fn=hydra.utils.instantiate( + self.config[phase].worker_init_fn + ), + ) + + def _get_dataset( + self, + phase: str, + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + ) -> pytorchvideo.data.LabeledVideoDataset: + + video_sampler = RandomSampler + if torch.distributed.is_available() and torch.distributed.is_initialized(): + logging.info( + "Distributed Environmnet detected, using DistributedSampler for dataset." + ) + video_sampler = DistributedSampler + + dataset = hydra.utils.instantiate( + self.config[phase].dataset, + transform=transforms, + video_sampler=video_sampler, + ) + return dataset + + +@dataclass +class PhaseDataLoaderConf: + + num_workers: int = 0 + pin_memory: bool = False + drop_last: bool = False + batch_size: int = MISSING + shuffle: bool = True + + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + collate_fn: Optional[Any] = None + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + worker_init_fn: Optional[Any] = None + + ## Dataset Related + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + dataset: Any = MISSING + + +@dataclass +class DataLoaderConf: + train: PhaseDataLoaderConf = MISSING + val: PhaseDataLoaderConf = MISSING + test: PhaseDataLoaderConf = MISSING + + +@dataclass +class TransformsConf: + + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + train: List[Any] = MISSING + + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + val: List[Any] = MISSING + + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + test: List[Any] = MISSING + + +@dataclass +class VideoClassificationDataModuleConf(DataModuleConf): + _target_: str = get_class_name_str(PyTorchVideoDataModule) + + dataloader: DataLoaderConf = MISSING + transforms: TransformsConf = MISSING + + +cs = ConfigStore() + +cs.store( + group="schema/datamodule", + name="ptv_video_classification_data_module_conf", + node=VideoClassificationDataModuleConf, + package="datamodule", +) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py new file mode 100644 index 00000000..ecae54fd --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/rand_erase_transform.py @@ -0,0 +1,196 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +COMMENT FROM ORIGINAL: +Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 +Copyright Zhun Zhong & Liang Zheng +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import random +from typing import Optional, Tuple + +import torch + + +def _get_pixels( + per_pixel: bool, + rand_color: bool, + patch_size: Tuple[int], + dtype: torch.dtype = torch.float32, + device: str = "cuda", +) -> torch.Tensor: + """ + A utility function that generates image patches for RandomErasing transform + """ + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """ + This variant of RandomErasing is intended to be applied to a video tensor i.e, + batch of images after it has been normalized by dataset mean and std. + + Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + Args: + probability (float): Probability that the Random Erasing operation will be performed. + min_area (float): Minimum percentage of erased area wrt input image area. + max_area (float): Maximum percentage of erased area wrt input image area. + min_aspect (float): Minimum aspect ratio of erased area. + mode (str): pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count (int): maximum number of erasing blocks per image, area per box is scaled by + count. Per-image count is randomly chosen between 1 and this value. + min_count (int): minimum number of erasing blocks per image, area per box is scaled by + count. Per-image count is randomly chosen between 1 and this value. + device (str): Device to perform the transform on. + """ + + def __init__( + self, + probability: float = 0.5, + min_area: float = 0.02, + max_area: float = 1 / 3, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None, + mode: str = "const", + min_count: int = 1, + max_count: Optional[int] = None, + num_splits: int = 0, + device: str = "cuda", + cube: bool = True, + ) -> None: + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio: Tuple[float, float] = ( + math.log(min_aspect), + math.log(max_aspect), + ) + self.min_count = min_count + self.max_count: int = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color: bool = False + self.per_pixel: bool = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase( + self, img: torch.Tensor, chan: int, height: int, width: int, dtype: torch.dtype + ) -> None: + if random.random() > self.probability: + return + area = height * width + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < width and h < height: + top = random.randint(0, height - h) + left = random.randint(0, width - w) + img[:, top : top + h, left : left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), # pyre-ignore[6] + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + video: torch.Tensor, + batch_start: int, + batch_size: int, + chan: int, + height: int, + width: int, + dtype: torch.dtype, + ) -> None: + if random.random() > self.probability: + return + area = height * width + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < width and h < height: + top = random.randint(0, height - h) + left = random.randint(0, width - w) + for i in range(batch_start, batch_size): + img_instance = video[i] + img_instance[:, top : top + h, left : left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), # pyre-ignore[6] + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, frames: torch.Tensor) -> torch.Tensor: + """ + Args: + frames (tensor): frames of images sampled from the video. The + dimension is `channel` x `num frames` x `height` x `width`. + Returns: + frames (tensor): frames of images sampled from the video. The + dimension is `channel` x `num frames` x `height` x `width`. + """ + # Expects frames of shape T, C, H, W + batch_size, chan, height, width = frames.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + if self.cube: + self._erase_cube( + frames, + batch_start, + batch_size, + chan, + height, + width, + frames.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(frames[i], chan, height, width, frames.dtype) + return frames diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py new file mode 100644 index 00000000..effb3fde --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/datamodule/transforms.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import random +from typing import Iterable, Sequence, Any, Callable, Dict, Mapping, List + +import hydra +import torch +import torchvision +from PIL import Image, ImageFilter +from pytorchvideo.transforms import MixVideo +from torchvision.transforms import Compose + + +def build_transforms(transforms_config: Iterable[Mapping[str, Any]]) -> Compose: + """ + A utility function to build data transforsm from a list of Hydra/Omega Conf + objects. This utility method is called by + `pytorchvideo_trainer.datamodule.PyTorchVideoDataModule` class to build a + sequence of transforms applied during each phase(train, val and test). + + Uses torchvision.transforms.Compose to build a seuquence of transforms. + + Examples of config objects used by this method can be found in, + `pytorchvide_trainer/conf/datamodule/transforms/` + + Args: + transforms_config: A list of hydra config objects wherein, each element + represents config associated with a single transforms. + + An example of this would be, + ``` + - _target_: pytorchvideo.transforms.ApplyTransformToKey + transform: + - _target_: pytorchvideo.transforms.UniformTemporalSubsample + num_samples: 16 + - _target_: pytorchvideo.transforms.Div255 + - _target_: pytorchvideo.transforms.Normalize + mean: [0.45, 0.45, 0.45] + std: [0.225, 0.225, 0.225] + - _target_: pytorchvideo.transforms.ShortSideScale + size: 224 + key: video + - _target_: pytorchvideo.transforms.UniformCropVideo + size: 224 + - _target_: pytorchvideo.transforms.RemoveKey + key: audio + ``` + """ + transform_list = [build_single_transform(config) for config in transforms_config] + transform = Compose(transform_list) + return transform + + +def build_single_transform(config: Mapping[str, Any]) -> Callable[..., object]: + """ + A utility method to build a single transform from hydra / omega conf objects. + + If the key "transform" is present in the give config, it recursively builds + and composes transforms using the `torchvision.transforms.Compose` method. + """ + config = dict(config) + if "transform" in config: + assert isinstance(config["transform"], Sequence) + transform_list = [ + build_single_transform(transform) for transform in config["transform"] + ] + transform = Compose(transform_list) + config.pop("transform") + return hydra.utils.instantiate(config, transform=transform) + return hydra.utils.instantiate(config) + + +class ApplyTransformToKeyOnList: + """ + Applies transform to key of dictionary input wherein input is a list + + Args: + key (str): the dictionary key the transform is applied to + transform (callable): the transform that is applied + + Example: + >>> transforms.ApplyTransformToKeyOnList( + >>> key='video', + >>> transform=UniformTemporalSubsample(num_video_samples), + >>> ) + """ + + def __init__(self, key: str, transform: Callable) -> None: # pyre-ignore[24] + self._key = key + self._transform = transform + + def __call__( + self, x: Dict[str, List[torch.Tensor]] + ) -> Dict[str, List[torch.Tensor]]: + x[self._key] = [self._transform(a) for a in x[self._key]] + return x + + +class SlowFastPackPathway: + """ + Transform for converting a video clip into a list of 2 clips with + different temporal granualirity as needed by the SlowFast video + model. + + For more details, refere to the paper, + Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. + "SlowFast networks for video recognition." + https://arxiv.org/pdf/1812.03982.pdf + + Args: + alpha (int): Number of frames to sub-sample from the given clip + to create the second clip. + """ + + def __init__(self, alpha: int) -> None: + super().__init__() + self.alpha = alpha + + def __call__(self, frames: torch.Tensor) -> List[torch.Tensor]: + """ + Args: + frames (tensor): frames of images sampled from the video. The + dimension is `channel` x `num frames` x `height` x `width`. + Returns: + frame_list (list): list of tensors with the dimension of + `channel` x `num frames` x `height` x `width`. + """ + fast_pathway = frames + # Perform temporal sampling from the fast pathway. + slow_pathway = torch.index_select( + frames, + 1, + torch.linspace( + 0, frames.shape[1] - 1, frames.shape[1] // self.alpha + ).long(), + ) + frame_list = [slow_pathway, fast_pathway] + return frame_list + + +class RepeatandConverttoList: + """ + An utility transform that repeats each value in a + key, value-style minibatch and replaces it with a list of values. + + Useful for performing multiple augmentations. + An example such usecase can be found in + `pytorchvideo_trainer/conf/datamodule/transforms/kinetics_classification_mvit_16x4.yaml` + + Args: + repead_num (int): Number of times to repeat each value. + """ + + def __init__(self, repeat_num: int) -> None: + super().__init__() + self.repeat_num = repeat_num + + # pyre-ignore[3] + def __call__(self, sample_dict: Dict[str, Any]) -> Dict[str, List[Any]]: + for k, v in sample_dict.items(): + sample_dict[k] = self.repeat_num * [v] + return sample_dict + + +class MixVideoBatchWrapper: + def __init__( + self, + mixup_alpha: float, + cutmix_prob: float, + cutmix_alpha: float, + label_smoothing: float, + ) -> None: + """ + A wrapper for MixVideo (CutMix or Mixup) tranform in pytorchvideo.transforms. + Extends the MixVideo transform to work on a batch dictionary objects. + + The dictionary object should consist of keys "video" and "label" representing + video clips and their associated labels. + """ + + self.mix_video_transform = MixVideo( + mixup_alpha=mixup_alpha, + cutmix_prob=cutmix_prob, + cutmix_alpha=cutmix_alpha, + label_smoothing=label_smoothing, + ) + + def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: + + batch["video"], batch["label"] = self.mix_video_transform( + batch["video"], batch["label"] + ) + return batch + + +class ColorJitterVideoSSl: + """ + A custom sequence of transforms that randomly performs Color jitter, + Gaussian Blur and Grayscaling on the given clip. + + Particularly useful for the SSL tasks like SimCLR, MoCoV2, BYOL, etc. + + Args: + bri_con_sat (list[float]): A list of 3 floats reprsenting brightness, + constrast and staturation coefficients to use for the + `torchvision.transforms.ColorJitter` transform. + hue (float): Heu value to use in the `torchvision.transforms.ColorJitter` + transform. + p_color_jitter (float): The probability with which the Color jitter transform + is randomly applied on the given clip. + p_convert_gray (float): The probability with which the given clip is randomly + coverted into grayscale. + p_gaussian_blur (float): The probability with which the Gaussian transform + is randomly applied on the given clip. + gaussian_blur_sigma (list[float]): A list of 2 floats with in which + the blur radius is randomly sampled for Gaussian blur transform. + """ + + def __init__( + self, + bri_con_sat: List[float], + hue: float, + p_color_jitter: float, + p_convert_gray: float, + p_gaussian_blur: float = 0.5, + gaussian_blur_sigma: List[float] = (0.1, 2.0), + ) -> None: + + self.color_jitter = torchvision.transforms.Compose( + [ + torchvision.transforms.ToPILImage(), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.ColorJitter( + bri_con_sat[0], bri_con_sat[1], bri_con_sat[2], hue + ) + ], + p=p_color_jitter, + ), + torchvision.transforms.RandomGrayscale(p=p_convert_gray), + torchvision.transforms.RandomApply( + [GaussianBlur(gaussian_blur_sigma)], p=p_gaussian_blur + ), + torchvision.transforms.ToTensor(), + ] + ) + + def __call__(self, frames: torch.Tensor) -> torch.Tensor: + """ + Args: + frames (tensor): frames of images sampled from the video. The + dimension is `channel` x `num frames` x `height` x `width`. + Returns: + frames (tensor): frames of images sampled from the video. The + dimension is `channel` x `num frames` x `height` x `width`. + """ + c, t, h, w = frames.shape + frames = frames.view(c, t * h, w) + frames = self.color_jitter(frames) # pyre-ignore[6,9] + frames = frames.view(c, t, h, w) + + return frames + + +class GaussianBlur(object): + """ + A PIL image version of Gaussian blur augmentation as + in SimCLR https://arxiv.org/abs/2002.05709 + + Args: + sigma (list[float]): A list of 2 floats with in which + the blur radius is randomly sampled during each step. + """ + + def __init__(self, sigma: List[float] = (0.1, 2.0)) -> None: + self.sigma = sigma + + def __call__(self, img: Image.Image) -> Image.Image: + """ + img (Image): A PIL image with single or 3 color channels. + """ + sigma = self.sigma[0] + if len(self.sigma) == 2: + sigma = random.uniform(self.sigma[0], self.sigma[1]) + + img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) + return img diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py new file mode 100644 index 00000000..6f07c806 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from .byol import BYOLModule # noqa +from .moco_v2 import MOCOV2Module # noqa +from .simclr import SimCLRModule # noqa +from .video_classification import VideoClassificationModule # noqa + +__all__ = [ + "VideoClassificationModule", + "SimCLRModule", + "BYOLModule", + "MOCOV2Module", +] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py new file mode 100644 index 00000000..550ba6d2 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/byol.py @@ -0,0 +1,329 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from dataclasses import dataclass +from typing import Optional, List, Callable, Union, Any, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING +from pytorchvideo.models.resnet import create_resnet +from pytorchvideo.models.weight_init import init_net_weights +from pytorchvideo_trainer.module.ssl_helper import SSLBaseModule, create_mlp_util +from pytorchvideo_trainer.module.video_classification import ( + EnsembleMethod, + BatchKey, + Batch, +) +from torchrecipes.core.conf import ModuleConf +from torchrecipes.utils.config_utils import get_class_name_str + + +class BYOL(nn.Module): + """ + Bootstrap Your Own Latent A New Approach to Self-Supervised Learning + Details can be found in: + https://arxiv.org/pdf/2006.07733.pdf + """ + + def __init__( + self, + mmt: float, + backbone: nn.Module, + predictor: nn.Module, + backbone_mmt: nn.Module, + projector: Optional[nn.Module] = None, + projector_mmt: Optional[nn.Module] = None, + ) -> None: + """ + Args: + backbone (nn.Module): backbone for byol, input shape depends on the forward + input size. Standard inputs include `B x C`, `B x C x H x W`, and + `B x C x T x H x W`. + projector (nn.Module): An mlp with 2 to 3 hidden layers, + with (synchronized) BatchNorm and ReLU activation. + backbone_mmt (nn.Module): backbone for byol, input shape depends on the forward + input size. Standard inputs include `B x C`, `B x C x H x W`, and + `B x C x T x H x W`. + projector_mmt (nn.Module): Am mlp with 2 to 3 hidden layers, + with (synchronized) BatchNorm and ReLU activation. + predictor (nn.Module): predictor MLP of BYOL of similar structure as the + projector MLP. + mmt (float): momentum update ratio for the momentum backbone. + """ + super().__init__() + + self.mmt: float = mmt + if projector is not None: + backbone = nn.Sequential( + backbone, + projector, + ) + init_net_weights(backbone) + self.backbone = backbone + + if projector_mmt is not None: + backbone_mmt = nn.Sequential( + backbone_mmt, + projector_mmt, + ) + init_net_weights(backbone_mmt) + self.backbone_mmt = backbone_mmt + + for p in self.backbone_mmt.parameters(): + p.requires_grad = False + + init_net_weights(predictor) + self.predictor = predictor + + self._copy_weights_to_backbone_mmt() + + def _copy_weights_to_backbone_mmt(self) -> None: + dist = {} + for name, p in self.backbone.named_parameters(): + dist[name] = p + for name, p in self.backbone_mmt.named_parameters(): + p.data.copy_(dist[name].data) + + @torch.no_grad() + def momentum_update_backbone(self) -> None: + """ + Momentum update on the backbone. + """ + m = self.mmt + dist = {} + for name, p in self.backbone.named_parameters(): + dist[name] = p + for name, p in self.backbone_mmt.named_parameters(): + # pyre-ignore[41] + p.data = dist[name].data * (1.0 - m) + p.data * m + + @torch.no_grad() + def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward momentum backbone. + Args: + x (tensor): input to be forwarded of shape N x C x T x H x W + """ + with torch.no_grad(): + proj = self.backbone_mmt(x) + return F.normalize(proj, dim=1) + + def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (tensor): input to be forwarded of shape N x C x T x H x W + """ + if not self.training: + x = self.backbone(x) + x = F.normalize(x, dim=1) + return x + + proj = self.backbone(x) + pred = self.predictor(proj) + pred = F.normalize(pred, dim=1) + + out_proj = F.normalize(proj, dim=1) + + return out_proj, pred # pyre-ignore[7] + + +def create_byol_resnet_50( + # Backbone + backbone_creator: Callable = create_resnet, # pyre-ignore[24] + backbone_embed_dim: int = 128, + head_pool: Callable = nn.AdaptiveAvgPool3d, # pyre-ignore[24] + head_output_size: Tuple[int, int, int] = (1, 1, 1), + head_activation: Callable = None, # pyre-ignore[9,24] + dropout_rate: float = 0.0, + # Projector + projector_dim_in: int = 2048, + projector_inner_dim: int = 4096, + projector_depth: int = 2, + # Predictor + predictor_inner_dim: int = 4096, + predictor_depth: int = 2, + predictor_norm: Callable = nn.BatchNorm1d, # pyre-ignore[24] + projector_norm: Callable = nn.BatchNorm1d, # pyre-ignore[24] + mmt: float = 0.99, +) -> BYOL: + """ + Builds a Resnet video backbone, projector and predictors models for + BYOL SSL task. + """ + + def _make_bacbone_and_projector(): # pyre-ignore[3] + backbone = backbone_creator( + dropout_rate=dropout_rate, + head_activation=head_activation, + head_output_with_global_average=True, + head_pool=head_pool, + head_output_size=head_output_size, + ) + + backbone.blocks[-1].proj = None # Overwite head projection + projector = create_mlp_util( + projector_dim_in, + backbone_embed_dim, + projector_inner_dim, + projector_depth, + norm=projector_norm, + ) + return backbone, projector + + backbone, projector = _make_bacbone_and_projector() + backbone_mmt, projector_mmt = _make_bacbone_and_projector() + + predictor = create_mlp_util( + backbone_embed_dim, + backbone_embed_dim, + predictor_inner_dim, + predictor_depth, + norm=predictor_norm, + ) + byol_model = BYOL( + mmt=mmt, + backbone=backbone, + projector=projector, + predictor=predictor, + backbone_mmt=backbone_mmt, + projector_mmt=projector_mmt, + ) + return byol_model + + +class BYOLModule(SSLBaseModule): + """ + The Lightning Base module for BYOL SSL video task. + + For more details refer to, + 1. Bootstrap your own latent: A new approach to self-supervised Learning: + https://arxiv.org/abs/2006.07733 + 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning + + Args: + model (OmegaConf): An omega conf object intializing the neural-network modle. + Example configs can be found in `pytorchvideo_trainer/conf/module/model` + loss(OmegaConf): An omega conf object intializing the loss function. + Example configs can be found in `pytorchvideo_trainer/conf/module/loss` + optim (OmegaConf): An omega conf object for constructing the optimizer object. + The associated config schema can be found at + `pytorchvideo_trainer.module.optimizer.OptimizerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/optim` + metrics (OmegaConf): The metrics to track, which will be used for both train, + validation and test. Example configs can be found in + `pytorchvideo_trainer/conf/module/metricx` + lr_scheduler (OmegaConf): An omega conf object associated with learning rate + scheduler used during trainer. + The associated config schema can be found at + `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler` + modality_key (str): The modality key used in data processing, default: "video". + ensemble_method (str): The data ensembling method to control how we accumulate + the testing results at video level, which is optional. Users may choose from + ["sum", "max", None], If it is set to None, no data ensembling will be applied. + knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN + Memory module to use. Example config can be found at, + `pytorchvideo_trainer/conf/module/knn_memory`. + momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine + anneals the momentum term used from updating the backbone-history model. + num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if + pytorch lightning trainer's sync_batchnorm parameter is to false. + """ + + def __init__( + self, + model: Any, # pyre-ignore[2] + loss: Any, # pyre-ignore[2] + optim: Any, # pyre-ignore[2] + metrics: List[Any], # pyre-ignore[2] + lr_scheduler: Optional[Any] = None, # pyre-ignore[2] + modality_key: BatchKey = "video", + ensemble_method: Optional[EnsembleMethod] = None, + knn_memory: Optional[Any] = None, # pyre-ignore[2] + momentum_anneal_cosine: bool = False, + num_sync_devices: int = 1, + ) -> None: + super().__init__( + model=model, + loss=loss, + optim=optim, + metrics=metrics, + lr_scheduler=lr_scheduler, + modality_key=modality_key, + ensemble_method=ensemble_method, + knn_memory=knn_memory, + momentum_anneal_cosine=momentum_anneal_cosine, + num_sync_devices=num_sync_devices, + ) + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> None: + self.cur_epoch_step += 1 # pyre-ignore[16] + + if self.momentum_anneal_cosine: + self._cosine_anneal_momentum() + + self.manual_zero_opt_grad() + self.manual_update_lr() + + inputs = batch[self.modality_key] # pyre-ignore[6] + + self.model.momentum_update_backbone() # pyre-ignore[29] + keys = self._compute_keys(inputs) + + partial_loss = 0.0 + for k, vids in enumerate(inputs): + other_keys = keys[:k] + keys[k + 1 :] + assert len(other_keys) > 0, "Length of keys cannot be zero" + + proj, pred = self.model(vids) + loss_k = self.loss(pred, other_keys[0]) + for i in range(1, len(other_keys)): + loss_k += self.loss(pred, other_keys[i]) + loss_k /= len(other_keys) + + self.manual_backward(loss_k) + partial_loss += loss_k.detach() + + if self.knn_memory is not None: + self.knn_memory.update(proj, batch["video_index"]) # pyre-ignore[29,61] + + partial_loss /= len(inputs) * 2.0 # to have same loss as symmetric loss + self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True) + + self.manual_opt_step() + + @torch.no_grad() + def _compute_keys(self, x: torch.Tensor) -> List[torch.Tensor]: + keys = [] + for sub_x in x: + # pyre-ignore[29] + keys.append(self.model.forward_backbone_mmt(sub_x).detach()) + return keys + + +@dataclass +class BYOLModuleConf(ModuleConf): + _target_: str = get_class_name_str(BYOLModule) + model: Any = MISSING # pyre-ignore[4] + loss: Any = MISSING # pyre-ignore[4] + optim: Any = MISSING # pyre-ignore[4] + metrics: List[Any] = MISSING # pyre-ignore[4] + lr_scheduler: Optional[Any] = None # pyre-ignore[4] + modality_key: str = "video" + ensemble_method: Optional[str] = None + num_sync_devices: Optional[int] = 1 + knn_memory: Optional[Any] = None # pyre-ignore[4] + momentum_anneal_cosine: bool = False + + +cs = ConfigStore() +cs.store( + group="schema/module", + name="byol_module_conf", + node=BYOLModuleConf, + package="module", +) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py new file mode 100644 index 00000000..1dcc8acb --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/distributed_utils.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Distributed helpers.""" +import functools +import logging +import pickle +from typing import Optional, List, Tuple, Any, TypeVar + +import torch +import torch.distributed as dist + +DistProcessGroup = TypeVar("ProcessGroup") + + +def all_gather(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """ + All gathers the provided tensors from all processes across machines. + + Args: + tensors (list): tensors to perform all gather across all processes in + all machines. + """ + + gather_list = [] + output_tensor = [] + world_size = dist.get_world_size() + for tensor in tensors: + tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_placeholder, tensor, async_op=False) + gather_list.append(tensor_placeholder) + for gathered_tensor in gather_list: + output_tensor.append(torch.cat(gathered_tensor, dim=0)) + return output_tensor + + +def cat_all_gather( + tensors: torch.Tensor, process_group: Optional[DistProcessGroup] = None +) -> torch.Tensor: + """ + Performs the concatenated all_reduce operation on the provided tensors. + """ + if process_group is not None: + gather_sz = get_process_group_size(process_group) + else: + gather_sz = dist.get_world_size() + tensors_gather = [torch.ones_like(tensors) for _ in range(gather_sz)] + dist.all_gather( + tensors_gather, + tensors, + async_op=False, + group=process_group, + ) + output = torch.cat(tensors_gather, dim=0) + return output + + +def all_reduce(tensors: List[torch.Tensor], average: bool = True) -> List[torch.Tensor]: + """ + All reduce the provided tensors from all processes across machines. + + Args: + tensors (list): tensors to perform all reduce across all processes in + all machines. + average (bool): scales the reduced tensor by the number of overall + processes across all machines. + """ + + for tensor in tensors: + dist.all_reduce(tensor, async_op=False) + if average: + world_size = dist.get_world_size() + for tensor in tensors: + tensor.mul_(1.0 / world_size) + return tensors + + +def init_process_group( + local_rank: int, + local_world_size: int, + shard_id: int, + num_shards: int, + init_method: str, + dist_backend: str = "nccl", +) -> None: + """ + Initializes the default process group. + + Args: + local_rank (int): the rank on the current local machine. + local_world_size (int): the world size (number of processes running) on + the current local machine. + shard_id (int): the shard index (machine rank) of the current machine. + num_shards (int): number of shards for distributed training. + init_method (string): supporting three different methods for + initializing process groups: + "file": use shared file system to initialize the groups across + different processes. + "tcp": use tcp address to initialize the groups across different + dist_backend (string): backend to use for distributed training. Options + includes gloo, mpi and nccl, the details can be found here: + https://pytorch.org/docs/stable/distributed.html + """ + # Sets the GPU to use. + torch.cuda.set_device(local_rank) + # Initialize the process group. + proc_rank = local_rank + shard_id * local_world_size + world_size = local_world_size * num_shards + dist.init_process_group( + backend=dist_backend, + init_method=init_method, + world_size=world_size, + rank=proc_rank, + ) + + +def get_world_size() -> int: + """ + Get the size of the world. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + """ + Get the rank of the current process. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize() -> None: + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group() -> List[int]: + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + + Returns: + (group): pytorch dist group. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +# pyre-ignore [2] +def _serialize_to_tensor(data: Any, group: List[int]) -> torch.Tensor: + """ + Seriialize the tensor to ByteTensor. Note that only `gloo` and `nccl` + backend is supported. + + Args: + data (data): data to be serialized. + group (group): pytorch dist group. + Returns: + tensor (ByteTensor): tensor that serialized. + """ + + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor( + tensor: torch.Tensor, group: List[int] +) -> Tuple[List[int], torch.Tensor]: + """ + Padding all the tensors from different GPUs to the largest ones. + + Args: + tensor (tensor): tensor to pad. + group (group): pytorch dist group. + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +# pyre-ignore [2,3] +def all_gather_unaligned(data: Any, group: Optional[List[int]] = None) -> List[Any]: + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def get_process_group_size(process_group: DistProcessGroup) -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=process_group) + + +def get_local_rank(process_group: DistProcessGroup) -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + + return dist.get_rank(group=process_group) + + +class AllGatherWithGradient(torch.autograd.Function): + """ + Support distributed all_gather for any arbitrary tensor while + preserving its gradient. + """ + + @staticmethod + # pyre-ignore [2,14] + def forward(ctx: Any, input: torch.Tensor) -> torch.Tensor: + world_size = dist.get_world_size() + x_gather = [torch.ones_like(input) for _ in range(world_size)] + dist.all_gather(x_gather, input, async_op=False) + x_gather = torch.cat(x_gather, dim=0) + return x_gather + + @staticmethod + # pyre-ignore [2,14] + def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: + + reduction = dist.all_reduce(grad_output, async_op=True) + reduction.wait() + + world_size = dist.get_world_size() + N = grad_output.size(0) + mini_batchsize = N // world_size + cur_gpu = dist.get_rank() + grad_output = grad_output[ + cur_gpu * mini_batchsize : (cur_gpu + 1) * mini_batchsize + ] + return grad_output diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py new file mode 100644 index 00000000..d3713072 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/losses.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from typing import List + +import pytorchvideo_trainer.module.distributed_utils as du +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorchvideo.layers.utils import set_attributes + + +class SoftTargetCrossEntropy(nn.Module): + """ + Cross entropy loss with soft target. + """ + + def __init__(self, reduction: str = "mean") -> None: + """ + Args: + reduction (str): specifies reduction to apply to the output. + It can be "mean" (default) or "none". + """ + super(SoftTargetCrossEntropy, self).__init__() + self.reduction = reduction + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + + loss = torch.sum(-y * F.log_softmax(x, dim=-1), dim=-1) + if self.reduction == "mean": + return loss.mean() + elif self.reduction == "none": + return loss + else: + raise NotImplementedError + + +class NtxentLoss(nn.Module): + """ + NT-Xent loss for SimCLR Self-Supervised learning approach - + https://arxiv.org/abs/2002.05709 + + Args: + temperature (float): scalar value to scale the loss by. + """ + + def __init__( + self, + temperature: float, + ) -> None: + super().__init__() + set_attributes(self, locals()) # pyre-ignore[6] + + def forward(self, x_list: List[torch.Tensor]) -> torch.Tensor: + """ + Args: + x_list (list[torch.tensor]): A list of two tensors of shape N x C. + Where, N is the batch size and C is the SSL model's embedding size. + """ + assert ( + len(x_list) == 2 + ), f"Invalid list input to SimCLR. Expected dimention 2 but received {len(x_list)}" + + out_1, out_2 = x_list + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + out_1 = du.AllGatherWithGradient.apply(out_1) # pyre-ignore[16] + out_2 = du.AllGatherWithGradient.apply(out_2) + out = torch.cat([out_1, out_2], dim=0) + # [2*B, 2*B] + sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / self.temperature) + mask = ( + torch.ones_like(sim_matrix) + - torch.eye(out.shape[0], device=sim_matrix.device) + ).bool() + # [2*B, 2*B-1] + sim_matrix = sim_matrix.masked_select(mask).view(out.shape[0], -1) + # compute loss + pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature) + # [2*B] + pos_sim = torch.cat([pos_sim, pos_sim], dim=0) + loss = (-torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean() + + return loss + + +class SimilarityLoss(nn.Module): + """ + Temperature-scaled Similarity loss for BYOL Self-Supervised learning + approach - https://arxiv.org/abs/2006.07733 + + Args: + temperature (float): scalar value to scale the loss by. + """ + + def __init__(self, temperature: float) -> None: + super().__init__() + self.temperature = temperature + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + """ + Args: + q and k (nn.tensor): inputs to calculate the similarity, expected to have + the same shape of `N x C`. Where N is the batch size and C + is the SSL model's embedding size. + """ + similarity = torch.einsum("nc,nc->n", [q, k]) + similarity /= self.temperature + loss = -similarity.mean() + return loss + + +class ContrastiveLoss(nn.Module): + """ + Temperature-scaled Contrastive loss for MoCo and other Self-Supervised learning + approaches - https://arxiv.org/abs/1911.05722 + + Args: + temperature (float): scalar value to scale the loss by. + """ + + def __init__(self, reduction: str = "mean", temperature: float = 0.1) -> None: + super(ContrastiveLoss, self).__init__() + self.reduction = reduction + self.temperature = temperature + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs (nn.tensor): Expected to have the same shape of `N x C`. + Where, N is the batch size and C is the SSL model's embedding size. + """ + inputs = torch.div(inputs, self.temperature) + targets = torch.zeros(inputs.shape[0], dtype=torch.long).to(inputs.device) + loss = nn.CrossEntropyLoss(reduction=self.reduction).cuda()(inputs, targets) + return loss diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py new file mode 100644 index 00000000..62ef9f5d --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/lr_policy.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Learning rate policy.""" +import math +from dataclasses import dataclass +from typing import Callable, List + +import torch +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclass +class LRSchedulerConf: + # common + lr_policy: str = MISSING + lr: float = MISSING + max_iters: int = MISSING + warmup_iters: int = MISSING + warmup_start_lr: float = MISSING + + # cosine + cosine_end_lr: float = MISSING + cosine_after_warmup: bool = MISSING + + # LRS + steps: List[int] = MISSING + lrs: List[float] = MISSING + + +cs = ConfigStore() +cs.store( + group="schema/module/lr_scheduler", + name="lr_scheduler_conf", + node=LRSchedulerConf, + package="module.lr_scheduler", +) + + +def get_lr_at_epoch(cfg: LRSchedulerConf, cur_epoch: float) -> float: + """ + Retrieve the learning rate of the current epoch with the option to perform + warm up in the beginning of the training stage. + + Args: + cfg (LRSchedulerConf): Hydra / omega conf object associated with + Learningrate scheduler. The schema can be found in + `LRSchedulerConf` and the example configs can be found in + `pytorchvideo_trainer/conf/module/lr_scheduler`. + cur_epoch (float): the number of epoch of the current training stage. + """ + lr = get_lr_func(cfg.lr_policy)(cfg, cur_epoch) + # Perform warm up. + if cur_epoch < cfg.warmup_iters: + lr_start = cfg.warmup_start_lr + lr_end = get_lr_func(cfg.lr_policy)(cfg, cfg.warmup_iters) + alpha = (lr_end - lr_start) / cfg.warmup_iters + lr = cur_epoch * alpha + lr_start + return lr + + +def lr_func_cosine(cfg: LRSchedulerConf, cur_epoch: float) -> float: + """ + Retrieve the learning rate to specified values at specified epoch with the + cosine learning rate schedule. Details can be found in: + Ilya Loshchilov, and Frank Hutter ,SGDR: Stochastic Gradient + Descent With Warm Restarts. + + Args: + cfg (CfgNode): Hydra / omega conf object associated with + Learningrate scheduler. The schema can be found in + `LRSchedulerConf` and the example configs can be found in + `pytorchvideo_trainer/conf/module/lr_scheduler`. + cur_epoch (float): the number of epoch of the current training stage. + """ + offset = cfg.warmup_iters if cfg.cosine_after_warmup else 0.0 + assert cfg.cosine_end_lr < cfg.lr + return ( + cfg.cosine_end_lr + + (cfg.lr - cfg.cosine_end_lr) + * (math.cos(math.pi * (cur_epoch - offset) / (cfg.max_iters - offset)) + 1.0) + * 0.5 + ) + + +def lr_func_steps_with_relative_lrs(cfg: LRSchedulerConf, cur_epoch: float) -> float: + """ + Retrieve the learning rate to specified values at specified epoch with the + steps with relative learning rate schedule. + + Args: + cfg (CfgNode): configs. Hydra / omega conf object associated with + Learningrate scheduler. The schema can be found in + `LRSchedulerConf` and the example configs can be found in + `pytorchvideo_trainer/conf/module/lr_scheduler`. + cur_epoch (float): the number of epoch of the current training stage. + """ + ind = get_step_index(cfg, cur_epoch) + return cfg.lrs[ind] * cfg.lr + + +def get_step_index(cfg: LRSchedulerConf, cur_epoch: float) -> int: + """ + Retrieves the lr step index for the given epoch. + + Args: + cfg (CfgNode): Hydra / omega conf object associated with + Learningrate scheduler. The schema can be found in + `LRSchedulerConf` and the example configs can be found in + `pytorchvideo_trainer/conf/module/lr_scheduler`. + cur_epoch (float): the number of epoch of the current training stage. + """ + steps = cfg.steps + [cfg.max_iters] + for ind, step in enumerate(steps): # NoQA + if cur_epoch < step: + break + return ind - 1 + + +def get_lr_func(lr_policy: str) -> Callable: # pyre-ignore[24] + """ + Given the configs, retrieve the specified lr policy function. + + Args: + lr_policy (string): the learning rate policy to use for the job. + """ + policy = "lr_func_" + lr_policy + if policy not in globals(): + raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) + else: + return globals()[policy] + + +def get_epoch_lr(cur_epoch: float, cfg: LRSchedulerConf) -> float: + """ + Retrieves the lr for the given epoch (as specified by the lr policy). + + Args: + cfg (config): Hydra / omega conf object associated with + Learningrate scheduler. The schema can be found in + `LRSchedulerConf` and the example configs can be found in + `pytorchvideo_trainer/conf/module/lr_scheduler`. + cur_epoch (float): the number of epoch of the current training stage. + """ + return get_lr_at_epoch(cfg, cur_epoch) + + +def set_lr(optimizer: torch.optim.Optimizer, new_lr: float) -> None: + """ + Sets the optimizer lr to the specified value. + Args: + optimizer (optim): the optimizer using to optimize the current network. + new_lr (float): the new learning rate to set. + """ + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py new file mode 100644 index 00000000..ef225ad8 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/moco_v2.py @@ -0,0 +1,456 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import math +from dataclasses import dataclass +from typing import Optional, List, Callable, Union, Any, Tuple + +import pytorchvideo_trainer.module.distributed_utils as du +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING +from pytorchvideo.models.resnet import create_resnet +from pytorchvideo.models.weight_init import init_net_weights +from pytorchvideo_trainer.module.ssl_helper import SSLBaseModule, create_mlp_util +from pytorchvideo_trainer.module.video_classification import ( + EnsembleMethod, + BatchKey, + Batch, +) +from torchrecipes.core.conf import ModuleConf +from torchrecipes.utils.config_utils import get_class_name_str + + +def create_moco_resnet_50( + # Backbone + backbone_creator: Callable = create_resnet, # pyre-ignore[24] + backbone_embed_dim: int = 128, + head_pool: Callable = nn.AdaptiveAvgPool3d, # pyre-ignore[24] + head_output_size: Tuple[int, int, int] = (1, 1, 1), + head_activation: Callable = None, # pyre-ignore[9,24] + dropout_rate: float = 0.0, + # Projector + projector_dim_in: int = 2048, + projector_inner_dim: int = 2048, + projector_depth: int = 3, + projector_norm: Optional[Callable] = None, # pyre-ignore[24] + mmt: float = 0.994, +) -> nn.Module: + def _make_bacbone_and_projector(): # pyre-ignore[3] + backbone = backbone_creator( + dropout_rate=dropout_rate, + head_activation=head_activation, + head_output_with_global_average=True, + head_pool=head_pool, + head_output_size=head_output_size, + stem_conv_kernel_size=(1, 7, 7), + head_pool_kernel_size=(8, 7, 7), + ) + + backbone.blocks[-1].proj = None # Overwite head projection + projector = create_mlp_util( + projector_dim_in, + backbone_embed_dim, + projector_inner_dim, + projector_depth, + norm=projector_norm, # pyre-ignore[6] + ) + return backbone, projector + + backbone, projector = _make_bacbone_and_projector() + backbone_mmt, projector_mmt = _make_bacbone_and_projector() + + moco_model = MOCO( + mmt=mmt, + backbone=backbone, + projector=projector, + backbone_mmt=backbone_mmt, + projector_mmt=projector_mmt, + ) + return moco_model + + +class MOCO(nn.Module): + """ + Momentum Contrast for unsupervised Visual Representation Learning + Details can be found in: + https://arxiv.org/abs/1911.05722 + """ + + def __init__( + self, + mmt: float, + backbone: nn.Module, + backbone_mmt: nn.Module, + projector: Optional[nn.Module] = None, + projector_mmt: Optional[nn.Module] = None, + ) -> None: + """ + Args: + backbone (nn.Module): backbone for byol, input shape depends on the forward + input size. Standard inputs include `B x C`, `B x C x H x W`, and + `B x C x T x H x W`. + projector (nn.Module): An mlp with 2 to 3 hidden layers, + with (synchronized) BatchNorm and ReLU activation. + backbone_mmt (nn.Module): backbone for byol, input shape depends on the forward + input size. Standard inputs include `B x C`, `B x C x H x W`, and + `B x C x T x H x W`. + projector_mmt (nn.Module): Am mlp with 2 to 3 hidden layers, + with (synchronized) BatchNorm and ReLU activation. + mmt (float): momentum update ratio for the momentum backbone. + """ + super().__init__() + + self.mmt: float = mmt + + if projector is not None: + backbone = nn.Sequential( + backbone, + projector, + ) + init_net_weights(backbone) + self.backbone = backbone + + if projector_mmt is not None: + backbone_mmt = nn.Sequential( + backbone_mmt, + projector_mmt, + ) + init_net_weights(backbone_mmt) + self.backbone_mmt = backbone_mmt + + for p in self.backbone_mmt.parameters(): + p.requires_grad = False + + self._copy_weights_to_backbone_mmt() + + def _copy_weights_to_backbone_mmt(self) -> None: + dist = {} + for name, p in self.backbone.named_parameters(): + dist[name] = p + for name, p in self.backbone_mmt.named_parameters(): + p.data.copy_(dist[name].data) + + @torch.no_grad() + def momentum_update_backbone(self) -> None: + """ + Momentum update on the backbone. + """ + m = self.mmt + dist = {} + for name, p in self.backbone.named_parameters(): + dist[name] = p + for name, p in self.backbone_mmt.named_parameters(): + # pyre-ignore[41] + p.data = dist[name].data * (1.0 - m) + p.data * m + + @torch.no_grad() + def forward_backbone_mmt(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward momentum backbone. + Args: + x (tensor): input to be forwarded of shape N x C x T x H x W + """ + with torch.no_grad(): + proj = self.backbone_mmt(x) + out_proj = F.normalize(proj, dim=1) + return out_proj + + def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + """ + Args: + x (tensor): input to be forwarded of shape N x C x T x H x W + """ + proj = self.backbone(x) + out_proj = F.normalize(proj, dim=1) + return out_proj + + +class MOCOV2Module(SSLBaseModule): + """ + The Lightning Base module for MoCo SSL video task. + + For more details refer to, + 1. Momentum Contrast for unsupervised Visual Representation Learning: + https://arxiv.org/abs/1911.05722 + 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning + + Args: + model (OmegaConf): An omega conf object intializing the neural-network modle. + Example configs can be found in `pytorchvideo_trainer/conf/module/model` + loss(OmegaConf): An omega conf object intializing the loss function. + Example configs can be found in `pytorchvideo_trainer/conf/module/loss` + optim (OmegaConf): An omega conf object for constructing the optimizer object. + The associated config schema can be found at + `pytorchvideo_trainer.module.optimizer.OptimizerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/optim` + metrics (OmegaConf): The metrics to track, which will be used for both train, + validation and test. Example configs can be found in + `pytorchvideo_trainer/conf/module/metricx` + dim (int): Dimentionality of features in the stored queue. Set to be same as + embedding dimentions for the SSL model. + k (int): Queue size for stored features. + batch_suffle (bool): If true, performs shuffling of the computed keys. + local_shuffle_bn (bool): If true, only performs shuffling of keys with in the + current node. + lr_scheduler (OmegaConf): An omega conf object associated with learning rate + scheduler used during trainer. + The associated config schema can be found at + `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler` + modality_key (str): The modality key used in data processing, default: "video". + ensemble_method (str): The data ensembling method to control how we accumulate + the testing results at video level, which is optional. Users may choose from + ["sum", "max", None], If it is set to None, no data ensembling will be applied. + knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN + Memory module to use. Example config can be found at, + `pytorchvideo_trainer/conf/module/knn_memory`. + momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine + anneals the momentum term used from updating the backbone-history model. + num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if + pytorch lightning trainer's sync_batchnorm parameter is to false. + """ + + def __init__( + self, + model: Any, # pyre-ignore[2] + loss: Any, # pyre-ignore[2] + optim: Any, # pyre-ignore[2] + metrics: List[Any], # pyre-ignore[2] + dim: int, + k: int, + batch_shuffle: bool, + local_shuffle_bn: bool, + lr_scheduler: Optional[Any] = None, # pyre-ignore[2] + modality_key: BatchKey = "video", + ensemble_method: Optional[EnsembleMethod] = None, + knn_memory: Optional[Any] = None, # pyre-ignore[2] + momentum_anneal_cosine: bool = False, + num_sync_devices: int = 1, + ) -> None: + super().__init__( + model=model, + loss=loss, + optim=optim, + metrics=metrics, + lr_scheduler=lr_scheduler, + modality_key=modality_key, + ensemble_method=ensemble_method, + knn_memory=knn_memory, + momentum_anneal_cosine=momentum_anneal_cosine, + num_sync_devices=num_sync_devices, + ) + + self.dim: int = dim + self.k: int = k + self.batch_shuffle_on = batch_shuffle + self.local_shuffle_bn = local_shuffle_bn + self.register_buffer("ptr", torch.tensor([0])) + self.ptr.requires_grad = False + stdv = 1.0 / math.sqrt(self.dim / 3) + self.register_buffer( + "queue_x", + torch.rand(self.k, self.dim).mul_(2 * stdv).add_(-stdv), + ) + self.queue_x.requires_grad = False + self.local_process_group = None # pyre-ignore[4] + + def on_fit_start(self) -> None: + """Called at the very beginning of fit. + If on DDP it is called on every process + """ + dataloader = self.trainer.datamodule.train_dataloader() + if self.knn_memory is not None: + self.knn_memory.init_knn_labels(dataloader) # pyre-ignore[29] + + world_size = self.trainer.world_size + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and self.local_shuffle_bn + and self.batch_shuffle_on + ): + self._create_local_process_group() + + # TODO: For ad's dataloder this might be different + # pyre-ignore[16] + self.no_update_iters = self.k // world_size // dataloader.batch_size + + def _create_local_process_group(self) -> None: + assert self.trainer.num_gpus > 1, "Error creating local process group in MoCo" + + for i in range(self.trainer.num_nodes): + ranks_on_i = list( + range(i * self.trainer.num_gpus, (i + 1) * self.trainer.num_gpus) + ) + pg = torch.distributed.new_group(ranks=ranks_on_i) + if i == torch.distributed.get_rank() // self.trainer.num_gpus: + self.local_process_group = pg + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> None: + + self.cur_epoch_step += 1 # pyre-ignore[16] + + if self.momentum_anneal_cosine: + self._cosine_anneal_momentum() + + self.manual_zero_opt_grad() + self.manual_update_lr() + + inputs = batch[self.modality_key] # pyre-ignore[6] + + self.model.momentum_update_backbone() # pyre-ignore[29] + keys = self._compute_keys(inputs) + + partial_loss = 0.0 + for k, vids in enumerate(inputs): + other_keys = keys[:k] + keys[k + 1 :] + assert len(other_keys) > 0, "Length of keys cannot be zero" + + proj = self.model(vids) + q_knn = proj + queue_neg = torch.einsum("nc,kc->nk", [proj, self.queue_x.clone().detach()]) + + for k, key in enumerate(other_keys): + out_pos = torch.einsum("nc,nc->n", [proj, key]).unsqueeze(-1) + lgt_k = torch.cat([out_pos, queue_neg], dim=1) + if k == 0: + logits = lgt_k + else: + logits = torch.cat([logits, lgt_k], dim=0) + loss_k = self.loss(logits) # pyre-ignore[61] + + self.manual_backward(loss_k) + partial_loss += loss_k.detach() + + if self.knn_memory is not None: + self.knn_memory.update(q_knn, batch["video_index"]) # pyre-ignore[29,61] + + partial_loss /= len(inputs) * 2.0 # to have same loss as symmetric loss + self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True) + self._dequeue_and_enqueue(keys) + + if ( + self.trainer.current_epoch == 0 + and self.cur_epoch_step < self.no_update_iters + ): + print( + f"No update: Epoch {self.trainer.current_epoch}" + + f" Step {self.cur_epoch_step}/{self.no_update_iters}" + ) + return + + self.manual_opt_step() + + @torch.no_grad() + def _compute_keys(self, x: torch.Tensor) -> List[torch.Tensor]: + keys = [] + for sub_x in x: + if self.batch_shuffle_on: + with torch.no_grad(): + sub_x, idx_restore = self._batch_shuffle(sub_x) + with torch.no_grad(): + # pyre-ignore[29] + key = self.model.forward_backbone_mmt(sub_x).detach() + + if self.batch_shuffle_on: + key = self._batch_unshuffle(key, idx_restore).detach() + keys.append(key) + return keys + + @torch.no_grad() + def _batch_shuffle(self, x: torch.Tensor): # pyre-ignore[3] + world_size = self.trainer.world_size + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if self.local_shuffle_bn: + assert self.local_process_group is not None + x = du.cat_all_gather(x, self.local_process_group) + gpu_idx = du.get_local_rank(self.local_process_group) + world_size = self.trainer.num_gpus + else: + x = du.cat_all_gather(x) + gpu_idx = torch.distributed.get_rank() + + idx_randperm = torch.randperm(x.shape[0]).to(self.device) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.broadcast(idx_randperm, src=0) + else: + gpu_idx = 0 + idx_randperm = idx_randperm.view(world_size, -1) + x = x[idx_randperm[gpu_idx, :]] # pyre-ignore[61] + idx_restore = torch.argsort(idx_randperm.view(-1)) + idx_restore = idx_restore.view(world_size, -1) + + return x, idx_restore + + @torch.no_grad() + def _batch_unshuffle( + self, x: torch.Tensor, idx_restore: torch.Tensor + ) -> torch.Tensor: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if self.local_shuffle_bn: + assert self.local_process_group is not None + x = du.cat_all_gather(x, self.local_process_group) + gpu_idx = du.get_local_rank(self.local_process_group) + else: + x = du.cat_all_gather(x) + gpu_idx = torch.distributed.get_rank() + else: + gpu_idx = 0 + + idx = idx_restore[gpu_idx, :] + x = x[idx] + return x + + @torch.no_grad() + def _dequeue_and_enqueue( + self, + keys: List[torch.Tensor], + ) -> None: + assert len(keys) > 0, "need to have multiple views for adding them to queue" + ptr = int(self.ptr.item()) + for key in keys: + # write the current feat into queue, at pointer + num_items = int(key.size(0)) + assert ( + self.k % num_items == 0 + ), "Queue size should be a multiple of batchsize" + assert ptr + num_items <= self.k + self.queue_x[ptr : ptr + num_items, :] = key + # move pointer + ptr += num_items + # reset pointer + if ptr == self.k: + ptr = 0 + self.ptr[0] = ptr + + +@dataclass +class MOCOV2ModuleConf(ModuleConf): + _target_: str = get_class_name_str(MOCOV2Module) + model: Any = MISSING # pyre-ignore[4] + loss: Any = MISSING # pyre-ignore[4] + optim: Any = MISSING # pyre-ignore[4] + metrics: List[Any] = MISSING # pyre-ignore[4] + lr_scheduler: Optional[Any] = None # pyre-ignore[4] + modality_key: str = "video" + ensemble_method: Optional[str] = None + num_sync_devices: Optional[int] = 1 + knn_memory: Optional[Any] = None # pyre-ignore[4] + momentum_anneal_cosine: bool = False + dim: int = MISSING + k: int = MISSING + batch_shuffle: bool = MISSING + local_shuffle_bn: bool = MISSING + + +cs = ConfigStore() +cs.store( + group="schema/module", + name="moco_v2_module_conf", + node=MOCOV2ModuleConf, + package="module", +) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py new file mode 100644 index 00000000..2de8eb2f --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/optimizer.py @@ -0,0 +1,257 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-ignore-all-errors + +from dataclasses import dataclass + +import torch +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclass +class OptimizerConf: + method: str = MISSING + lr: float = MISSING + weight_decay: float = 1e-4 + bn_weight_decay: float = 0.0 + momentum: float = 0.9 + dampening: float = 0.0 + nesterov: bool = True + zero_weight_decay_1d_param: bool = False + lars_on: bool = False + + +# TODO: Refactor contruct_optimer to torch.optim conf + construct_param_group +def construct_optimizer( + model: torch.nn.Module, cfg: OptimizerConf # noqa +) -> torch.optim.Optimizer: + """ + Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer + with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay + Batchnorm and/or no-update 1-D parameters support, based on the config. + + Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling + (LARS): https://arxiv.org/abs/1708.03888 + + Args: + model (nn.Module): model to perform stochastic gradient descent + optimization or ADAM optimization. + cfg (OptimizerConf): Hydra/Omega conf object consisting hyper-parameters + of SGD or ADAM, includes base learning rate, momentum, weight_decay, + dampening and etc. The supported config schema is `OptimizerConf`. + Example config files can be found at, + `pytorchvideo_trainer/conf/module/optim` + """ + bn_parameters = [] + non_bn_parameters = [] + zero_parameters = [] + no_grad_parameters = [] + skip = {} + + if hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() # pyre-ignore[29] + + for name, m in model.named_modules(): + is_bn = isinstance(m, torch.nn.modules.batchnorm._NormBase) + for p in m.parameters(recurse=False): + if not p.requires_grad: + no_grad_parameters.append(p) + elif is_bn: + bn_parameters.append(p) + elif name in skip: + zero_parameters.append(p) + elif cfg.zero_weight_decay_1d_param and ( + len(p.shape) == 1 or name.endswith(".bias") + ): + zero_parameters.append(p) + else: + non_bn_parameters.append(p) + + optim_params = [ + { + "params": bn_parameters, + "weight_decay": cfg.bn_weight_decay, + "apply_LARS": False, + }, + { + "params": non_bn_parameters, + "weight_decay": cfg.weight_decay, + "apply_LARS": cfg.lars_on, + }, + { + "params": zero_parameters, + "weight_decay": 0.0, + "apply_LARS": cfg.lars_on, + }, + ] + optim_params = [x for x in optim_params if len(x["params"])] # pyre-ignore[6] + + # Check all parameters will be passed into optimizer. + assert len(list(model.parameters())) == len(non_bn_parameters) + len( + bn_parameters + ) + len(zero_parameters) + len( + no_grad_parameters + ), "parameter size does not match: {} + {} + {} + {} != {}".format( + len(non_bn_parameters), + len(bn_parameters), + len(zero_parameters), + len(no_grad_parameters), + len(list(model.parameters())), + ) + print( + "bn {}, non bn {}, zero {} no grad {}".format( + len(bn_parameters), + len(non_bn_parameters), + len(zero_parameters), + len(no_grad_parameters), + ) + ) + + if cfg.method == "sgd": + optimizer = torch.optim.SGD( + optim_params, + lr=cfg.lr, + momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + dampening=cfg.dampening, + nesterov=cfg.nesterov, + ) + elif cfg.method == "adam": + optimizer = torch.optim.Adam( + optim_params, + lr=cfg.lr, + betas=(0.9, 0.999), + weight_decay=cfg.weight_decay, + ) + elif cfg.method == "adamw": + optimizer = torch.optim.AdamW( + optim_params, + lr=cfg.lr, + eps=1e-08, + weight_decay=cfg.weight_decay, + ) + else: + raise NotImplementedError("Does not support {} optimizer".format(cfg.method)) + + if cfg.lars_on: + optimizer = LARS(optimizer=optimizer, trust_coefficient=0.001, clip=False) + return optimizer + + +cs = ConfigStore() +cs.store( + group="schema/module/optim", + name="optim_conf", + node=OptimizerConf, + package="module.optim", +) + + +class LARS(torch.optim.Optimizer): + """ + This class is adapted from + https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py to + include ignoring LARS application specific parameters (e.g. 1D params) + + Args: + optimizer (torch.optim): Pytorch optimizer to wrap and modify learning rate for. + trust_coefficient: Trust coefficient for calculating the lr. + See https://arxiv.org/abs/1708.03888 + clip (bool): Decides between clipping or scaling mode of LARS. If `clip=True` the + learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. + If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. + eps (float): epsilon kludge to help with numerical stability while calculating + adaptive_lr. + ignore_1d_param (float): If true, does not update 1 dimentional parameters. + """ + + def __init__( + self, + optimizer, + trust_coefficient=0.02, + clip=True, + eps=1e-8, + ignore_1d_param=True, + ) -> None: + self.optim = optimizer + self.trust_coefficient = trust_coefficient + self.eps = eps + self.clip = clip + self.ignore_1d_param = ignore_1d_param + + self.defaults = self.optim.defaults + + def __getstate__(self): + return self.optim.__getstate__() + + def __setstate__(self, state): + self.optim.__setstate__(state) + + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + @property + def param_groups(self): + return self.optim.param_groups + + @param_groups.setter + def param_groups(self, value): + self.optim.param_groups = value + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + + def zero_grad(self): + self.optim.zero_grad() + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) + + def step(self, closure=None): + with torch.no_grad(): + weight_decays = [] + for group in self.optim.param_groups: + # absorb weight decay control from optimizer + weight_decay = group["weight_decay"] if "weight_decay" in group else 0 + weight_decays.append(weight_decay) + apply_LARS = group["apply_LARS"] if "apply_LARS" in group else True + if not apply_LARS: + continue + group["weight_decay"] = 0 + for p in group["params"]: + if p.grad is None: + continue + if self.ignore_1d_param and p.ndim == 1: # ignore bias + continue + param_norm = torch.norm(p.data) + grad_norm = torch.norm(p.grad.data) + + if param_norm != 0 and grad_norm != 0: + # calculate adaptive lr + weight decay + adaptive_lr = ( + self.trust_coefficient + * (param_norm) + / (grad_norm + param_norm * weight_decay + self.eps) + ) + + # clip learning rate for LARS + if self.clip: + # calculation of adaptive_lr so that when multiplied + # by lr it equals `min(adaptive_lr, lr)` + adaptive_lr = min(adaptive_lr / group["lr"], 1) + + p.grad.data += weight_decay * p.data + p.grad.data *= adaptive_lr + + self.optim.step() + # return weight decay control to optimizer + for i, group in enumerate(self.optim.param_groups): + group["weight_decay"] = weight_decays[i] diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py new file mode 100644 index 00000000..643c7b11 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/simclr.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from dataclasses import dataclass +from typing import Optional, List, Callable, Union, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING +from pytorchvideo.models.resnet import create_resnet +from pytorchvideo.models.weight_init import init_net_weights +from pytorchvideo_trainer.module.byol import create_mlp_util +from pytorchvideo_trainer.module.ssl_helper import SSLBaseModule +from pytorchvideo_trainer.module.video_classification import ( + EnsembleMethod, + BatchKey, + Batch, +) +from torchrecipes.core.conf import ModuleConf +from torchrecipes.utils.config_utils import get_class_name_str + + +class SimCLR(nn.Module): + """ + Skeletal NN.Module for the SimCLR model that supports + arbitrary bacbone and projector models. + """ + + def __init__( + self, + backbone: nn.Module, + projector: Optional[nn.Module] = None, + ) -> None: + """ + Args: + backbone (nn.Module): backbone for simclr, input shape depends on the forward + input size. Standard inputs include `B x C`, `B x C x H x W`, and + `B x C x T x H x W`. + projector (nn.Module): An mlp with 2 to 3 hidden layers, + with (synchronized) BatchNorm and ReLU activation. + """ + super().__init__() + + if projector is not None: + backbone = nn.Sequential( + backbone, + projector, + ) + init_net_weights(backbone) + self.backbone = backbone + + def forward( + self, x_list: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x_list (list(tensor) or tensor): Expects a list of 2 tensors + for trainin phase and single tensor for the train and val + phases. Here all tensors are expected to be of the shape, + N x C x T x H x W. + """ + if not self.training: + assert isinstance( + x_list, torch.Tensor + ), "Expected tensor for test/val phase in SimCLR" + if self.backbone is not None: + x_list = self.backbone(x_list) + x_list = F.normalize(x_list, p=2, dim=1) + return x_list + + assert ( + isinstance(x_list, list) and len(x_list) == 2 + ), f"Invalid list input to SimCLR. Expected len 2 but received {len(x_list)}" + + for i, x in enumerate(x_list): + if self.backbone is not None: + x = self.backbone(x) + x = F.normalize(x, p=2, dim=1) + x_list[i] = x + + return x_list + + +def create_simclr_resnet_50( + # Backbone + backbone_creator: Callable = create_resnet, # pyre-ignore[24] + backbone_embed_dim: int = 128, + dim_in: int = 2048, + # Projector + # TODO: Standardize projector conf across all SSL tasks + mlp_activation: Callable = nn.ReLU, # pyre-ignore[24] + mlp_inner_dim: int = 2048, + mlp_depth: int = 1, + mlp_norm: Optional[Callable] = None, # pyre-ignore[24] +) -> SimCLR: + """ + Builds a Resnet video model with a projector for SimCLR + SSL traning task. + """ + backbone = backbone_creator( + model_num_class=backbone_embed_dim, + dropout_rate=0.0, + ) + backbone.blocks[-1].proj = None + projector = create_mlp_util( + dim_in, + backbone_embed_dim, + mlp_inner_dim, + mlp_depth, + norm=mlp_norm, # pyre-ignore[6] + ) + simclr = SimCLR( + backbone=backbone, + projector=projector, + ) + return simclr + + +class SimCLRModule(SSLBaseModule): + """ + The Lightning Base module for SimCLR SSL video task. + + For more details refer to, + 1. A Simple Framework for Contrastive Learning of Visual Representations : + https://arxiv.org/abs/2002.05709 + 2. A Large-Scale Study on Unsupervised Spatiotemporal Representation Learning + + Args: + model (OmegaConf): An omega conf object intializing the neural-network modle. + Example configs can be found in `pytorchvideo_trainer/conf/module/model` + loss(OmegaConf): An omega conf object intializing the loss function. + Example configs can be found in `pytorchvideo_trainer/conf/module/loss` + optim (OmegaConf): An omega conf object for constructing the optimizer object. + The associated config schema can be found at + `pytorchvideo_trainer.module.optimizer.OptimizerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/optim` + metrics (OmegaConf): The metrics to track, which will be used for both train, + validation and test. Example configs can be found in + `pytorchvideo_trainer/conf/module/metricx` + lr_scheduler (OmegaConf): An omega conf object associated with learning rate + scheduler used during trainer. + The associated config schema can be found at + `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler` + modality_key (str): The modality key used in data processing, default: "video". + ensemble_method (str): The data ensembling method to control how we accumulate + the testing results at video level, which is optional. Users may choose from + ["sum", "max", None], If it is set to None, no data ensembling will be applied. + knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN + Memory module to use. Example config can be found at, + `pytorchvideo_trainer/conf/module/knn_memory`. + num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if + pytorch lightning trainer's sync_batchnorm parameter is to false. + """ + + def __init__( + self, + model: Any, # pyre-ignore[2] + loss: Any, # pyre-ignore[2] + optim: Any, # pyre-ignore[2] + metrics: List[Any], # pyre-ignore[2] + lr_scheduler: Optional[Any] = None, # pyre-ignore[2] + modality_key: BatchKey = "video", + ensemble_method: Optional[EnsembleMethod] = None, + knn_memory: Optional[Any] = None, # pyre-ignore[2] + num_sync_devices: int = 1, + ) -> None: + super().__init__( + model=model, + loss=loss, + optim=optim, + metrics=metrics, + lr_scheduler=lr_scheduler, + modality_key=modality_key, + ensemble_method=ensemble_method, + knn_memory=knn_memory, + momentum_anneal_cosine=False, + num_sync_devices=num_sync_devices, + ) + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> None: + + self.cur_epoch_step += 1 # pyre-ignore[16] + + self.manual_zero_opt_grad() + self.manual_update_lr() + + inputs = batch[self.modality_key] # pyre-ignore[6] + partial_loss = 0.0 + for i in range(len(inputs) - 1): + y_hat = self(inputs[i : i + 2]) + loss = self.loss(y_hat) + self.manual_backward(loss) + partial_loss += loss.detach() + + partial_loss /= len(inputs) - 1 + self.log("Losses/train_loss", partial_loss, on_step=True, on_epoch=True) + + if self.knn_memory is not None: + # pyre-ignore[29] + self.knn_memory.update(y_hat[0], batch["video_index"]) + + self.manual_opt_step() + + +@dataclass +class SimCLRModuleConf(ModuleConf): + _target_: str = get_class_name_str(SimCLRModule) + model: Any = MISSING # pyre-ignore[4] + loss: Any = MISSING # pyre-ignore[4] + optim: Any = MISSING # pyre-ignore[4] + metrics: List[Any] = MISSING # pyre-ignore[4] + lr_scheduler: Optional[Any] = None # pyre-ignore[4] + modality_key: str = "video" + ensemble_method: Optional[str] = None + num_sync_devices: Optional[int] = 1 + knn_memory: Optional[Any] = None # pyre-ignore[4] + + +cs = ConfigStore() +cs.store( + group="schema/module", + name="simclr_module_conf", + node=SimCLRModuleConf, + package="module", +) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py new file mode 100644 index 00000000..8b737327 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/ssl_helper.py @@ -0,0 +1,473 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import math +from typing import Optional, List, Any, Dict, Callable + +import numpy as np +import pytorchvideo_trainer.module.distributed_utils as du +import torch +import torch.nn as nn +import torch.nn.functional as F +from hydra.utils import instantiate +from pytorch_lightning.trainer import Trainer +from pytorchvideo_trainer.module.video_classification import ( + VideoClassificationModule, + EnsembleMethod, + BatchKey, + Batch, +) + + +def create_mlp_util( + dim_in: int, + dim_out: int, + mlp_dim: int, + num_layers: int, + norm: Callable, # pyre-ignore[24] + bias: bool = True, + xavier_init: bool = True, +) -> nn.Module: + """ + A utility method for creating the MLP that gets attached to the SSL + bacbone network either in the form of the projector or predictor. + + Consists of multiple squences of "Linear -> Norm -> Relu" layers. + + Args: + dim_in (int): Input dimension size to the MLP. + dim_out (int): Output dimension size of MLP. + mlp_dim (int): Dimentions size for the inner layers of MLP. + num_layers (int): Number of layer in the MLP. + norm (callabe): Type of normalization to apply between layers. + Examples include BatchNorm, SyncBatchNorm, etc + bias (bool): If set true, enables bias for the final layer. + xavier_init (bool): If set to true, performs Xavier weight + initialization for all linear layers. + """ + if num_layers == 1: + return nn.Linear(dim_in, dim_out) + + b = False if norm is not None else bias + mlp_layers = [nn.Linear(dim_in, mlp_dim, bias=b)] + mlp_layers[-1].xavier_init = xavier_init + for i in range(1, num_layers): + if norm: + mlp_layers.append(norm(mlp_dim)) + mlp_layers.append(nn.ReLU(inplace=True)) + if i == num_layers - 1: + d = dim_out + b = bias + else: + d = mlp_dim + mlp_layers.append(nn.Linear(mlp_dim, d, bias=b)) + mlp_layers[-1].xavier_init = xavier_init + return nn.Sequential(*mlp_layers) + + +def create_classification_model_from_ssl_checkpoint( + ssl_checkpoint_path: str, + checkpoint_type: str, + mlp: Optional[nn.Module] = None, + detach_backbone: bool = False, +) -> nn.Module: + + """ + A utlity function for extracting the bacbone from the PyTorch Lightning's + SSL checkpoints. Used for supervided finetuning the SSL pre-trained models + in video classification task. + + Extracts bacbone from the checkpoints of the SimCLR, BYOL and MoCoV2 SSL + tasks and attaches the given MLP to the backbone. + + Args: + ssl_checkpoint_path (str): Path to the lightning checkpoint for the + said SSL task. + checkpoint_type (str): Type of the SSL task the checkpoint belongs to. + Should be one of ["simclr, "byol", "mocov_v2"] + mlp (nn.Module): If specified, the MLP module to attach to the bacbone + for the supervised finetuning phase. + detach_bacbone: If true, detaches bacbone and no gradient are tracked and + updated for the bacbone. Only updates the MLP weights during finetuning. + + Returns: + model (SSLFineTuningModel): Returns an instance of `SSLFineTuningModel`, + consisting of bacbone and mlp. + """ + + if checkpoint_type == "simclr": + from pytorchvideo_trainer.module.simclr import SimCLRModule as M + + lightning_module = M.load_from_checkpoint(ssl_checkpoint_path) + backbone = lightning_module.model.backbone[0] + elif checkpoint_type == "byol": + from pytorchvideo_trainer.module.byol import BYOLModule as M + + lightning_module = M.load_from_checkpoint(ssl_checkpoint_path) + backbone = lightning_module.model.backbone[0] + elif checkpoint_type == "moco_v2": + from pytorchvideo_trainer.module.moco_v2 import MOCOV2Module as M + + lightning_module = M.load_from_checkpoint(ssl_checkpoint_path) + backbone = lightning_module.model.backbone[0] + else: + raise ValueError("Incorrect SSL checkpoint type.") + + # pyre-ignore[6] + return SSLFineTuningModel(backbone, mlp, detach_backbone) + + +class SSLFineTuningModel(nn.Module): + """ + Model consisting of a backbone sequentially followed by an an MLP. + Used for supervised finetuning of the SSL pre-trained models. + + Args: + backbone (nn.Module): A model whole weights are conditionally + updated based on the betach_backbone parameter. + mlp (nn.Module): If specified, the MLP module to attach to the bacbone + for the supervised finetuning phase. + detach_bacbone: If true, detaches bacbone and no gradient are tracked and + updated for the bacbone. Only updates the MLP weights during finetuning. + """ + + def __init__( + self, + backbone: nn.Module, + mlp: nn.Module, + detach_backbone: bool, + ) -> None: + super().__init__() + + self.backbone = backbone + self.mlp = mlp + self.detach_backbone = detach_backbone + + for p in self.backbone.parameters(): + p.requires_grad = False if detach_backbone else True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.backbone(x) + if self.detach_backbone: + x = x.detach() + if self.mlp is not None: + x = self.mlp(x) + return x + + +class KnnMemory(nn.Module): + """ + KNN Memory object that keeps track of the features generated by the SSL model + during the traing phase and performs nearest neighbours inference during the + test and validation phases for video classfication. + + KNN memory requires that you provide the labels and video indices for the + dataset used for the SSL training phase. + + Args: + length (int): Size of the KNN memory. Set to be equal to the training dataset size. + dim (int): Feture dimention generated by the SSL model. + momentum (float): The rate at which to update the features in memory during the SSL- + training phase. + downstream_classes (int): Number of classes in the dataset. + temperature (float): Temperature scaling to use during the inference phase. Typically, + set to the same value as the loss temperature used in SSL. + knn_k (int): Number of nearest neighbours to aggregate metrics over for inference. + deive (str): Device to store the memory module on. + """ + + def __init__( + self, + length: int, + dim: int, + momentum: float = 1.0, + downstream_classes: int = 400, + temperature: float = 1.0, + knn_k: int = 200, + device: str = "cpu", + ) -> None: + super(KnnMemory, self).__init__() + self.length = length + self.dim = dim + self.momentum = momentum + self.temperature = temperature + self.downstream_classes = downstream_classes + self.knn_k = knn_k + stdv = 1.0 / math.sqrt(dim / 3) + self.device = device + self.register_buffer( + "memory", + torch.rand(length, dim, device=self.device).mul_(2 * stdv).add_(-stdv), + ) + + def resize(self, length: int, dim: int) -> None: + """ + Resizes the memory and intialized it fresh. + + Args: + length (int): Size of the KNN memory. Set to be equal to the training + dataset size. + dim (int): Feture dimention generated by the SSL model. + """ + self.length = length + self.dim = dim + stdv = 1.0 / math.sqrt(dim / 3) + del self.memory + self.memory = ( + torch.rand(length, dim, device=self.device).mul_(2 * stdv).add_(-stdv) + ) + + @torch.no_grad() + def get(self, ind: torch.Tensor) -> torch.Tensor: + """ + Fetches features from the memory based on the video index. + + Args: + ind (int): Index of the video / clip for which to fetch the features. + """ + batch_size = ind.size(0) + selected_mem = self.memory[ind.view(-1), :] + out = selected_mem.view(batch_size, -1, self.dim) + return out + + @torch.no_grad() + def update(self, mem: torch.Tensor, ind: torch.Tensor) -> None: + """ + Peforms feature update in the memory based on the new features realized by the + SSL model. Called during the SSL training phase. + + Args: + mem (tensor): Features of the same N x C genereated by the SSL model. + N is the batch size and C is the feature dimention generated by the + SSL Model. + ind (tensor): A 1-D tensor of video indices associated the given features. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + mem, ind = du.all_gather([mem, ind]) + mem = mem.view(mem.size(0), 1, -1) + mem_old = self.get(ind).to(mem.device) + + mem_update = mem * self.momentum + mem_old * (1 - self.momentum) + mem_update = F.normalize(mem_update, p=2, dim=1) + self.memory[ind.view(-1), :] = mem_update.squeeze().to(self.memory.device) + + @torch.no_grad() + def init_knn_labels(self, train_loader: Trainer) -> None: + """ + Called before traning, intializes the KNN Memory and resizes it based on the + labels and number of samples in the train dataloader. + + Args: + train_loader (dataloader): Trainining dataloader containing an attribute + `dataset._labeled_videos` which holds mapping from video indices to + labels. + """ + # TODO: Make sure all dataloader's have this property `dataset._labeled_videos` + self.num_imgs = len(train_loader.dataset._labeled_videos) # pyre-ignore[16] + # pyre-ignore[16] + self.train_labels = np.zeros((self.num_imgs,), dtype=np.int32) + for i in range(self.num_imgs): # pyre-ignore[6] + # pyre-ignore[29] + self.train_labels[i] = train_loader.dataset._labeled_videos[i][1]["label"] + self.train_labels = torch.LongTensor(self.train_labels).to(self.device) + if self.length != self.num_imgs: + self.resize(self.num_imgs, self.dim) # pyre-ignore[6] + + def forward(self, inputs: torch.Tensor) -> None: + pass + + @torch.no_grad() + def eval_knn(self, q_knn: torch.Tensor) -> torch.Tensor: + """ + Peforms KNN nearest neighbour aggregations and returns predictions + for the qurried features. + + Args: + q_nn (tensor): Features generated by the SSL model during the inference + phase. Expected to be of shape N x C where, N is the batch size and + C is the feature dimention generated by the SSL Model. + """ + device = q_knn.device + batch_size = q_knn.size(0) + dist = torch.einsum( + "nc,mc->nm", + q_knn.view(batch_size, -1), + self.memory.view(self.memory.size(0), -1).to(device), + ) + yd, yi = dist.topk(self.knn_k, dim=1, largest=True, sorted=True) + + K = yi.shape[1] + C = self.downstream_classes + candidates = self.train_labels.view(1, -1).expand(batch_size, -1) + candidates = candidates.to(device) + yi = yi.to(device) + retrieval = torch.gather(candidates, 1, yi) + retrieval_one_hot = torch.zeros((batch_size * K, C)).to(device) + retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) + yd_transform = (yd.clone().div_(self.temperature).exp_()).to(device) + probs = torch.mul( + retrieval_one_hot.view(batch_size, -1, C), + yd_transform.view(batch_size, -1, 1), + ) + preds = torch.sum(probs, 1) + return preds + + +class SSLBaseModule(VideoClassificationModule): + """ + The Lightning Base module supporting SimCLR, MoCo and BYOL SSL tasks. + + Args: + model (OmegaConf): An omega conf object intializing the neural-network modle. + Example configs can be found in `pytorchvideo_trainer/conf/module/model` + loss(OmegaConf): An omega conf object intializing the loss function. + Example configs can be found in `pytorchvideo_trainer/conf/module/loss` + optim (OmegaConf): An omega conf object for constructing the optimizer object. + The associated config schema can be found at + `pytorchvideo_trainer.module.optimizer.OptimizerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/optim` + metrics (OmegaConf): The metrics to track, which will be used for both train, + validation and test. Example configs can be found in + `pytorchvideo_trainer/conf/module/metricx` + lr_scheduler (OmegaConf): An omega conf object associated with learning rate + scheduler used during trainer. + The associated config schema can be found at + `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler` + modality_key (str): The modality key used in data processing, default: "video". + ensemble_method (str): The data ensembling method to control how we accumulate + the testing results at video level, which is optional. Users may choose from + ["sum", "max", None], If it is set to None, no data ensembling will be applied. + knn_memory (OmegaConf): An optional hydra / omeaga conf, if set, initializes KNN + Memory module to use. Example config can be found at, + `pytorchvideo_trainer/conf/module/knn_memory`. + momentum_anneal_cosine (bool): For MoCo and BYOL tasks, if set to true, cosine + anneals the momentum term used from updating the backbone-history model. + num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if + pytorch lightning trainer's sync_batchnorm parameter is to false. + """ + + def __init__( + self, + model: Any, # pyre-ignore[2] + loss: Any, # pyre-ignore[2] + optim: Any, # pyre-ignore[2] + metrics: List[Any], # pyre-ignore[2] + lr_scheduler: Optional[Any] = None, # pyre-ignore[2] + modality_key: BatchKey = "video", + ensemble_method: Optional[EnsembleMethod] = None, + knn_memory: Optional[Any] = None, # pyre-ignore[2] + momentum_anneal_cosine: bool = False, # TODO: Refactor out mmt from base class. + num_sync_devices: int = 1, + ) -> None: + super().__init__( + model=model, + loss=loss, + optim=optim, + metrics=metrics, + lr_scheduler=lr_scheduler, + modality_key=modality_key, + ensemble_method=ensemble_method, + num_sync_devices=num_sync_devices, + ) + + self.knn_memory: nn.Module = instantiate(knn_memory) + self.automatic_optimization = False + self.momentum_anneal_cosine = momentum_anneal_cosine + if self.momentum_anneal_cosine: + self.initial_mmt: float = self.model.mmt # pyre-ignore[8] + + if ensemble_method is not None: + assert ( + self.knn_memory is not None + ), "Test-Ensembling is only supported with KNN module" + + def on_fit_start(self) -> None: + """ + Called at the very beginning of fit. + If on DDP it is called on every process. + + Peforms conversion of model batchnorm layers into syncbatchnom + and intialized the KNN module using the dataloader labels. + """ + + self._convert_to_sync_bn() + if self.knn_memory is not None: + dataloader = self.trainer.datamodule.train_dataloader() + self.knn_memory.init_knn_labels(dataloader) # pyre-ignore[29] + + def _test_step_with_data_ensembling(self, batch: Batch, batch_idx: int) -> None: + """ + Operates on a single batch of data from the test set. + """ + assert ( + isinstance(batch, dict) + and self.modality_key in batch + and "label" in batch + and "video_index" in batch + and self.knn_memory is not None + ), ( + f"Returned batch [{batch}] is not a map with '{self.modality_key}' and" + + "'label' and 'video_index' keys" + ) + + y_hat = self(batch[self.modality_key]) + y_hat = ( + self.knn_memory.eval_knn(y_hat) if self.knn_memory is not None else y_hat + ) + preds = torch.nn.functional.softmax(y_hat, dim=-1) + labels = batch["label"] + video_ids = torch.tensor(batch["video_index"], device=self.device) + + self._ensemble_at_video_level(preds, labels, video_ids) + + def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]: + """ + If KNN Memory is enabled, evaluates metrics using the labels of neighbours + during the validation and test phases. + """ + assert ( + isinstance(batch, dict) + and self.modality_key in batch + and ("label" in batch or self.knn_memory is None) + and phase_type in ["val", "test"] + ), ( + f"Returned batch [{batch}] is not a map with '{self.modality_key}' and" + + "'label' keys" + ) + + if self.knn_memory is not None: + y_hat = self(batch[self.modality_key]) + y_hat = self.knn_memory.eval_knn(y_hat) + pred = torch.nn.functional.softmax(y_hat, dim=-1) + metrics_result = self._compute_metrics(pred, batch["label"], phase_type) + self.log_dict(metrics_result, on_epoch=True) + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> None: + """Missing method implemented in subsequent derived SSL task modules.""" + pass + + @torch.no_grad() + def _cosine_anneal_momentum(self) -> None: + """ + For MoCo and BYOL tasks, if self.momentum_anneal_cosine set to true, + cosine anneals the momentum term used from updating the backbone-history + model. + """ + # pyre-ignore[6] + exact_epoch = float(self.cur_epoch_step) / float( + self._num_training_steps_per_epoch() + ) + exact_epoch += self.trainer.current_epoch + new_mmt = ( + 1.0 + - (1.0 - self.initial_mmt) + * ( + math.cos(math.pi * float(exact_epoch) / float(self.trainer.max_epochs)) + + 1.0 + ) + * 0.5 + ) + self.model.mmt = new_mmt # pyre-ignore[16] + self.log("MMT", new_mmt, on_step=True, prog_bar=True) diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py b/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py new file mode 100644 index 00000000..78b7fa6e --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py @@ -0,0 +1,518 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-strict + +from dataclasses import dataclass +from typing import ( + Any, + Dict, + Mapping, + List, + Optional, + Iterable, + Tuple, + Union, + TypedDict, + Literal, +) + +import pytorch_lightning as pl +import torch +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate +from iopath.common.file_io import g_pathmgr + +# @manual "//github/third-party/omry/omegaconf:omegaconf" +from omegaconf import MISSING, OmegaConf +from pytorch_lightning.utilities import rank_zero_info +from pytorchvideo_trainer.datamodule.transforms import MixVideoBatchWrapper +from pytorchvideo_trainer.module.lr_policy import set_lr, get_epoch_lr, LRSchedulerConf +from pytorchvideo_trainer.module.optimizer import ( + construct_optimizer, +) +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torchrecipes.core.conf import ModuleConf +from torchrecipes.core.task_base import TaskBase +from torchrecipes.utils.config_utils import get_class_name_str + + +class Batch(TypedDict): + """ + PyTorchVideo batches are dictionaries containing each modality or metadata of + the batch collated video clips. For Kinetics it has the below keys and types. + """ + + video: torch.Tensor # (B, C, T, H, W) + audio: torch.Tensor # (B, S) + label: torch.Tensor # (B, 1) + video_index: List[int] # len(video_index) == B + + +BatchKey = Literal["video", "audio", "label", "video_index"] +EnsembleMethod = Literal["sum", "max"] +Output = Optional[Dict[str, torch.Tensor]] + + +class VideoClassificationModule(TaskBase[Batch, Output, Output], pl.LightningModule): + """ + The Lightning module supporting the video classification task. + + Args: + model (OmegaConf): An omega conf object intializing the neural-network modle. + Example configs can be found in `pytorchvideo_trainer/conf/module/model` + loss(OmegaConf): An omega conf object intializing the loss function. + Example configs can be found in `pytorchvideo_trainer/conf/module/loss` + optim (OmegaConf): An omega conf object for constructing the optimizer object. + The associated config schema can be found at + `pytorchvideo_trainer.module.optimizer.OptimizerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/optim` + metrics (OmegaConf): The metrics to track, which will be used for both train, + validation and test. Example configs can be found in + `pytorchvideo_trainer/conf/module/metricx` + lr_scheduler (OmegaConf): An omega conf object associated with learning rate + scheduler used during trainer. + The associated config schema can be found at + `pytorchvideo_trainer.module.lr_policy.LRSchedulerConf`. + Example configs can be found in `pytorchvideo_trainer/conf/module/lr_scheduler` + modality_key (str): The modality key used in data processing, default: "video". + ensemble_method (str): The data ensembling method to control how we accumulate + the testing results at video level, which is optional. Users may choose from + ["sum", "max", None], If it is set to None, no data ensembling will be applied. + num_classes (int): The number of classes in the dataset. + num_sync_devices (int): Number of gpus to sync bathcnorm over. Only works if + pytorch lightning trainer's sync_batchnorm parameter is to false. + batch_transform (OmegaConf): An optional omega conf object, for constructing the + data transform method that act upon the entire mini batch. Examples include, + MixVideo transform, etc. + clip_gradient_norm (float): Performs gradient clipping if set to a positive value. + Since, we use Pytorch-lightning's manual optimization approach gradient clipping + has to be be set in the lightning module instead of the Trainer object. + """ + + def __init__( + self, + model: Any, # pyre-ignore[2] + loss: Any, # pyre-ignore[2] + optim: Any, # pyre-ignore[2] + metrics: List[Any], # pyre-ignore[2] + lr_scheduler: Optional[Any] = None, # pyre-ignore[2] + modality_key: BatchKey = "video", + ensemble_method: Optional[EnsembleMethod] = None, + num_classes: int = 400, + num_sync_devices: int = 1, + batch_transform: Optional[Any] = None, # pyre-ignore[2] + clip_gradient_norm: float = 0.0, + ) -> None: + super().__init__() + self.automatic_optimization = False + + self.model: nn.Module = instantiate(model, _convert_="all") + self.loss: nn.Module = instantiate(loss) + self.batch_transform = instantiate(batch_transform) # pyre-ignore[4] + rank_zero_info(OmegaConf.to_yaml(optim)) + self.optim: torch.optim.Optimizer = construct_optimizer(self.model, optim) + self.lr_scheduler_conf: LRSchedulerConf = lr_scheduler + self.modality_key: BatchKey = modality_key + self.ensemble_method: Optional[EnsembleMethod] = ensemble_method + self.num_classes: int = num_classes + self.clip_gradient_norm = clip_gradient_norm + + self.metrics: Mapping[str, nn.Module] = { + metric_conf.name: instantiate(metric_conf.config) for metric_conf in metrics + } + + self.train_metrics: nn.ModuleDict = nn.ModuleDict() + self.val_metrics: nn.ModuleDict = nn.ModuleDict() + self.test_metrics: nn.ModuleDict = nn.ModuleDict() + + self.save_hyperparameters() + + # These are used for data ensembling in the test stage. + self.video_preds: Dict[int, torch.Tensor] = {} + self.video_labels: Dict[int, torch.Tensor] = {} + self.video_clips_cnts: Dict[int, int] = {} + + # Sync BatchNorm + self.num_sync_devices = num_sync_devices + + def setup(self, stage: Optional[str] = None) -> None: + if stage == "fit": + self.train_metrics.update(self.metrics) + self.val_metrics.update(self.metrics) + else: + self.test_metrics.update(self.metrics) + + # pyre-ignore[14]: *args, **kwargs are not torchscriptable. + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward defines the prediction/inference actions. + """ + return self.model(x) + + def _num_training_steps_per_epoch(self) -> int: + """training steps per epoch inferred from datamodule and devices.""" + dataloader = self.trainer.datamodule.train_dataloader() + world_size = self.trainer.world_size + + # TODO: Make sure other dataloaders has this property + dataset_size = self.trainer.limit_train_batches + dataset_size *= len(dataloader.dataset._labeled_videos) + + # TODO: Make sure other dataloaders has this property + return dataset_size // world_size // dataloader.batch_size + + def manual_update_lr(self) -> None: + """Utility function for manually updating the optimizer learning rate""" + + opt = self.optimizers() + + if self.lr_scheduler_conf is not None: + # pyre-ignore[6] + exact_epoch = float(self.cur_epoch_step) / float( + self._num_training_steps_per_epoch() + ) + exact_epoch += self.trainer.current_epoch + lr = get_epoch_lr(exact_epoch, self.lr_scheduler_conf) + self.log("LR", lr, on_step=True, prog_bar=True) + self.log("ExactE", exact_epoch, on_step=True, prog_bar=True) + + if isinstance(opt, list): + for op in opt: + set_lr(op, lr) # pyre-ignore[6] + else: + set_lr(opt, lr) # pyre-ignore[6] + + def manual_zero_opt_grad(self) -> None: + """Utility function for zeroing optimzer gradients""" + opt = self.optimizers() + if isinstance(opt, list): + for op in opt: + op.zero_grad() # pyre-ignore[16] + else: + opt.zero_grad() + + def manual_opt_step(self) -> None: + """Utility function for manually stepping the optimzer""" + opt = self.optimizers() + if isinstance(opt, list): + for op in opt: + op.step() + else: + opt.step() + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> None: + """ + The PyTorchVideo models and transforms expect the same input shapes and + dictionary structure making this function just a matter of unwrapping the + dict and feeding it through the model/loss. + """ + self.cur_epoch_step += 1 # pyre-ignore[16] + + if self.batch_transform is not None: + batch = self.batch_transform(batch) + + self.manual_zero_opt_grad() + self.manual_update_lr() + + # Forward/backward + loss = self._step(batch, batch_idx, "train") + self.manual_backward(loss) # pyre-ignore[6] + if self.clip_gradient_norm > 0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_gradient_norm + ) + self.manual_opt_step() + + def validation_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> Dict[str, Any]: + """ + Operates on a single batch of data from the validation set. + """ + return self._step(batch, batch_idx, "val") + + def test_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> Optional[Dict[str, Any]]: + """ + Operates on a single batch of data from the test set. + """ + if self.ensemble_method: + self._test_step_with_data_ensembling(batch, batch_idx) + else: + return self._step(batch, batch_idx, "test") + + def _test_step_with_data_ensembling(self, batch: Batch, batch_idx: int) -> None: + """ + Operates on a single batch of data from the test set. + """ + assert ( + isinstance(batch, dict) + and self.modality_key in batch + and "label" in batch + and "video_index" in batch + ), ( + f"Returned batch [{batch}] is not a map with '{self.modality_key}' and" + + "'label' and 'video_index' keys" + ) + + y_hat = self(batch[self.modality_key]) + preds = torch.nn.functional.softmax(y_hat, dim=-1) + labels = batch["label"] + video_ids = torch.tensor(batch["video_index"], device=self.device) + + self._ensemble_at_video_level(preds, labels, video_ids) + + def on_train_epoch_start(self) -> None: + self._reset_metrics("train") + self.cur_epoch_step = 0.0 # pyre-ignore[16] + + def on_validation_epoch_start(self) -> None: + self._reset_metrics("val") + + def on_test_epoch_start(self) -> None: + self._reset_metrics("test") + + def on_test_epoch_end(self) -> None: + """Pytorch-Lightning's method for aggregating test metrics at the end of epoch""" + if self.ensemble_method: + for video_id in self.video_preds: + self.video_preds[video_id] = ( + self.video_preds[video_id] / self.video_clips_cnts[video_id] + ) + video_preds = torch.stack(list(self.video_preds.values()), dim=0) + video_labels = torch.tensor( + list(self.video_labels.values()), + device=self.device, + ) + metrics_result = self._compute_metrics(video_preds, video_labels, "test") + self.log_dict(metrics_result) + + def _ensemble_at_video_level( + self, preds: torch.Tensor, labels: torch.Tensor, video_ids: torch.Tensor + ) -> None: + """ + Ensemble multiple predictions of the same view together. This relies on the + fact that the dataloader reads multiple clips of the same video at different + spatial crops. + """ + for i in range(preds.shape[0]): + vid_id = int(video_ids[i]) + self.video_labels[vid_id] = labels[i] + if vid_id not in self.video_preds: + self.video_preds[vid_id] = torch.zeros( + (self.num_classes), device=self.device, dtype=preds.dtype + ) + self.video_clips_cnts[vid_id] = 0 + + if self.ensemble_method == "sum": + self.video_preds[vid_id] += preds[i] + elif self.ensemble_method == "max": + self.video_preds[vid_id] = torch.max(self.video_preds[vid_id], preds[i]) + self.video_clips_cnts[vid_id] += 1 + + def configure_optimizers( + self, + ) -> Union[ + torch.optim.Optimizer, + Tuple[Iterable[torch.optim.Optimizer], Iterable[_LRScheduler]], + ]: + """Pytorch-Lightning's method for configuring optimizer""" + return self.optim + + def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]: + assert ( + isinstance(batch, dict) and self.modality_key in batch and "label" in batch + ), ( + f"Returned batch [{batch}] is not a map with '{self.modality_key}' and" + + "'label' keys" + ) + + y_hat = self(batch[self.modality_key]) + if phase_type == "train": + loss = self.loss(y_hat, batch["label"]) + self.log( + f"Losses/{phase_type}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + else: + loss = None + + ## TODO: Move MixUP transform metrics to sperate method. + if ( + phase_type == "train" + and self.batch_transform is not None + and isinstance(self.batch_transform, MixVideoBatchWrapper) + ): + _top_max_k_vals, top_max_k_inds = torch.topk( + batch["label"], 2, dim=1, largest=True, sorted=True + ) + idx_top1 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 0] + idx_top2 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 1] + y_hat = y_hat.detach() + y_hat[idx_top1] += y_hat[idx_top2] + y_hat[idx_top2] = 0.0 + batch["label"] = top_max_k_inds[:, 0] + + pred = torch.nn.functional.softmax(y_hat, dim=-1) + metrics_result = self._compute_metrics(pred, batch["label"], phase_type) + self.log_dict(metrics_result, on_epoch=True) + + return loss + + def _compute_metrics( + self, pred: torch.Tensor, label: torch.Tensor, phase_type: str + ) -> Dict[str, torch.Tensor]: + metrics_dict = getattr(self, f"{phase_type}_metrics") + metrics_result = {} + for name, metric in metrics_dict.items(): + metrics_result[f"Metrics/{phase_type}/{name}"] = metric(pred, label) + return metrics_result + + def _reset_metrics(self, phase_type: str) -> None: + metrics_dict = getattr(self, f"{phase_type}_metrics") + for _, metric in metrics_dict.items(): + metric.reset() + + def _convert_to_sync_bn(self) -> None: + """ + Converts BatchNorm into sync-batchnorm. + If pytorch lightning trainer's sync_batchnorm parameter is to true, + performs global sync-batchnorm across all nodes and gpus. Else, + if perform local sync-batchnorm acroos specified number of gpus. + """ + if ( + hasattr(self.trainer.training_type_plugin, "sync_batchnorm") + and self.trainer.training_type_plugin.sync_batchnorm + ): + print("Using Global Synch BatchNorm.") + return None + + if self.num_sync_devices > 1: + print(f"Using local Synch BatchNorm over {self.num_sync_devices} devices.") + pg = create_syncbn_process_group(self.num_sync_devices) + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model, process_group=pg + ) + + def on_fit_start(self) -> None: + """ + Called at the very beginning of fit. + If on DDP it is called on every process. + """ + self._convert_to_sync_bn() + + +def create_syncbn_process_group(group_size: int) -> List[int]: + """ + Creates process groups to be used for syncbn of a give ``group_size`` and returns + process group that current GPU participates in. + + Args: + group_size (int): number of GPU's to collaborate for sync bn. group_size should + be >=2 else, no action is taken. + """ + assert ( + group_size > 1 + ), f"Invalid group size {group_size} to convert to sync batchnorm." + + world_size = torch.distributed.get_world_size() + assert world_size >= group_size + assert world_size % group_size == 0 + + group = None + for group_num in range(world_size // group_size): + group_ids = range(group_num * group_size, (group_num + 1) * group_size) + cur_group = torch.distributed.new_group(ranks=group_ids) + if torch.distributed.get_rank() // group_size == group_num: + group = cur_group + # can not drop out and return here, + # every process must go through creation of all subgroups + + assert group is not None + return group + + +@dataclass +class VideoClassificationModuleConf(ModuleConf): + _target_: str = get_class_name_str(VideoClassificationModule) + model: Any = MISSING # pyre-ignore[4] + loss: Any = MISSING # pyre-ignore[4] + optim: Any = MISSING # pyre-ignore[4] + metrics: List[Any] = MISSING # pyre-ignore[4] + lr_scheduler: Optional[Any] = None # pyre-ignore[4] + modality_key: str = "video" + ensemble_method: Optional[str] = None + num_classes: int = 400 + num_sync_devices: Optional[int] = 1 + + +@dataclass +class VideoClassificationModuleConfVisionTransformer(VideoClassificationModuleConf): + + batch_transform: Optional[Any] = None # pyre-ignore[4] + clip_gradient_norm: float = 0.0 + + +cs = ConfigStore() +cs.store( + group="schema/module", + name="video_classification_module_conf", + node=VideoClassificationModuleConf, + package="module", +) + +cs.store( + group="schema/module", + name="video_classification_module_conf_vision_transformer", + node=VideoClassificationModuleConfVisionTransformer, + package="module", +) + + +def create_classification_model_from_modelzoo( + checkpoint_path: str, + model: nn.Module, +) -> nn.Module: + """ + Builds a model from PyTorchVideo's model zoo checkpoint. + + Example config for building this method can be found at - + `pytorchvideo_trainer/conf/module/model/from_model_zoo_checkpoint.yaml` + + Args: + checkpoint_path (str): Path the pretrained model weights. + model (nn.Module): Module to load the checkpoints into. + Returns: + model (nn.Module): Returns the model with pretrained weights loaded. + """ + + with g_pathmgr.open(checkpoint_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + state_dict = checkpoint["model_state"] + model.load_state_dict(state_dict) + return model + + +def create_classification_model_from_lightning( + checkpoint_path: str, +) -> nn.Module: + """ + Builds a model from pytorchvideo_trainer's PytorchLightning checkpoint. + + Example config for building this method can be found at - + `pytorchvideo_trainer/conf/module/model/from_lightning_checkpoint.yaml` + + Args: + checkpoint_path (str): Path the pretrained model weights. + Returns: + model (nn.Module): Returns the model with pretrained weights loaded. + """ + lightning_model = VideoClassificationModule.load_from_checkpoint(checkpoint_path) + return lightning_model.model diff --git a/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py b/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py new file mode 100644 index 00000000..33980b57 --- /dev/null +++ b/pytorchvideo_trainer/pytorchvideo_trainer/train_app.py @@ -0,0 +1,300 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import os +from dataclasses import dataclass, field +from typing import Any, List, Dict, Optional, Union + +import hydra +import numpy as np +import submitit +import torch +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning import LightningModule, LightningDataModule +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorchvideo_trainer.datamodule.datamodule import ( + VideoClassificationDataModuleConf, +) +from pytorchvideo_trainer.module.video_classification import ( + VideoClassificationModuleConf, +) +from torchrecipes.core.base_train_app import BaseTrainApp +from torchrecipes.core.base_train_app import TrainOutput +from torchrecipes.core.conf import TrainerConf, TrainAppConf +from torchrecipes.utils.config_utils import get_class_name_str + + +class VideoClassificationTrainApp(BaseTrainApp): + """ + This app is used to launch the video tasks (both Classfication and SSL). + Main point of entry for all training, validation and test phases. + + The hydra/Omega conf schema used by the train app is as defined in + `VideoClassificationTrainAppConf` + + Args: + module (OmegaConf): Hydra/Omega conf object associated with the initialization of the + pytorch-lightning module. Supported config schema's include, + 1. `pytorchvide_trainer.module.video_classification.VideoClassificationModuleConf` + 2. `pytorchvide_trainer.module.simclr.SimCLRModuleConf` + 3. `pytorchvide_trainer.module.byol.BYOLModuleConf` + 4. `pytorchvide_trainer.module.moco_v2.MOCOV2ModuleConf` + and more. Example definitions of the config can be found in + `pytorchvide_trainer/conf.module` + trainer (OmegaConf): Hydra/Omega conf object associated with the initialization of the + pytorch-lightning Trainer object. Supported config schema can be found in + `github.com/facebookresearch/recipes/blob/main/torchrecipes/core/conf/__init__.py` + datamodule (OmegaConf): Hydra/Omega conf object associated with the initialization of + the pytorch-lightning DataModule object. Supported config schema can be found at, + `pytorchvideo_trainer.datamodule.datamodule.VideoClassificationDataModuleConf` + logger (OmegaConf): Hydra/Omega conf object associated with the initialization of the + pytorch-lightning's tensboard logger object. Example config can be found at, + `pytorchvideo_trainer/conf/logger` + callbacks (List[OmegaConf]): Hydra/Omega conf object associated with the intialization + of a series of pytorch-ligtning Callbacks that act upon the lightning module. Expect + a list or iterable config object wherein, each element represent the hydra conf of + a single callback. Thus, supports loading multiple callabacks at a time. Example + configs can be found at `pytorchvideo_trainer/conf/callbacks` + submitit_conf (OmegaConf): Hydra/Omega conf to be used by the `submitit_launcher` for + launching the train app. Example config file can be found at, + `pytorchvideo_trainer/conf/submitit_conf` + """ + + def __init__( + self, + module: VideoClassificationModuleConf, + trainer: TrainerConf, + datamodule: VideoClassificationDataModuleConf, + logger: Any, # pyre-ignore[2] + callbacks: Optional[Any] = None, # pyre-ignore[2] + submitit_conf: Optional[Any] = None, # pyre-ignore[2] + ) -> None: + + self.logger_conf: DictConfig = logger + self.callbacks_conf: DictConfig = callbacks + self.submitit_conf: DictConfig = submitit_conf + # This has to happen at last because it depends on the value above. + super().__init__(module, trainer, datamodule) + + def get_data_module(self) -> Optional[LightningDataModule]: + """ + Instantiate a LightningDataModule. + """ + return hydra.utils.instantiate( + self.datamodule_conf, + _recursive_=False, + ) + + def get_lightning_module(self) -> LightningModule: + """ + Instantiate a LightningModule. + """ + return hydra.utils.instantiate( + self.module_conf, + _recursive_=False, + ) + + def get_callbacks(self) -> List[Callback]: + """ + Creates a list of callbacks that feeds into trainer. + You can add additional ModelCheckpoint here too. + """ + callbacks = [] + if self.trainer_conf.logger: + callbacks.extend( + [ + LearningRateMonitor(), + ] + ) + if self.callbacks_conf is None: + return callbacks + + for cb_conf in self.callbacks_conf.values(): + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + def _make_reproducible_conf(self) -> DictConfig: + conf = OmegaConf.create() + conf._target_ = "pytorchvideo_trainer.train_app.VideoClassificationTrainApp" + conf.module = self.module_conf + conf.trainer = self.trainer_conf + conf.datamodule = self.datamodule_conf + conf.logger = self.logger_conf + conf.callbacks = self.callbacks_conf + conf.submitit_conf = self.submitit_conf + return conf + + def get_logger(self) -> TensorBoardLogger: + """ + Creates a logger that feeds into trainer. + Override this method to return a logger for trainer. + """ + logger = hydra.utils.instantiate( + self.logger_conf, + _recursive_=False, + ) + + @rank_zero_only + def log_params() -> None: # pyre-ignore[53] + if os.environ["PTV_TRAINER_ENV"] == "oss": + from iopath.common.file_io import g_pathmgr + + conf_to_log = self._make_reproducible_conf() + conf_save_path = os.path.join(logger.log_dir, "train_app_conf.yaml") + if not g_pathmgr.exists(conf_save_path): + with g_pathmgr.open(conf_save_path, mode="w") as f: + f.write(OmegaConf.to_yaml(conf_to_log)) + else: + from stl.lightning.io import filesystem + + fs = filesystem.get_filesystem(logger.log_dir) + conf_to_log = self._make_reproducible_conf() + fs.makedirs(logger.log_dir, exist_ok=True) + conf_save_path = os.path.join(logger.log_dir, "train_app_conf.yaml") + if not fs.exists(conf_save_path): + with fs.open(conf_save_path, mode="w") as f: + f.write(OmegaConf.to_yaml(conf_to_log)) + + log_params() + return logger + + def test(self) -> TrainOutput: # pyre-ignore[15] + """ + Triggers PyTorch-lightning's testing phase. + """ + trainer, _ = self._get_trainer() + trainer.test(self.module, datamodule=self.datamodule) + return TrainOutput(tensorboard_log_dir=self.root_dir) + + def predict(self) -> TrainOutput: # pyre-ignore[15] + """ + Triggers PyTorch-lightning's prediction phase. + """ + trainer, _ = self._get_trainer() + trainer.predict(self.module, datamodule=self.datamodule) + return TrainOutput(tensorboard_log_dir=self.root_dir) + + +def run_app_in_certain_mode( + cfg: TrainAppConf, mode: str, env: str = "oss" +) -> TrainOutput: + + os.environ["PTV_TRAINER_ENV"] = env + + rank_zero_info(OmegaConf.to_yaml(cfg)) + + # TODO: Move this to config and replace with `seed_everything` + np.random.seed(0) + torch.manual_seed(0) + app = hydra.utils.instantiate(cfg, _recursive_=False) + + if mode == "train": + rank_zero_info("MODE set to train, run train only.") + return app.train() + elif mode == "test": + rank_zero_info("MODE set to test, run test only.") + return app.test() + elif mode == "predict": + rank_zero_info("MODE set to predict, run train and predict.") + app.train() + return app.predict() + else: + # By default, run train and test + app.train() + return app.test() + + +project_defaults: List[Union[str, Dict[str, str]]] = [ + "_self_", + {"schema/module": "video_classification_module_conf"}, + {"schema/module/optim": "optim_conf"}, + {"schema/datamodule": "ptv_video_classification_data_module_conf"}, + {"datamodule/dataloader": "kinetics_classification"}, + {"logger": "ptl"}, + {"datamodule/transforms": "kinetics_classification_slow"}, + {"module/model": "slow_r50"}, + {"module/loss": "cross_entropy"}, + {"module/optim": "sgd"}, + {"module/metrics": "accuracy"}, + {"schema/trainer": "trainer"}, + {"trainer": "cpu"}, +] + + +@dataclass +class VideoClassificationTrainAppConf(TrainAppConf): + _target_: str = get_class_name_str(VideoClassificationTrainApp) + datamodule: VideoClassificationDataModuleConf = MISSING + module: VideoClassificationModuleConf = MISSING + trainer: TrainerConf = MISSING + + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. + logger: Any = MISSING + + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. + callbacks: Optional[Any] = None + + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. + defaults: List[Any] = field(default_factory=lambda: project_defaults) + + # pyre-fixme[4]: Attribute annotation cannot contain `Any`. + submitit_conf: Optional[Any] = None + + +cs = ConfigStore() +cs.store( + name="video_classification_train_app_conf", + node=VideoClassificationTrainAppConf, +) + + +@hydra.main(config_path="conf", config_name=None) +# pyre-ignore[2] +def submitit_launcher(cfg) -> None: + + print("###################### Train App Config ####################") + print(OmegaConf.to_yaml(cfg)) + print("############################################################") + + submitit_conf = cfg.get("submitit_conf", None) + logger_conf = cfg.get("logger", None) + assert submitit_conf is not None, "Missing submitit config" + + if logger_conf is not None: + assert ( + logger_conf.save_dir is not None + ), "set save_dir in logger conf to a valid path" + submitit_dir = os.path.join(logger_conf.save_dir, logger_conf.name) + else: + assert submitit_conf.log_save_dir is not None + submitit_dir = submitit_conf.log_save_dir + + submitit_dir = os.path.join(submitit_dir, "submitit_logs") + executor = submitit.AutoExecutor(folder=submitit_dir) + job_kwargs = { + "slurm_time": submitit_conf.time, + "name": cfg.logger.name if logger_conf is not None else submitit_conf.name, + "slurm_partition": submitit_conf.partition, + "gpus_per_node": cfg.trainer.gpus, + "tasks_per_node": cfg.trainer.gpus, # one task per GPU + "cpus_per_task": submitit_conf.cpus_per_task, + "nodes": cfg.trainer.num_nodes, + } + if submitit_conf.get("mem", None) is not None: + job_kwargs["slurm_mem"] = submitit_conf.mem + if submitit_conf.get("constraints", None) is not None: + job_kwargs["constraints"] = submitit_conf.constraints + + executor.update_parameters(**job_kwargs) + job = executor.submit(run_app_in_certain_mode, cfg, submitit_conf.mode) + print("Submitit Job ID:", job.job_id) + + +if __name__ == "__main__": + submitit_launcher() diff --git a/pytorchvideo_trainer/setup.py b/pytorchvideo_trainer/setup.py new file mode 100644 index 00000000..96daf53e --- /dev/null +++ b/pytorchvideo_trainer/setup.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +from setuptools import find_packages, setup + +setup( + name="pytorchvideo_trainer", + version="0.0.1", + license="Apache 2.0", + author="Facebook AI", + url="https://github.com/facebookresearch/pytorchvideo", + description="PyTorch-Lightning trainer powering PyTorchVideo models.", + python_requires=">=3.8", + install_requires=[ + "submitit", + "pytorchvideo>=0.1.5", + ], + extras_require={ + "test": ["coverage", "pytest", "opencv-python"], + "dev": [ + "opencv-python", + "black==20.8b1", + "sphinx", + "isort==4.3.21", + "flake8==3.8.1", + "flake8-bugbear", + "flake8-comprehensions", + "pre-commit", + "nbconvert", + "bs4", + "autoflake==1.4", + ], + "opencv-python": [ + "opencv-python", + ], + }, + packages=find_packages(exclude=("scripts", "tests")), +) diff --git a/pytorchvideo_trainer/tests/__init__.py b/pytorchvideo_trainer/tests/__init__.py new file mode 100644 index 00000000..5c7f19c6 --- /dev/null +++ b/pytorchvideo_trainer/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. diff --git a/pytorchvideo_trainer/tests/test_conf_datamodule.py b/pytorchvideo_trainer/tests/test_conf_datamodule.py new file mode 100644 index 00000000..153be1ca --- /dev/null +++ b/pytorchvideo_trainer/tests/test_conf_datamodule.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import unittest + +from hydra.experimental import compose, initialize_config_module +from hydra.utils import instantiate # @manual +from pytorchvideo_trainer.datamodule.datamodule import PyTorchVideoDataModule + + +class TestKineticsDataModuleConf(unittest.TestCase): + def test_init_with_hydra(self) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + test_conf = compose( + config_name="video_classification_train_app_conf", + overrides=[ + "datamodule/dataloader=kinetics_classification", + "datamodule/transforms=kinetics_classification_slow", + ], + ) + print(test_conf) + kinetics_data_module = instantiate( + test_conf.datamodule, + _recursive_=False, + ) + self.assertIsInstance(kinetics_data_module, PyTorchVideoDataModule) + self.assertIsNotNone(kinetics_data_module.transforms["train"]) + self.assertIsNotNone(kinetics_data_module.transforms["val"]) + self.assertIsNotNone(kinetics_data_module.transforms["test"]) diff --git a/pytorchvideo_trainer/tests/test_conf_module.py b/pytorchvideo_trainer/tests/test_conf_module.py new file mode 100644 index 00000000..f2f93b72 --- /dev/null +++ b/pytorchvideo_trainer/tests/test_conf_module.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import unittest + +import hydra +from hydra.experimental import compose, initialize_config_module +from pytorchvideo_trainer.module.byol import BYOLModule +from pytorchvideo_trainer.module.moco_v2 import ( + MOCOV2Module, +) +from pytorchvideo_trainer.module.simclr import ( + SimCLRModule, +) +from pytorchvideo_trainer.module.video_classification import ( + VideoClassificationModule, +) + + +class TestVideoClassificationModuleConf(unittest.TestCase): + def test_init_with_hydra(self) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + test_conf = compose( + config_name="video_classification_train_app_conf", + overrides=["module/model=slow_r50"], + ) + test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) + self.assertIsInstance(test_module, VideoClassificationModule) + self.assertIsNotNone(test_module.model) + + +class TestVideoSimCLRModuleConf(unittest.TestCase): + def test_init_with_hydra(self) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + test_conf = compose( + config_name="simclr_train_app_conf", + ) + test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) + self.assertIsInstance(test_module, SimCLRModule) + self.assertIsNotNone(test_module.model) + + +class TestVideoBYOLModuleConf(unittest.TestCase): + def test_init_with_hydra(self) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + test_conf = compose( + config_name="byol_train_app_conf", + ) + test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) + self.assertIsInstance(test_module, BYOLModule) + self.assertIsNotNone(test_module.model) + + +class TestVideoMOCOV2ModuleConf(unittest.TestCase): + def test_init_with_hydra(self) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + test_conf = compose( + config_name="moco_v2_train_app_conf", + # overrides=["module/model=resnet"], + ) + test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) + self.assertIsInstance(test_module, MOCOV2Module) + self.assertIsNotNone(test_module.model) diff --git a/pytorchvideo_trainer/tests/test_task_byol.py b/pytorchvideo_trainer/tests/test_task_byol.py new file mode 100644 index 00000000..463dfd3b --- /dev/null +++ b/pytorchvideo_trainer/tests/test_task_byol.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-strict +from torchrecipes.core.base_train_app import BaseTrainApp +from util import ( + create_small_kinetics_dataset, + run_locally, + tempdir, + BaseTrainAppTestCase, +) + + +class TestBYOLTrainApp(BaseTrainAppTestCase): + def get_train_app( + self, + root_dir: str, + fast_dev_run: bool = True, + logger: bool = False, + ) -> BaseTrainApp: + create_small_kinetics_dataset(root_dir) + overrides = [ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "module.knn_memory.length=50", + "module.knn_memory.knn_k=2", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "trainer.logger=false", + ] + app = self.create_app_from_hydra( + config_module="pytorchvideo_trainer.conf", + config_name="byol_train_app_conf", + overrides=overrides, + ) + trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger} + self.mock_trainer_params(app, trainer_overrides) + return app + + @run_locally + @tempdir + def test_byol_app_train_test_30_views(self, root_dir: str) -> None: + train_app = self.get_train_app( + root_dir=root_dir, fast_dev_run=False, logger=False + ) + output = train_app.train() + self.assertIsNotNone(output) + output = train_app.test() + self.assertIsNotNone(output) + + video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None) + num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10) + num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3) + self.assertIsNotNone(video_clips_cnts) + for _, sample_cnts in video_clips_cnts.items(): + self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts) diff --git a/pytorchvideo_trainer/tests/test_task_moco_v2.py b/pytorchvideo_trainer/tests/test_task_moco_v2.py new file mode 100644 index 00000000..46c55d6a --- /dev/null +++ b/pytorchvideo_trainer/tests/test_task_moco_v2.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-strict +from torchrecipes.core.base_train_app import BaseTrainApp +from util import ( + create_small_kinetics_dataset, + run_locally, + tempdir, + BaseTrainAppTestCase, +) + + +class TestMOCOV2TrainApp(BaseTrainAppTestCase): + def get_train_app( + self, + root_dir: str, + fast_dev_run: bool = True, + logger: bool = False, + ) -> BaseTrainApp: + create_small_kinetics_dataset(root_dir) + overrides = [ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "module.knn_memory.length=50", + "module.knn_memory.knn_k=2", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "trainer.logger=false", + ] + + app = self.create_app_from_hydra( + config_module="pytorchvideo_trainer.conf", + config_name="moco_v2_train_app_conf", + overrides=overrides, + ) + trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger} + self.mock_trainer_params(app, trainer_overrides) + return app + + @run_locally + @tempdir + def test_moco_v2_app_train_test_30_views(self, root_dir: str) -> None: + train_app = self.get_train_app( + root_dir=root_dir, fast_dev_run=False, logger=False + ) + output = train_app.train() + self.assertIsNotNone(output) + output = train_app.test() + self.assertIsNotNone(output) + + video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None) + num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10) + num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3) + self.assertIsNotNone(video_clips_cnts) + for _, sample_cnts in video_clips_cnts.items(): + self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts) diff --git a/pytorchvideo_trainer/tests/test_task_module_all.py b/pytorchvideo_trainer/tests/test_task_module_all.py new file mode 100644 index 00000000..94e76c79 --- /dev/null +++ b/pytorchvideo_trainer/tests/test_task_module_all.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import unittest +from typing import Any + +import hydra +from hydra import compose, initialize_config_module +from hydra.utils import instantiate # @manual +from omegaconf import OmegaConf +from pytorch_lightning import Trainer +from pytorchvideo_trainer.datamodule.datamodule import ( + VideoClassificationDataModuleConf, +) +from pytorchvideo_trainer.train_app import ( + VideoClassificationTrainAppConf, +) +from util import create_small_kinetics_dataset, run_locally, tempdir + + +class TestMain(unittest.TestCase): + # pyre-fixme[3]: Return annotation cannot be `Any`. + def get_datamodule(self, cfg: VideoClassificationDataModuleConf) -> Any: + test_data_module = instantiate( + cfg, + _recursive_=False, + ) + return test_data_module + + def train(self, cfg: VideoClassificationTrainAppConf) -> None: + print(OmegaConf.to_yaml(cfg)) + test_module = hydra.utils.instantiate(cfg.module, _recursive_=False) + test_data_module = self.get_datamodule(cfg.datamodule) + # pyre-fixme[6]: Expected `SupportsKeysAndGetItem[Variable[_KT], + # Variable[_VT]]` for 1st param but got `TrainerConf`. + trainer_params = dict(cfg.trainer) + trainer_params["logger"] = True + trainer_params["checkpoint_callback"] = False + trainer_params["fast_dev_run"] = True + pl_trainer = Trainer(**trainer_params) + pl_trainer.fit(model=test_module, datamodule=test_data_module) + + @run_locally + @tempdir + def test_train_video_model(self, root_dir: str) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + create_small_kinetics_dataset(root_dir) + # Config is relative to a module + cfg = compose( + config_name="video_classification_train_app_conf", + overrides=[ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "+module/lr_scheduler=cosine_with_warmup", + "trainer.logger=true", + ], + ) + self.assertEqual(cfg.trainer.max_epochs, 1) + + self.train(cfg) + + @run_locally + @tempdir + def test_train_video_model_simclr(self, root_dir: str) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + create_small_kinetics_dataset(root_dir) + # Config is relative to a module + cfg = compose( + config_name="simclr_train_app_conf", + overrides=[ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "module.knn_memory.length=50", + "module.knn_memory.knn_k=2", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "trainer.logger=true", + ], + ) + self.assertEqual(cfg.trainer.max_epochs, 1) + + self.train(cfg) + + @run_locally + @tempdir + def test_train_video_model_byol(self, root_dir: str) -> None: + with initialize_config_module(config_module="pytorchvideo_trainer.conf"): + create_small_kinetics_dataset(root_dir) + # Config is relative to a module + cfg = compose( + config_name="byol_train_app_conf", + overrides=[ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "module.knn_memory.length=50", + "module.knn_memory.knn_k=2", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "trainer.logger=true", + ], + ) + self.assertEqual(cfg.trainer.max_epochs, 1) + + self.train(cfg) diff --git a/pytorchvideo_trainer/tests/test_task_simclr.py b/pytorchvideo_trainer/tests/test_task_simclr.py new file mode 100644 index 00000000..85871a66 --- /dev/null +++ b/pytorchvideo_trainer/tests/test_task_simclr.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-strict +from torchrecipes.core.base_train_app import BaseTrainApp +from util import ( + create_small_kinetics_dataset, + run_locally, + tempdir, + BaseTrainAppTestCase, +) + + +class TestSimCLRTrainApp(BaseTrainAppTestCase): + def get_train_app( + self, + root_dir: str, + fast_dev_run: bool = True, + logger: bool = False, + ) -> BaseTrainApp: + create_small_kinetics_dataset(root_dir) + overrides = [ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "module.knn_memory.length=50", + "module.knn_memory.knn_k=2", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "trainer.logger=false", + ] + app = self.create_app_from_hydra( + config_module="pytorchvideo_trainer.conf", + config_name="simclr_train_app_conf", + overrides=overrides, + ) + trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger} + self.mock_trainer_params(app, trainer_overrides) + return app + + @run_locally + @tempdir + def test_simclr_app_train_test_30_views(self, root_dir: str) -> None: + train_app = self.get_train_app( + root_dir=root_dir, fast_dev_run=False, logger=False + ) + output = train_app.train() + self.assertIsNotNone(output) + output = train_app.test() + self.assertIsNotNone(output) + + video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None) + num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10) + num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3) + self.assertIsNotNone(video_clips_cnts) + for _, sample_cnts in video_clips_cnts.items(): + self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts) diff --git a/pytorchvideo_trainer/tests/test_task_video_classification.py b/pytorchvideo_trainer/tests/test_task_video_classification.py new file mode 100644 index 00000000..d4f86a1d --- /dev/null +++ b/pytorchvideo_trainer/tests/test_task_video_classification.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# pyre-strict +from torchrecipes.core.base_train_app import BaseTrainApp +from util import ( + create_small_kinetics_dataset, + run_locally, + tempdir, + BaseTrainAppTestCase, +) + + +class TestVideoClassificationTrainApp(BaseTrainAppTestCase): + def get_train_app( + self, + root_dir: str, + precise_bn_num_batches: int = 0, + fast_dev_run: bool = True, + logger: bool = False, + ) -> BaseTrainApp: + create_small_kinetics_dataset(root_dir) + overrides = [ + f"datamodule.dataloader.train.dataset.data_path={root_dir}/train.csv", + f"datamodule.dataloader.val.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.test.dataset.data_path={root_dir}/val.csv", + f"datamodule.dataloader.train.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.val.dataset.video_path_prefix={root_dir}", + f"datamodule.dataloader.test.dataset.video_path_prefix={root_dir}", + "datamodule.dataloader.train.num_workers=0", + "datamodule.dataloader.val.num_workers=0", + "datamodule.dataloader.test.num_workers=0", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + "+module/lr_scheduler=cosine_with_warmup", + "trainer.logger=false", + ] + if precise_bn_num_batches > 0: + overrides.extend( + [ + "+callbacks=precise_bn", + f"callbacks.precise_bn.num_batches={precise_bn_num_batches}", + "datamodule.dataloader.train.batch_size=2", + "datamodule.dataloader.val.batch_size=2", + "datamodule.dataloader.test.batch_size=2", + ] + ) + app = self.create_app_from_hydra( + config_module="pytorchvideo_trainer.conf", + config_name="video_classification_train_app_conf", + overrides=overrides, + ) + trainer_overrides = {"fast_dev_run": fast_dev_run, "logger": logger} + self.mock_trainer_params(app, trainer_overrides) + return app + + @run_locally + @tempdir + def test_video_classification_app_train(self, root_dir: str) -> None: + train_app = self.get_train_app(root_dir=root_dir, logger=False) + output = train_app.train() + self.assertIsNotNone(output) + + @run_locally + @tempdir + def test_video_classification_app_train_with_precise_bn( + self, root_dir: str + ) -> None: + train_app = self.get_train_app( + root_dir=root_dir, precise_bn_num_batches=2, logger=False + ) + output = train_app.train() + self.assertIsNotNone(output) + + @run_locally + @tempdir + def test_video_classification_app_test(self, root_dir: str) -> None: + train_app = self.get_train_app(root_dir=root_dir) + output = train_app.test() + self.assertIsNotNone(output) + + @run_locally + @tempdir + def test_video_classification_app_test_30_views(self, root_dir: str) -> None: + train_app = self.get_train_app(root_dir=root_dir, fast_dev_run=False) + train_app.test() + video_clips_cnts = getattr(train_app.module, "video_clips_cnts", None) + num_ensemble_views = getattr(train_app.datamodule, "num_ensemble_views", 10) + num_spatial_crops = getattr(train_app.datamodule, "num_spatial_crops", 3) + self.assertIsNotNone(video_clips_cnts) + for _, sample_cnts in video_clips_cnts.items(): + self.assertEqual(num_ensemble_views * num_spatial_crops, sample_cnts) diff --git a/pytorchvideo_trainer/tests/util.py b/pytorchvideo_trainer/tests/util.py new file mode 100644 index 00000000..ac6c194a --- /dev/null +++ b/pytorchvideo_trainer/tests/util.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import csv +import os +from functools import wraps +from tempfile import TemporaryDirectory +from typing import Any, Callable, Dict, List, Optional +from unittest import mock + +import testslide +import torch +import torchvision.io as io +from hydra import compose, initialize_config_module +from hydra.utils import instantiate +from omegaconf import OmegaConf +from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput + + +def create_small_kinetics_dataset(root_dir: str) -> None: + """ + A test utility function to create a small Kinetics like dataset + + Args: + root_dir(str): The directory to create the dataset in. + Typically, a temporary directory is used. + """ + video_codec = "libx264rgb" + options = {"crf": "0"} + height: int = 250 + width: int = 250 + num_frames = 20 + fps = 5 + data = create_dummy_video_frames(num_frames, height, width) + + train_data = [ + ["a.mp4", "308"], + ["b.mp4", "298"], + ["c.mp4", "240"], + ["d.mp4", "363"], + ] + + val_data = [ + ["a.mp4", "151"], + ] + + for i in range(4): + io.write_video( + os.path.join(root_dir, train_data[i][0]), + data, + fps=fps, + video_codec=video_codec, + options=options, + ) + + train_file = os.path.join(root_dir, "train.csv") + write_single_csv_file(train_file, train_data) + + val_file = os.path.join(root_dir, "val.csv") + write_single_csv_file(val_file, val_data) + + +# pyre-fixme[2]: Parameter annotation cannot contain `Any`. +def write_single_csv_file(file_name: str, data: List[Any]) -> None: + with open(file_name, "w+", newline="") as csvfile: + data_writer = csv.writer( + # pyre-fixme[6]: Expected `_Writer` for 1st param but got `TextIOWrapper`. + csvfile, + delimiter=" ", + ) + for row in data: + data_writer.writerow(row) + + +# pyre-fixme[3] +def create_dummy_video_frames(num_frames: int, height: int, width: int): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + return torch.stack(data, 0) + + +def run_locally(func: Callable) -> Callable: # pyre-ignore[24] + """A decorator to run unittest locally.""" + + @wraps(func) + def wrapper(*args, **kwargs): # pyre-ignore[2,3] + with mock.patch( + "torch.distributed.is_available", + return_value=False, + ): + return func(*args, **kwargs) + + return wrapper + + +def tempdir(func: Callable) -> Callable: # pyre-ignore[24] + """A decorator for creating a tempory directory that + is cleaned up after function execution.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): # pyre-ignore[2,3] + with TemporaryDirectory() as temp: + return func(self, temp, *args, **kwargs) + + return wrapper + + +def get_mock_init_trainer_params( + overrides: Optional[Dict[str, Any]] = None, +) -> Callable[..., Dict[str, Any]]: + """ + Order of trainer_params setting in unit test: + - First call original function, which sets params from config + - Then override some params to disable logger and checkpoint + - Apply any test-specific overrides. + """ + + def mock_init_trainer_params( + original: Callable[..., Dict[str, Any]], + ) -> Dict[str, Any]: + trainer_params = original() + + trainer_params["logger"] = False + trainer_params["checkpoint_callback"] = False + trainer_params["fast_dev_run"] = True + + if overrides: + trainer_params.update(overrides) + + return trainer_params + + return mock_init_trainer_params + + +class BaseTrainAppTestCase(testslide.TestCase): + """All Standard TrainApp unit tests should inherit from this class.""" + + def mock_trainer_params( + self, app: BaseTrainApp, overrides: Optional[Dict[str, Any]] = None + ) -> None: + self.mock_callable( + app, "_init_trainer_params", allow_private=True + ).with_wrapper(get_mock_init_trainer_params(overrides)) + + def create_app_from_hydra( + self, + config_module: str, + config_name: str, + overrides: Optional[List[str]] = None, + ) -> BaseTrainApp: + with initialize_config_module(config_module=config_module): + cfg = compose(config_name=config_name, overrides=overrides or []) + print(OmegaConf.to_yaml(cfg)) + return instantiate(cfg, _recursive_=False) + + def assert_train_output(self, output: TrainOutput) -> None: + self.assertIsNotNone(output) + # Ensure logger is set to False in test to avoid dependency on Manifold + self.assertIsNone(output.tensorboard_log_dir)