Skip to content

Commit

Permalink
PytorchVideo - Lightning Training pipeline (#158)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #158

Pull Request resolved: fairinternal/pytorchvideo#47

1. Support Video classification
2. Support Video SSL - SimCLR, BYOL, MoCo

Reviewed By: haooooooqi

Differential Revision: D33431232

fbshipit-source-id: 47ad9c35d45e4c8f9ac95e497dd7b582cb4084a9
  • Loading branch information
Kalyan Vasudev Alwala authored and facebook-github-bot committed Jan 19, 2022
1 parent 2679d5c commit 874d27c
Show file tree
Hide file tree
Showing 77 changed files with 5,714 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pytorchvideo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

__version__ = "0.1.3"
__version__ = "0.1.5"
9 changes: 8 additions & 1 deletion pytorchvideo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ def create_acoustic_bottleneck_block(
)


def _trivial_sum(x, y):
"""
Utility function used in lieu of lamda which are not picklable
"""
return x + y


def create_res_block(
*,
# Bottleneck Block configs.
Expand All @@ -324,7 +331,7 @@ def create_res_block(
dim_out: int,
bottleneck: Callable,
use_shortcut: bool = False,
branch_fusion: Callable = lambda x, y: x + y,
branch_fusion: Callable = _trivial_sum,
# Conv configs.
conv_a_kernel_size: Tuple[int] = (3, 1, 1),
conv_a_stride: Tuple[int] = (2, 1, 1),
Expand Down
39 changes: 39 additions & 0 deletions pytorchvideo_trainer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
## PyTorchVideo Trainer

A [PyTorch-Lightning]() based trainer supporting PytorchVideo models and dataloaders for various video understanding tasks.

Currently supported tasks include:

- Video Action Recognition: ResNet's, SlowFast Models, X3D models and MViT
- Video Self-Supervised Learning: SimCLR, BYOL, MoCo
- (Planned) Video Action Detection

## Installation

These instructions assumes that both pytorch and torchvision are already installed
using the instructions in [INSTALL.md](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md#requirements)

Install the required additional dependency `recipes` by running the following command,
```
pip install "git+https://github.com/facebookresearch/recipes.git"
```

Post that, install PyTorchVideo Trainer by running,
```
git clone https://github.com/facebookresearch/pytorchvideo.git
cd pytorchvideo/pytorchvideo_trainer
pip install -e .
# For developing and testing
pip install -e . [test,dev]
```

## Testing

Before running the tests, please ensure that you installed the necessary additional test dependencies.

Use the the following command to run the tests:
```
# From the current directory
python -m unittest discover -v -s ./tests
```
16 changes: 16 additions & 0 deletions pytorchvideo_trainer/pytorchvideo_trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.


def register_components() -> None:
"""
Calls register_components() for all subfolders so we can register
subcomponents to Hydra's ConfigStore.
"""
import pytorchvideo_trainer.datamodule.datamodule # noqa
import pytorchvideo_trainer.module.byol # noqa
import pytorchvideo_trainer.module.lr_policy # noqa
import pytorchvideo_trainer.module.moco_v2 # noqa
import pytorchvideo_trainer.module.optimizer # noqa
import pytorchvideo_trainer.module.simclr # noqa
import pytorchvideo_trainer.module.video_classification # noqa
import pytorchvideo_trainer.train_app # noqa
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from .precise_batchnorm import PreciseBn # noqa


__all__ = [
"PreciseBn",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Facebook, Inc. and its affiliates.

from typing import Generator

import torch
from fvcore.nn.precise_bn import update_bn_stats
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader


class PreciseBn(Callback):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this callaback, the BN stats are recomputed with fixed weights, to make
the running average more precise during Training Phase. Specifically, it
computes the true average of per-batch mean/variance instead of the
running average. See Sec. 3 of the paper "Rethinking Batch in BatchNorm"
for details.
"""

def __init__(self, num_batches: int) -> None:
"""
Args:
num_batches (int): Number of steps / mini-batches to
perform to sample for updating the precise batchnorm
stats.
"""
self.num_batches = num_batches

def _get_precise_bn_loader(
self, data_loader: DataLoader, pl_module: LightningModule
) -> Generator[torch.Tensor, None, None]:
for batch in data_loader:
inputs = batch[pl_module.modality_key]
if isinstance(inputs, list):
inputs = [x.to(pl_module.device) for x in inputs]
else:
inputs = inputs.to(pl_module.device)
yield inputs

def on_train_epoch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
"""
Called at the end of every epoch only during the training
phase.
Args:
trainer (Trainer): A PyTorch-Lightning trainer object.
pl_module (LightningModule): A PyTorch-Lightning module.
Typically supported modules include -
pytorchvideo_trainer.module.VideoClassificationModule, etc.
"""
# pyre-ignore[16]
dataloader = trainer.datamodule.train_dataloader()
precise_bn_loader = self._get_precise_bn_loader(
data_loader=dataloader, pl_module=pl_module
)
update_bn_stats(
model=pl_module.model, # pyre-ignore[6]
data_loader=precise_bn_loader,
num_iters=self.num_batches,
)
8 changes: 8 additions & 0 deletions pytorchvideo_trainer/pytorchvideo_trainer/conf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torchrecipes.core.conf # noqa

# Components to register with this config
from pytorchvideo_trainer import register_components

register_components()
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp

defaults:
- schema/module: byol_module_conf
- schema/module/optim: optim_conf
- schema/datamodule: ptv_video_classification_data_module_conf
- datamodule/dataloader: kinetics_contrastive
- logger: ptl
- datamodule/transforms: kinetics_contrastive
- module/knn_memory: kinetics_k400
- module/model: slow_r50_byol
- module/loss: similarity
- module/optim: sgd_ssl
- module/metrics: accuracy
- schema/trainer: trainer
- trainer: cpu
- callbacks: null
- _self_
trainer:
sync_batchnorm: false # set this to true for training

module:
momentum_anneal_cosine: true

hydra:
searchpath:
- pkg://pytorchvideo_trainer.conf
- pkg://torchrecipes.core.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
precise_bn:
_target_: pytorchvideo_trainer.callbacks.precise_batchnorm.PreciseBn
num_batches: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp

defaults:
- schema/module: video_classification_module_conf_vision_transformer
- schema/module/optim: optim_conf
- schema/datamodule: ptv_video_classification_data_module_conf
- datamodule/dataloader: kinetics_classification
- logger: ptl
- datamodule/transforms: kinetics_classification_mvit_16x4
- module/model: mvit_base_16x4
- module/loss: soft_cross_entropy
- module/optim: adamw
- module/metrics: accuracy
- module/lr_scheduler: cosine_with_warmup
- schema/trainer: trainer
- trainer: multi_gpu
- _self_

module:
clip_gradient_norm: 1.0
ensemble_method: "sum"
lr_scheduler:
max_iters: 200
warmup_start_lr: 1.6e-05
warmup_iters: 30
cosine_after_warmup: true
cosine_end_lr: 1.6e-05
optim:
lr: 0.0016
weight_decay: 0.05
method: adamw
zero_weight_decay_1d_param: true
batch_transform:
_target_: pytorchvideo_trainer.datamodule.transforms.MixVideoBatchWrapper
mixup_alpha: 0.8
cutmix_prob: 0.5
cutmix_alpha: 1.0
label_smoothing: 0.1

datamodule:
dataloader:
train:
batch_size: 2
dataset:
clip_sampler:
clip_duration: 2.13
collate_fn:
_target_: pytorchvideo_trainer.datamodule.collators.build_collator_from_name
name: multiple_samples_collate
val:
batch_size: 8
dataset:
clip_sampler:
clip_duration: 2.13
test:
batch_size: 8
dataset:
clip_sampler:
clip_duration: 2.13

trainer:
num_nodes: 16
gpus: 8
log_gpu_memory: null
max_epochs: 200
sync_batchnorm: False
replace_sampler_ddp: False

hydra:
searchpath:
- pkg://pytorchvideo_trainer.conf
- pkg://torchrecipes.core.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp

defaults:
- schema/module: video_classification_module_conf
- schema/module/optim: optim_conf
- schema/datamodule: ptv_video_classification_data_module_conf
- datamodule/dataloader: kinetics_classification
- logger: ptl
- datamodule/transforms: kinetics_classification_slow
- module/model: slow_r50
- module/loss: cross_entropy
- module/optim: sgd
- module/metrics: accuracy
- module/lr_scheduler: cosine_with_warmup
- schema/trainer: trainer
- trainer: multi_gpu
- callbacks: precise_bn
- _self_

module:
ensemble_method: "sum"
lr_scheduler:
max_iters: 196
warmup_start_lr: 0.01
warmup_iters: 34
optim:
lr: 0.8
nesterov: true

callbacks:
precise_bn:
num_batches: 200

trainer:
num_nodes: 8
gpus: 8
log_gpu_memory: null
max_epochs: 196
sync_batchnorm: False
replace_sampler_ddp: False


hydra:
searchpath:
- pkg://pytorchvideo_trainer.conf
- pkg://torchrecipes.core.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp

defaults:
- schema/module: video_classification_module_conf
- schema/module/optim: optim_conf
- schema/datamodule: ptv_video_classification_data_module_conf
- datamodule/dataloader: kinetics_classification
- logger: ptl
- datamodule/transforms: kinetics_classification_slowfast
- module/model: slowfast_r50
- module/loss: cross_entropy
- module/optim: sgd
- module/metrics: accuracy
- module/lr_scheduler: cosine_with_warmup
- schema/trainer: trainer
- trainer: multi_gpu
- callbacks: precise_bn
- _self_

module:
ensemble_method: "sum"
lr_scheduler:
max_iters: 196
warmup_start_lr: 0.01
warmup_iters: 34
optim:
lr: 0.8
nesterov: true

callbacks:
precise_bn:
num_batches: 200

trainer:
num_nodes: 8
gpus: 8
log_gpu_memory: null
max_epochs: 196
sync_batchnorm: False
replace_sampler_ddp: False


hydra:
searchpath:
- pkg://pytorchvideo_trainer.conf
- pkg://torchrecipes.core.conf
Loading

0 comments on commit 874d27c

Please sign in to comment.