diff --git a/README.md b/README.md
index 9f3b6c0..c9df70d 100644
--- a/README.md
+++ b/README.md
@@ -51,6 +51,17 @@ Here are the [Training/Validation Logs](https://api.wandb.ai/links/action_antici
👁️ **Note**: The relatively low accuracy is due to difficulty in training vision transformer (reversible or vanilla) from scratch on small datasets like CIFAR-10. Also likely is that6 a much higher accuracy can be achieved with the same code, using a better [chosen model design and optimization parameters](https://github.com/tysam-code/hlb-CIFAR10). The authors have done no tuning since this repository is meant for understanding code, not pushing performance.
+
Mixed precision training
+
+Mixed precision training is also supported and can be enabled by adding `--amp True` flag to above commands. Training progresses smoothly and achieves `80%+` validation accuracy on CIFAR-10 similar to training without AMP.
+
+
+📝 **Note**: Pytorch vanilla AMP, maintains full precision (fp32) on weights and only uses half-precision (fp16) on intermediate activations. Since reversible is already saving up on almost all intermediate activations (see video for examplanation), using AMP (ie half-precision on activations) brings little additional memory savings. For example, on a 16G V100 setup, AMP can improve rev maximum CIFAR-10 batch size from `12000` to `14500` ( `~20%`). At usual training batch size (`128`) there is small gain in GPU training memory (about 4%).
+
+ Distributed Data Parallel Training
+
+There are no additional overheads for DDP training with reversible that progresses the same as vanilla training. All results in [paper](https://arxiv.org/abs/2302.04869) (also see below) are obtained in DDP setups (`>64` GPUs per run). However, implementing distributed training is not commensurate with the purpose of this repo, and instead can be found in the pyslowfast [distributed training setup](https://github.com/facebookresearch/SlowFast/blob/99a655bd533d7fddd7f79509e3dfaae811767b5c/slowfast/models/build.py#L69-L83).
+
Running ImageNet, Kinetics-400 and more
For more usecases such as reproducing numbers from [original paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf), see the [full code in PySlowFast](https://github.com/facebookresearch/SlowFast) that supports
diff --git a/main.py b/main.py
index 29366cd..5956ee2 100644
--- a/main.py
+++ b/main.py
@@ -7,9 +7,10 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
+from torch.cuda.amp import GradScaler
+
import torchvision
import torchvision.transforms as transforms
-
from rev import RevViT
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training")
@@ -53,6 +54,12 @@
type=bool,
help="whether to use reversible backpropagation or not",
)
+parser.add_argument(
+ "--amp",
+ default=False,
+ type=bool,
+ help="whether to use mixed precision training or not",
+)
args = parser.parse_args()
@@ -102,6 +109,7 @@
patch_size=args.patch_size,
image_size=args.image_size,
num_classes=args.num_classes,
+ enable_amp=args.amp,
)
model = model.to(device)
@@ -113,6 +121,7 @@
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
+scaler = GradScaler()
# Training
def train(epoch):
@@ -122,12 +131,19 @@ def train(epoch):
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
+
+ # We do not need to specify AMP autocast in forward pass here since
+ # that is taken care of already in the forward of individual modules.
inputs, targets = inputs.to(device), targets.to(device)
- optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
+
+ # standard pytorch AMP training setup
+ # scaler also works without amp training.
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
train_loss += loss.item()
_, predicted = outputs.max(1)
diff --git a/rev.py b/rev.py
index 4caaf1e..aa463c6 100644
--- a/rev.py
+++ b/rev.py
@@ -14,13 +14,11 @@ def __init__(
embed_dim=768,
n_head=8,
depth=8,
- patch_size=(
- 2,
- 2,
- ), # this patch size is used for CIFAR-10
+ patch_size=(2, 2,), # this patch size is used for CIFAR-10
# --> (32 // 2)**2 = 256 sequence length
image_size=(32, 32), # CIFAR-10 image size
num_classes=10,
+ enable_amp=False,
):
super().__init__()
@@ -39,7 +37,11 @@ def __init__(
# is contrained inside the block code and not exposed.
self.layers = nn.ModuleList(
[
- ReversibleBlock(dim=self.embed_dim, num_heads=self.n_head)
+ ReversibleBlock(
+ dim=self.embed_dim,
+ num_heads=self.n_head,
+ enable_amp=enable_amp,
+ )
for _ in range(self.depth)
]
)
@@ -96,10 +98,7 @@ def forward(self, x):
executing_fn = RevBackProp.apply
# This takes care of switching between vanilla backprop and rev backprop
- x = executing_fn(
- x,
- self.layers,
- )
+ x = executing_fn(x, self.layers,)
# aggregate across sequence length
x = x.mean(1)
@@ -125,9 +124,7 @@ class RevBackProp(Function):
@staticmethod
def forward(
- ctx,
- x,
- layers,
+ ctx, x, layers,
):
"""
Reversible Forward pass.
@@ -167,10 +164,7 @@ def backward(ctx, dx):
# this is recomputing both the activations and the gradients wrt
# those activations.
X_1, X_2, dX_1, dX_2 = layer.backward_pass(
- Y_1=X_1,
- Y_2=X_2,
- dY_1=dX_1,
- dY_2=dX_2,
+ Y_1=X_1, Y_2=X_2, dY_1=dX_1, dY_2=dX_2,
)
# final input gradient to be passed backward to the patchification layer
dx = torch.cat([dX_1, dX_2], dim=-1)
@@ -186,11 +180,7 @@ class ReversibleBlock(nn.Module):
See Section 3.3.2 in paper for details.
"""
- def __init__(
- self,
- dim,
- num_heads,
- ):
+ def __init__(self, dim, num_heads, enable_amp):
"""
Block is composed entirely of function F (Attention
sub-block) and G (MLP sub-block) including layernorm.
@@ -198,9 +188,11 @@ def __init__(
super().__init__()
# F and G can be arbitrary functions, here we use
# simple attwntion and MLP sub-blocks using vanilla attention.
- self.F = AttentionSubBlock(dim=dim, num_heads=num_heads)
+ self.F = AttentionSubBlock(
+ dim=dim, num_heads=num_heads, enable_amp=enable_amp
+ )
- self.G = MLPSubblock(dim=dim)
+ self.G = MLPSubblock(dim=dim, enable_amp=enable_amp)
# note that since all functions are deterministic, and we are
# not using any stochastic elements such as dropout, we do
@@ -234,11 +226,7 @@ def forward(self, X_1, X_2):
return Y_1, Y_2
def backward_pass(
- self,
- Y_1,
- Y_2,
- dY_1,
- dY_2,
+ self, Y_1, Y_2, dY_1, dY_2,
):
"""
equation for activation recomputation:
@@ -323,9 +311,7 @@ class MLPSubblock(nn.Module):
"""
def __init__(
- self,
- dim,
- mlp_ratio=4, # standard for ViTs
+ self, dim, mlp_ratio=4, enable_amp=False, # standard for ViTs
):
super().__init__()
@@ -337,9 +323,18 @@ def __init__(
nn.GELU(),
nn.Linear(dim * mlp_ratio, dim),
)
+ self.enable_amp = enable_amp
def forward(self, x):
- return self.mlp(self.norm(x))
+
+ # The reason for implementing autocast inside forward loop instead
+ # in the main training logic is the implicit forward pass during
+ # memory efficient gradient backpropagation. In backward pass, the
+ # activations need to be recomputed, and if the forward has happened
+ # with mixed precision, the recomputation must also be so. This cannot
+ # be handled with the autocast setup in main training logic.
+ with torch.cuda.amp.autocast(enabled=self.enable_amp):
+ return self.mlp(self.norm(x))
class AttentionSubBlock(nn.Module):
@@ -349,9 +344,7 @@ class AttentionSubBlock(nn.Module):
"""
def __init__(
- self,
- dim,
- num_heads,
+ self, dim, num_heads, enable_amp=False,
):
super().__init__()
@@ -362,12 +355,15 @@ def __init__(
# Note that the complexity of the attention module is not a concern
# since it is used blackbox as F block in the reversible logic and
# can be arbitrary.
- self.attn = MHA(dim, num_heads, batch_first = True)
+ self.attn = MHA(dim, num_heads, batch_first=True)
+ self.enable_amp = enable_amp
def forward(self, x):
- x = self.norm(x)
- out, _ = self.attn(x, x, x)
- return out
+ # See MLP fwd pass for explanation.
+ with torch.cuda.amp.autocast(enabled=self.enable_amp):
+ x = self.norm(x)
+ out, _ = self.attn(x, x, x)
+ return out
def main():