Skip to content

Commit

Permalink
adding AMP
Browse files Browse the repository at this point in the history
  • Loading branch information
karttikeya committed Feb 27, 2023
1 parent 824cf58 commit 4f32aca
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 43 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ Here are the [Training/Validation Logs](https://api.wandb.ai/links/action_antici

👁️ **Note**: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that6 a much higher accuracy can be achieved with the same code, using a better [chosen model design and optimization parameters](https://github.com/tysam-code/hlb-CIFAR10). The authors have done no tuning since this repository is meant for understanding code, not pushing performance.

<h2> Mixed precision training </h2>

Mixed precision training is also supported and can be enabled by adding `--amp True` flag to above commands. Training progresses smoothly and achieves `80%+` validation accuracy on CIFAR-10 similar to training without AMP.


📝 **Note**: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from `12000` to `14500` ( `~20%`). At usual training batch size (`128`) there is small gain in GPU training memory (about 4%).

<h2> Distributed Data Parallel Training </h2>

There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in [paper](https://arxiv.org/abs/2302.04869) (also see below) are obtained in DDP setups (`>64` GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast [distributed training setup](https://github.com/facebookresearch/SlowFast/blob/99a655bd533d7fddd7f79509e3dfaae811767b5c/slowfast/models/build.py#L69-L83).

<h2> Running ImageNet, Kinetics-400 and more </h2>

For more usecases such as reproducing numbers from [original paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf), see the [full code in PySlowFast](https://github.com/facebookresearch/SlowFast) that supports
Expand Down
24 changes: 20 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import GradScaler

import torchvision
import torchvision.transforms as transforms

from rev import RevViT

parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training")
Expand Down Expand Up @@ -53,6 +54,12 @@
type=bool,
help="whether to use reversible backpropagation or not",
)
parser.add_argument(
"--amp",
default=False,
type=bool,
help="whether to use mixed precision training or not",
)

args = parser.parse_args()

Expand Down Expand Up @@ -102,6 +109,7 @@
patch_size=args.patch_size,
image_size=args.image_size,
num_classes=args.num_classes,
enable_amp=args.amp,
)

model = model.to(device)
Expand All @@ -113,6 +121,7 @@
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
scaler = GradScaler()

# Training
def train(epoch):
Expand All @@ -122,12 +131,19 @@ def train(epoch):
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):

# We do not need to specify AMP autocast in forward pass here since
# that is taken care of already in the forward of individual modules.
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

# standard pytorch AMP training setup
# scaler also works without amp training.
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

train_loss += loss.item()
_, predicted = outputs.max(1)
Expand Down
74 changes: 35 additions & 39 deletions rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ def __init__(
embed_dim=768,
n_head=8,
depth=8,
patch_size=(
2,
2,
), # this patch size is used for CIFAR-10
patch_size=(2, 2,), # this patch size is used for CIFAR-10
# --> (32 // 2)**2 = 256 sequence length
image_size=(32, 32), # CIFAR-10 image size
num_classes=10,
enable_amp=False,
):

super().__init__()
Expand All @@ -39,7 +37,11 @@ def __init__(
# is contrained inside the block code and not exposed.
self.layers = nn.ModuleList(
[
ReversibleBlock(dim=self.embed_dim, num_heads=self.n_head)
ReversibleBlock(
dim=self.embed_dim,
num_heads=self.n_head,
enable_amp=enable_amp,
)
for _ in range(self.depth)
]
)
Expand Down Expand Up @@ -96,10 +98,7 @@ def forward(self, x):
executing_fn = RevBackProp.apply

# This takes care of switching between vanilla backprop and rev backprop
x = executing_fn(
x,
self.layers,
)
x = executing_fn(x, self.layers,)

# aggregate across sequence length
x = x.mean(1)
Expand All @@ -125,9 +124,7 @@ class RevBackProp(Function):

@staticmethod
def forward(
ctx,
x,
layers,
ctx, x, layers,
):
"""
Reversible Forward pass.
Expand Down Expand Up @@ -167,10 +164,7 @@ def backward(ctx, dx):
# this is recomputing both the activations and the gradients wrt
# those activations.
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
Y_1=X_1,
Y_2=X_2,
dY_1=dX_1,
dY_2=dX_2,
Y_1=X_1, Y_2=X_2, dY_1=dX_1, dY_2=dX_2,
)
# final input gradient to be passed backward to the patchification layer
dx = torch.cat([dX_1, dX_2], dim=-1)
Expand All @@ -186,21 +180,19 @@ class ReversibleBlock(nn.Module):
See Section 3.3.2 in paper for details.
"""

def __init__(
self,
dim,
num_heads,
):
def __init__(self, dim, num_heads, enable_amp):
"""
Block is composed entirely of function F (Attention
sub-block) and G (MLP sub-block) including layernorm.
"""
super().__init__()
# F and G can be arbitrary functions, here we use
# simple attwntion and MLP sub-blocks using vanilla attention.
self.F = AttentionSubBlock(dim=dim, num_heads=num_heads)
self.F = AttentionSubBlock(
dim=dim, num_heads=num_heads, enable_amp=enable_amp
)

self.G = MLPSubblock(dim=dim)
self.G = MLPSubblock(dim=dim, enable_amp=enable_amp)

# note that since all functions are deterministic, and we are
# not using any stochastic elements such as dropout, we do
Expand Down Expand Up @@ -234,11 +226,7 @@ def forward(self, X_1, X_2):
return Y_1, Y_2

def backward_pass(
self,
Y_1,
Y_2,
dY_1,
dY_2,
self, Y_1, Y_2, dY_1, dY_2,
):
"""
equation for activation recomputation:
Expand Down Expand Up @@ -323,9 +311,7 @@ class MLPSubblock(nn.Module):
"""

def __init__(
self,
dim,
mlp_ratio=4, # standard for ViTs
self, dim, mlp_ratio=4, enable_amp=False, # standard for ViTs
):

super().__init__()
Expand All @@ -337,9 +323,18 @@ def __init__(
nn.GELU(),
nn.Linear(dim * mlp_ratio, dim),
)
self.enable_amp = enable_amp

def forward(self, x):
return self.mlp(self.norm(x))

# The reason for implementing autocast inside forward loop instead
# in the main training logic is the implicit forward pass during
# memory efficient gradient backpropagation. In backward pass, the
# activations need to be recomputed, and if the forward has happened
# with mixed precision, the recomputation must also be so. This cannot
# be handled with the autocast setup in main training logic.
with torch.cuda.amp.autocast(enabled=self.enable_amp):
return self.mlp(self.norm(x))


class AttentionSubBlock(nn.Module):
Expand All @@ -349,9 +344,7 @@ class AttentionSubBlock(nn.Module):
"""

def __init__(
self,
dim,
num_heads,
self, dim, num_heads, enable_amp=False,
):

super().__init__()
Expand All @@ -362,12 +355,15 @@ def __init__(
# Note that the complexity of the attention module is not a concern
# since it is used blackbox as F block in the reversible logic and
# can be arbitrary.
self.attn = MHA(dim, num_heads, batch_first = True)
self.attn = MHA(dim, num_heads, batch_first=True)
self.enable_amp = enable_amp

def forward(self, x):
x = self.norm(x)
out, _ = self.attn(x, x, x)
return out
# See MLP fwd pass for explanation.
with torch.cuda.amp.autocast(enabled=self.enable_amp):
x = self.norm(x)
out, _ = self.attn(x, x, x)
return out


def main():
Expand Down

0 comments on commit 4f32aca

Please sign in to comment.