-
Notifications
You must be signed in to change notification settings - Fork 412
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PytorchVideo - Lightning Training pipeline (#158)
Summary: Pull Request resolved: #158 Pull Request resolved: fairinternal/pytorchvideo#47 1. Support Video classification 2. Support Video SSL - SimCLR, BYOL, MoCo Reviewed By: haooooooqi Differential Revision: D33431232 fbshipit-source-id: 47ad9c35d45e4c8f9ac95e497dd7b582cb4084a9
- Loading branch information
1 parent
2679d5c
commit 874d27c
Showing
77 changed files
with
5,714 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
|
||
__version__ = "0.1.3" | ||
__version__ = "0.1.5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
## PyTorchVideo Trainer | ||
|
||
A [PyTorch-Lightning]() based trainer supporting PytorchVideo models and dataloaders for various video understanding tasks. | ||
|
||
Currently supported tasks include: | ||
|
||
- Video Action Recognition: ResNet's, SlowFast Models, X3D models and MViT | ||
- Video Self-Supervised Learning: SimCLR, BYOL, MoCo | ||
- (Planned) Video Action Detection | ||
|
||
## Installation | ||
|
||
These instructions assumes that both pytorch and torchvision are already installed | ||
using the instructions in [INSTALL.md](https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md#requirements) | ||
|
||
Install the required additional dependency `recipes` by running the following command, | ||
``` | ||
pip install "git+https://github.com/facebookresearch/recipes.git" | ||
``` | ||
|
||
Post that, install PyTorchVideo Trainer by running, | ||
``` | ||
git clone https://github.com/facebookresearch/pytorchvideo.git | ||
cd pytorchvideo/pytorchvideo_trainer | ||
pip install -e . | ||
# For developing and testing | ||
pip install -e . [test,dev] | ||
``` | ||
|
||
## Testing | ||
|
||
Before running the tests, please ensure that you installed the necessary additional test dependencies. | ||
|
||
Use the the following command to run the tests: | ||
``` | ||
# From the current directory | ||
python -m unittest discover -v -s ./tests | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
|
||
|
||
def register_components() -> None: | ||
""" | ||
Calls register_components() for all subfolders so we can register | ||
subcomponents to Hydra's ConfigStore. | ||
""" | ||
import pytorchvideo_trainer.datamodule.datamodule # noqa | ||
import pytorchvideo_trainer.module.byol # noqa | ||
import pytorchvideo_trainer.module.lr_policy # noqa | ||
import pytorchvideo_trainer.module.moco_v2 # noqa | ||
import pytorchvideo_trainer.module.optimizer # noqa | ||
import pytorchvideo_trainer.module.simclr # noqa | ||
import pytorchvideo_trainer.module.video_classification # noqa | ||
import pytorchvideo_trainer.train_app # noqa |
8 changes: 8 additions & 0 deletions
8
pytorchvideo_trainer/pytorchvideo_trainer/callbacks/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
|
||
from .precise_batchnorm import PreciseBn # noqa | ||
|
||
|
||
__all__ = [ | ||
"PreciseBn", | ||
] |
70 changes: 70 additions & 0 deletions
70
pytorchvideo_trainer/pytorchvideo_trainer/callbacks/precise_batchnorm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
from typing import Generator | ||
|
||
import torch | ||
from fvcore.nn.precise_bn import update_bn_stats | ||
from pytorch_lightning.callbacks import Callback | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.trainer.trainer import Trainer | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class PreciseBn(Callback): | ||
""" | ||
Recompute and update the batch norm stats to make them more precise. During | ||
training both BN stats and the weight are changing after every iteration, so | ||
the running average can not precisely reflect the actual stats of the | ||
current model. | ||
In this callaback, the BN stats are recomputed with fixed weights, to make | ||
the running average more precise during Training Phase. Specifically, it | ||
computes the true average of per-batch mean/variance instead of the | ||
running average. See Sec. 3 of the paper "Rethinking Batch in BatchNorm" | ||
for details. | ||
""" | ||
|
||
def __init__(self, num_batches: int) -> None: | ||
""" | ||
Args: | ||
num_batches (int): Number of steps / mini-batches to | ||
perform to sample for updating the precise batchnorm | ||
stats. | ||
""" | ||
self.num_batches = num_batches | ||
|
||
def _get_precise_bn_loader( | ||
self, data_loader: DataLoader, pl_module: LightningModule | ||
) -> Generator[torch.Tensor, None, None]: | ||
for batch in data_loader: | ||
inputs = batch[pl_module.modality_key] | ||
if isinstance(inputs, list): | ||
inputs = [x.to(pl_module.device) for x in inputs] | ||
else: | ||
inputs = inputs.to(pl_module.device) | ||
yield inputs | ||
|
||
def on_train_epoch_end( | ||
self, | ||
trainer: Trainer, | ||
pl_module: LightningModule, | ||
) -> None: | ||
""" | ||
Called at the end of every epoch only during the training | ||
phase. | ||
Args: | ||
trainer (Trainer): A PyTorch-Lightning trainer object. | ||
pl_module (LightningModule): A PyTorch-Lightning module. | ||
Typically supported modules include - | ||
pytorchvideo_trainer.module.VideoClassificationModule, etc. | ||
""" | ||
# pyre-ignore[16] | ||
dataloader = trainer.datamodule.train_dataloader() | ||
precise_bn_loader = self._get_precise_bn_loader( | ||
data_loader=dataloader, pl_module=pl_module | ||
) | ||
update_bn_stats( | ||
model=pl_module.model, # pyre-ignore[6] | ||
data_loader=precise_bn_loader, | ||
num_iters=self.num_batches, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
|
||
import torchrecipes.core.conf # noqa | ||
|
||
# Components to register with this config | ||
from pytorchvideo_trainer import register_components | ||
|
||
register_components() |
28 changes: 28 additions & 0 deletions
28
pytorchvideo_trainer/pytorchvideo_trainer/conf/byol_train_app_conf.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp | ||
|
||
defaults: | ||
- schema/module: byol_module_conf | ||
- schema/module/optim: optim_conf | ||
- schema/datamodule: ptv_video_classification_data_module_conf | ||
- datamodule/dataloader: kinetics_contrastive | ||
- logger: ptl | ||
- datamodule/transforms: kinetics_contrastive | ||
- module/knn_memory: kinetics_k400 | ||
- module/model: slow_r50_byol | ||
- module/loss: similarity | ||
- module/optim: sgd_ssl | ||
- module/metrics: accuracy | ||
- schema/trainer: trainer | ||
- trainer: cpu | ||
- callbacks: null | ||
- _self_ | ||
trainer: | ||
sync_batchnorm: false # set this to true for training | ||
|
||
module: | ||
momentum_anneal_cosine: true | ||
|
||
hydra: | ||
searchpath: | ||
- pkg://pytorchvideo_trainer.conf | ||
- pkg://torchrecipes.core.conf |
3 changes: 3 additions & 0 deletions
3
pytorchvideo_trainer/pytorchvideo_trainer/conf/callbacks/precise_bn.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
precise_bn: | ||
_target_: pytorchvideo_trainer.callbacks.precise_batchnorm.PreciseBn | ||
num_batches: null |
72 changes: 72 additions & 0 deletions
72
pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_mvit_16x4.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp | ||
|
||
defaults: | ||
- schema/module: video_classification_module_conf_vision_transformer | ||
- schema/module/optim: optim_conf | ||
- schema/datamodule: ptv_video_classification_data_module_conf | ||
- datamodule/dataloader: kinetics_classification | ||
- logger: ptl | ||
- datamodule/transforms: kinetics_classification_mvit_16x4 | ||
- module/model: mvit_base_16x4 | ||
- module/loss: soft_cross_entropy | ||
- module/optim: adamw | ||
- module/metrics: accuracy | ||
- module/lr_scheduler: cosine_with_warmup | ||
- schema/trainer: trainer | ||
- trainer: multi_gpu | ||
- _self_ | ||
|
||
module: | ||
clip_gradient_norm: 1.0 | ||
ensemble_method: "sum" | ||
lr_scheduler: | ||
max_iters: 200 | ||
warmup_start_lr: 1.6e-05 | ||
warmup_iters: 30 | ||
cosine_after_warmup: true | ||
cosine_end_lr: 1.6e-05 | ||
optim: | ||
lr: 0.0016 | ||
weight_decay: 0.05 | ||
method: adamw | ||
zero_weight_decay_1d_param: true | ||
batch_transform: | ||
_target_: pytorchvideo_trainer.datamodule.transforms.MixVideoBatchWrapper | ||
mixup_alpha: 0.8 | ||
cutmix_prob: 0.5 | ||
cutmix_alpha: 1.0 | ||
label_smoothing: 0.1 | ||
|
||
datamodule: | ||
dataloader: | ||
train: | ||
batch_size: 2 | ||
dataset: | ||
clip_sampler: | ||
clip_duration: 2.13 | ||
collate_fn: | ||
_target_: pytorchvideo_trainer.datamodule.collators.build_collator_from_name | ||
name: multiple_samples_collate | ||
val: | ||
batch_size: 8 | ||
dataset: | ||
clip_sampler: | ||
clip_duration: 2.13 | ||
test: | ||
batch_size: 8 | ||
dataset: | ||
clip_sampler: | ||
clip_duration: 2.13 | ||
|
||
trainer: | ||
num_nodes: 16 | ||
gpus: 8 | ||
log_gpu_memory: null | ||
max_epochs: 200 | ||
sync_batchnorm: False | ||
replace_sampler_ddp: False | ||
|
||
hydra: | ||
searchpath: | ||
- pkg://pytorchvideo_trainer.conf | ||
- pkg://torchrecipes.core.conf |
46 changes: 46 additions & 0 deletions
46
pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slow_8x8_r50.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp | ||
|
||
defaults: | ||
- schema/module: video_classification_module_conf | ||
- schema/module/optim: optim_conf | ||
- schema/datamodule: ptv_video_classification_data_module_conf | ||
- datamodule/dataloader: kinetics_classification | ||
- logger: ptl | ||
- datamodule/transforms: kinetics_classification_slow | ||
- module/model: slow_r50 | ||
- module/loss: cross_entropy | ||
- module/optim: sgd | ||
- module/metrics: accuracy | ||
- module/lr_scheduler: cosine_with_warmup | ||
- schema/trainer: trainer | ||
- trainer: multi_gpu | ||
- callbacks: precise_bn | ||
- _self_ | ||
|
||
module: | ||
ensemble_method: "sum" | ||
lr_scheduler: | ||
max_iters: 196 | ||
warmup_start_lr: 0.01 | ||
warmup_iters: 34 | ||
optim: | ||
lr: 0.8 | ||
nesterov: true | ||
|
||
callbacks: | ||
precise_bn: | ||
num_batches: 200 | ||
|
||
trainer: | ||
num_nodes: 8 | ||
gpus: 8 | ||
log_gpu_memory: null | ||
max_epochs: 196 | ||
sync_batchnorm: False | ||
replace_sampler_ddp: False | ||
|
||
|
||
hydra: | ||
searchpath: | ||
- pkg://pytorchvideo_trainer.conf | ||
- pkg://torchrecipes.core.conf |
46 changes: 46 additions & 0 deletions
46
pytorchvideo_trainer/pytorchvideo_trainer/conf/classification_slowfast_8x8_r50.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
_target_: pytorchvideo_trainer.train_app.VideoClassificationTrainApp | ||
|
||
defaults: | ||
- schema/module: video_classification_module_conf | ||
- schema/module/optim: optim_conf | ||
- schema/datamodule: ptv_video_classification_data_module_conf | ||
- datamodule/dataloader: kinetics_classification | ||
- logger: ptl | ||
- datamodule/transforms: kinetics_classification_slowfast | ||
- module/model: slowfast_r50 | ||
- module/loss: cross_entropy | ||
- module/optim: sgd | ||
- module/metrics: accuracy | ||
- module/lr_scheduler: cosine_with_warmup | ||
- schema/trainer: trainer | ||
- trainer: multi_gpu | ||
- callbacks: precise_bn | ||
- _self_ | ||
|
||
module: | ||
ensemble_method: "sum" | ||
lr_scheduler: | ||
max_iters: 196 | ||
warmup_start_lr: 0.01 | ||
warmup_iters: 34 | ||
optim: | ||
lr: 0.8 | ||
nesterov: true | ||
|
||
callbacks: | ||
precise_bn: | ||
num_batches: 200 | ||
|
||
trainer: | ||
num_nodes: 8 | ||
gpus: 8 | ||
log_gpu_memory: null | ||
max_epochs: 196 | ||
sync_batchnorm: False | ||
replace_sampler_ddp: False | ||
|
||
|
||
hydra: | ||
searchpath: | ||
- pkg://pytorchvideo_trainer.conf | ||
- pkg://torchrecipes.core.conf |
Oops, something went wrong.