diff --git a/README.md b/README.md index 9f3b6c0..c9df70d 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,17 @@ Here are the [Training/Validation Logs](https://api.wandb.ai/links/action_antici 👁️ **Note**: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that6 a much higher accuracy can be achieved with the same code, using a better [chosen model design and optimization parameters](https://github.com/tysam-code/hlb-CIFAR10). The authors have done no tuning since this repository is meant for understanding code, not pushing performance. +

Mixed precision training

+ +Mixed precision training is also supported and can be enabled by adding `--amp True` flag to above commands. Training progresses smoothly and achieves `80%+` validation accuracy on CIFAR-10 similar to training without AMP. + + +📝 **Note**: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from `12000` to `14500` ( `~20%`). At usual training batch size (`128`) there is small gain in GPU training memory (about 4%). + +

Distributed Data Parallel Training

+ +There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in [paper](https://arxiv.org/abs/2302.04869) (also see below) are obtained in DDP setups (`>64` GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast [distributed training setup](https://github.com/facebookresearch/SlowFast/blob/99a655bd533d7fddd7f79509e3dfaae811767b5c/slowfast/models/build.py#L69-L83). +

Running ImageNet, Kinetics-400 and more

For more usecases such as reproducing numbers from [original paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf), see the [full code in PySlowFast](https://github.com/facebookresearch/SlowFast) that supports diff --git a/main.py b/main.py index 29366cd..5956ee2 100644 --- a/main.py +++ b/main.py @@ -7,9 +7,10 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +from torch.cuda.amp import GradScaler + import torchvision import torchvision.transforms as transforms - from rev import RevViT parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training") @@ -53,6 +54,12 @@ type=bool, help="whether to use reversible backpropagation or not", ) +parser.add_argument( + "--amp", + default=False, + type=bool, + help="whether to use mixed precision training or not", +) args = parser.parse_args() @@ -102,6 +109,7 @@ patch_size=args.patch_size, image_size=args.image_size, num_classes=args.num_classes, + enable_amp=args.amp, ) model = model.to(device) @@ -113,6 +121,7 @@ criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) +scaler = GradScaler() # Training def train(epoch): @@ -122,12 +131,19 @@ def train(epoch): correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): + + # We do not need to specify AMP autocast in forward pass here since + # that is taken care of already in the forward of individual modules. inputs, targets = inputs.to(device), targets.to(device) - optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) - loss.backward() - optimizer.step() + + # standard pytorch AMP training setup + # scaler also works without amp training. + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() train_loss += loss.item() _, predicted = outputs.max(1) diff --git a/rev.py b/rev.py index 4caaf1e..aa463c6 100644 --- a/rev.py +++ b/rev.py @@ -14,13 +14,11 @@ def __init__( embed_dim=768, n_head=8, depth=8, - patch_size=( - 2, - 2, - ), # this patch size is used for CIFAR-10 + patch_size=(2, 2,), # this patch size is used for CIFAR-10 # --> (32 // 2)**2 = 256 sequence length image_size=(32, 32), # CIFAR-10 image size num_classes=10, + enable_amp=False, ): super().__init__() @@ -39,7 +37,11 @@ def __init__( # is contrained inside the block code and not exposed. self.layers = nn.ModuleList( [ - ReversibleBlock(dim=self.embed_dim, num_heads=self.n_head) + ReversibleBlock( + dim=self.embed_dim, + num_heads=self.n_head, + enable_amp=enable_amp, + ) for _ in range(self.depth) ] ) @@ -96,10 +98,7 @@ def forward(self, x): executing_fn = RevBackProp.apply # This takes care of switching between vanilla backprop and rev backprop - x = executing_fn( - x, - self.layers, - ) + x = executing_fn(x, self.layers,) # aggregate across sequence length x = x.mean(1) @@ -125,9 +124,7 @@ class RevBackProp(Function): @staticmethod def forward( - ctx, - x, - layers, + ctx, x, layers, ): """ Reversible Forward pass. @@ -167,10 +164,7 @@ def backward(ctx, dx): # this is recomputing both the activations and the gradients wrt # those activations. X_1, X_2, dX_1, dX_2 = layer.backward_pass( - Y_1=X_1, - Y_2=X_2, - dY_1=dX_1, - dY_2=dX_2, + Y_1=X_1, Y_2=X_2, dY_1=dX_1, dY_2=dX_2, ) # final input gradient to be passed backward to the patchification layer dx = torch.cat([dX_1, dX_2], dim=-1) @@ -186,11 +180,7 @@ class ReversibleBlock(nn.Module): See Section 3.3.2 in paper for details. """ - def __init__( - self, - dim, - num_heads, - ): + def __init__(self, dim, num_heads, enable_amp): """ Block is composed entirely of function F (Attention sub-block) and G (MLP sub-block) including layernorm. @@ -198,9 +188,11 @@ def __init__( super().__init__() # F and G can be arbitrary functions, here we use # simple attwntion and MLP sub-blocks using vanilla attention. - self.F = AttentionSubBlock(dim=dim, num_heads=num_heads) + self.F = AttentionSubBlock( + dim=dim, num_heads=num_heads, enable_amp=enable_amp + ) - self.G = MLPSubblock(dim=dim) + self.G = MLPSubblock(dim=dim, enable_amp=enable_amp) # note that since all functions are deterministic, and we are # not using any stochastic elements such as dropout, we do @@ -234,11 +226,7 @@ def forward(self, X_1, X_2): return Y_1, Y_2 def backward_pass( - self, - Y_1, - Y_2, - dY_1, - dY_2, + self, Y_1, Y_2, dY_1, dY_2, ): """ equation for activation recomputation: @@ -323,9 +311,7 @@ class MLPSubblock(nn.Module): """ def __init__( - self, - dim, - mlp_ratio=4, # standard for ViTs + self, dim, mlp_ratio=4, enable_amp=False, # standard for ViTs ): super().__init__() @@ -337,9 +323,18 @@ def __init__( nn.GELU(), nn.Linear(dim * mlp_ratio, dim), ) + self.enable_amp = enable_amp def forward(self, x): - return self.mlp(self.norm(x)) + + # The reason for implementing autocast inside forward loop instead + # in the main training logic is the implicit forward pass during + # memory efficient gradient backpropagation. In backward pass, the + # activations need to be recomputed, and if the forward has happened + # with mixed precision, the recomputation must also be so. This cannot + # be handled with the autocast setup in main training logic. + with torch.cuda.amp.autocast(enabled=self.enable_amp): + return self.mlp(self.norm(x)) class AttentionSubBlock(nn.Module): @@ -349,9 +344,7 @@ class AttentionSubBlock(nn.Module): """ def __init__( - self, - dim, - num_heads, + self, dim, num_heads, enable_amp=False, ): super().__init__() @@ -362,12 +355,15 @@ def __init__( # Note that the complexity of the attention module is not a concern # since it is used blackbox as F block in the reversible logic and # can be arbitrary. - self.attn = MHA(dim, num_heads, batch_first = True) + self.attn = MHA(dim, num_heads, batch_first=True) + self.enable_amp = enable_amp def forward(self, x): - x = self.norm(x) - out, _ = self.attn(x, x, x) - return out + # See MLP fwd pass for explanation. + with torch.cuda.amp.autocast(enabled=self.enable_amp): + x = self.norm(x) + out, _ = self.attn(x, x, x) + return out def main():