From 0a4f0ae599f97f33a3d89bd934f2c0ef71503b28 Mon Sep 17 00:00:00 2001 From: Annika Brundyn <42869932+annikabrundyn@users.noreply.github.com> Date: Wed, 5 Aug 2020 13:08:59 -0400 Subject: [PATCH] BYOL implementation (#144) * byol wip * add blank lines * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * verify implementation * add l2 normalization --- docs/source/index.rst | 1 + docs/source/self_supervised_callbacks.rst | 15 ++ docs/source/self_supervised_models.rst | 6 + pl_bolts/callbacks/self_supervised.py | 62 +++++ pl_bolts/models/self_supervised/__init__.py | 1 + .../models/self_supervised/byol/__init__.py | 0 .../self_supervised/byol/byol_module.py | 219 ++++++++++++++++++ .../models/self_supervised/byol/models.py | 39 ++++ .../callbacks/test_param_update_callbacks.py | 31 +++ tests/models/test_self_supervised.py | 17 +- 10 files changed, 390 insertions(+), 1 deletion(-) create mode 100644 docs/source/self_supervised_callbacks.rst create mode 100644 pl_bolts/callbacks/self_supervised.py create mode 100644 pl_bolts/models/self_supervised/byol/__init__.py create mode 100644 pl_bolts/models/self_supervised/byol/byol_module.py create mode 100644 pl_bolts/models/self_supervised/byol/models.py create mode 100644 tests/callbacks/test_param_update_callbacks.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 7e359313e9..3b8179df14 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ PyTorch-Lightning-Bolts documentation callbacks info_callbacks + self_supervised_callbacks variational_callbacks vision_callbacks diff --git a/docs/source/self_supervised_callbacks.rst b/docs/source/self_supervised_callbacks.rst new file mode 100644 index 0000000000..8f9083f4f1 --- /dev/null +++ b/docs/source/self_supervised_callbacks.rst @@ -0,0 +1,15 @@ +.. role:: hidden + :class: hidden-section + +Self-supervised Callbacks +========================= +Useful callbacks for self-supervised learning models + +--------------- + +BYOLMAWeightUpdate +------------------ +The exponential moving average weight-update rule from Bring Your Own Latent Space (BYOL). + +.. autoclass:: pl_bolts.callbacks.self_supervised.BYOLMAWeightUpdate + :noindex: diff --git a/docs/source/self_supervised_models.rst b/docs/source/self_supervised_models.rst index 1b60beeefa..3797bbcbc8 100644 --- a/docs/source/self_supervised_models.rst +++ b/docs/source/self_supervised_models.rst @@ -92,6 +92,12 @@ AMDIM .. autoclass:: pl_bolts.models.self_supervised.AMDIM :noindex: +BYOL +^^^^ + +.. autoclass:: pl_bolts.models.self_supervised.BYOL + :noindex: + CPC (V2) ^^^^^^^^ diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py new file mode 100644 index 0000000000..fc85d7c543 --- /dev/null +++ b/pl_bolts/callbacks/self_supervised.py @@ -0,0 +1,62 @@ +import torch +import math +import pytorch_lightning as pl + + +class BYOLMAWeightUpdate(pl.Callback): + + def __init__(self, initial_tau=0.996): + """ + Weight update rule from BYOL. + + Your model should have a: + + - self.online_network. + - self.target_network. + + Updates the target_network params using an exponential moving average update rule weighted by tau. + BYOL claims this keeps the online_network from collapsing. + + .. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step + + Example:: + + from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate + + # model must have 2 attributes + model = Model() + model.online_network = ... + model.target_network = ... + + # make sure to set max_steps in Trainer + trainer = Trainer(callbacks=[BYOLMAWeightUpdate()], max_steps=1000) + + Args: + initial_tau: starting tau. Auto-updates with every training step + """ + super().__init__() + self.initial_tau = initial_tau + self.current_tau = initial_tau + + def on_batch_end(self, trainer, pl_module): + + if pl_module.training: + # get networks + online_net = pl_module.online_network + target_net = pl_module.target_network + + # update weights + self.update_weights(online_net, target_net) + + # update tau after + self.current_tau = self.update_tau(pl_module, trainer) + + def update_tau(self, pl_module, trainer): + tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / trainer.max_steps) + 1) / 2 + return tau + + def update_weights(self, online_net, target_net): + # apply MA weight update + for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()): + if 'weight' in name: + target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data diff --git a/pl_bolts/models/self_supervised/__init__.py b/pl_bolts/models/self_supervised/__init__.py index 82601032f1..c8485265fb 100644 --- a/pl_bolts/models/self_supervised/__init__.py +++ b/pl_bolts/models/self_supervised/__init__.py @@ -19,6 +19,7 @@ """ from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM +from pl_bolts.models.self_supervised.byol.byol_module import BYOL from pl_bolts.models.self_supervised.cpc.cpc_module import CPCV2 from pl_bolts.models.self_supervised.evaluator import SSLEvaluator from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2 diff --git a/pl_bolts/models/self_supervised/byol/__init__.py b/pl_bolts/models/self_supervised/byol/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py new file mode 100644 index 0000000000..45055f2a04 --- /dev/null +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -0,0 +1,219 @@ +from copy import deepcopy +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule +from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform +from pl_bolts.optimizers.layer_adaptive_scaling import LARS +from pl_bolts.models.self_supervised.byol.models import SiameseArm +from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate + + +class BYOL(pl.LightningModule): + def __init__(self, + datamodule: pl.LightningDataModule = None, + data_dir: str = './', + learning_rate: float = 0.00006, + weight_decay: float = 0.0005, + input_height: int = 32, + batch_size: int = 32, + num_workers: int = 4, + optimizer: str = 'lars', + lr_sched_step: float = 30.0, + lr_sched_gamma: float = 0.5, + lars_momentum: float = 0.9, + lars_eta: float = 0.001, + loss_temperature: float = 0.5, + **kwargs): + """ + PyTorch Lightning implementation of `Bring Your Own Latent Space (BYOL) + `_ + + Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ + Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ + Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. + + Model implemented by: + - `Annika Brundyn `_ + + .. warning:: Work in progress. This implementation is still being verified. + + TODOs: + - add cosine scheduler + - verify on CIFAR-10 + - verify on STL-10 + - pre-train on imagenet + + Example: + + >>> from pl_bolts.models.self_supervised import BYOL + ... + >>> model = BYOL() + + Train:: + + trainer = Trainer() + trainer.fit(model) + + CLI command:: + + # cifar10 + python byol_module.py --gpus 1 + + # imagenet + python byol_module.py + --gpus 8 + --dataset imagenet2012 + --data_dir /path/to/imagenet/ + --meta_dir /path/to/folder/with/meta.bin/ + --batch_size 32 + + Args: + datamodule: The datamodule + data_dir: directory to store data + learning_rate: the learning rate + weight_decay: optimizer weight decay + input_height: image input height + batch_size: the batch size + num_workers: number of workers + optimizer: optimizer name + lr_sched_step: step for learning rate scheduler + lr_sched_gamma: gamma for learning rate scheduler + lars_momentum: the mom param for lars optimizer + lars_eta: for lars optimizer + loss_temperature: float = 0. + """ + super().__init__() + self.save_hyperparameters() + + # init default datamodule + if datamodule is None: + datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size) + datamodule.train_transforms = SimCLRTrainDataTransform(input_height) + datamodule.val_transforms = SimCLREvalDataTransform(input_height) + + self.datamodule = datamodule + + self.online_network = SiameseArm() + self.target_network = deepcopy(self.online_network) + + self.weight_callback = BYOLMAWeightUpdate() + + def on_batch_end(self): + # Add callback for user automatically since it's key to BYOL weight update + self.weight_callback.on_batch_end(self.trainer, self) + + def forward(self, x): + y, _, _ = self.online_network(x) + return y + + def shared_step(self, batch, batch_idx): + (img_1, img_2), y = batch + + # Image 1 to image 2 loss + y1, z1, h1 = self.online_network(img_1) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_2) + # L2 normalize + h1_norm = F.normalize(h1, p=2, dim=1) + z2_norm = F.normalize(z2, p=2, dim=1) + loss_a = F.mse_loss(h1_norm, z2_norm) + + # Image 2 to image 1 loss + y1, z1, h1 = self.online_network(img_2) + with torch.no_grad(): + y2, z2, h2 = self.target_network(img_1) + # L2 normalize + h1_norm = F.normalize(h1, p=2, dim=1) + z2_norm = F.normalize(z2, p=2, dim=1) + loss_b = F.mse_loss(h1_norm, z2_norm) + + # Final loss + total_loss = loss_a + loss_b + + return loss_a, loss_b, total_loss + + def training_step(self, batch, batch_idx): + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # log results + result = pl.TrainResult(minimize=total_loss) + result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + + return result + + def validation_step(self, batch, batch_idx): + loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) + + # log results + result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss) + result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + + return result + + def configure_optimizers(self): + optimizer = LARS(self.parameters(), lr=self.hparams.learning_rate) + # TODO: add scheduler - cosine decay + return optimizer + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--online_ft', action='store_true', help='run online finetuner') + parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10') + + (args, _) = parser.parse_known_args() + # Data + parser.add_argument('--data_dir', type=str, default='.') + + # Training + parser.add_argument('--optimizer', choices=['adam', 'lars'], default='lars') + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--learning_rate', type=float, default=1.0) + parser.add_argument('--lars_momentum', type=float, default=0.9) + parser.add_argument('--lars_eta', type=float, default=0.001) + parser.add_argument('--lr_sched_step', type=float, default=30, help='lr scheduler step') + parser.add_argument('--lr_sched_gamma', type=float, default=0.5, help='lr scheduler step') + parser.add_argument('--weight_decay', type=float, default=1e-4) + # Model + parser.add_argument('--loss_temperature', type=float, default=0.5) + parser.add_argument('--num_workers', default=4, type=int) + parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') + + return parser + + +if __name__ == '__main__': + from argparse import ArgumentParser + + parser = ArgumentParser() + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = BYOL.add_model_specific_args(parser) + args = parser.parse_args() + + # pick data + datamodule = None + if args.dataset == 'stl10': + datamodule = STL10DataModule.from_argparse_args(args) + datamodule.train_dataloader = datamodule.train_dataloader_mixed + datamodule.val_dataloader = datamodule.val_dataloader_mixed + + (c, h, w) = datamodule.size() + datamodule.train_transforms = SimCLRTrainDataTransform(h) + datamodule.val_transforms = SimCLREvalDataTransform(h) + + elif args.dataset == 'imagenet2012': + datamodule = ImagenetDataModule.from_argparse_args(args, image_size=196) + (c, h, w) = datamodule.size() + datamodule.train_transforms = SimCLRTrainDataTransform(h) + datamodule.val_transforms = SimCLREvalDataTransform(h) + + model = BYOL(**args.__dict__, datamodule=datamodule) + + trainer = pl.Trainer.from_argparse_args(args, max_steps=10000) + trainer.fit(model) diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py new file mode 100644 index 0000000000..0d422270b6 --- /dev/null +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -0,0 +1,39 @@ +from torch import nn +from pl_bolts.utils.self_supervised import torchvision_ssl_encoder + + +class MLP(nn.Module): + def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): + super().__init__() + self.output_dim = output_dim + self.input_dim = input_dim + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, output_dim, bias=True)) + + def forward(self, x): + x = self.model(x) + return x + + +class SiameseArm(nn.Module): + def __init__(self, encoder=None): + super().__init__() + + if encoder is None: + encoder = torchvision_ssl_encoder('resnet50') + # Encoder + self.encoder = encoder + # Projector + self.projector = MLP() + # Predictor + self.predictor = MLP(input_dim=256) + + def forward(self, x): + y = self.encoder(x)[0] + y = y.view(y.size(0), -1) + z = self.projector(y) + h = self.predictor(z) + return y, z, h diff --git a/tests/callbacks/test_param_update_callbacks.py b/tests/callbacks/test_param_update_callbacks.py new file mode 100644 index 0000000000..da865f1926 --- /dev/null +++ b/tests/callbacks/test_param_update_callbacks.py @@ -0,0 +1,31 @@ +import torch +from torch import nn +from copy import deepcopy +from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate + + +def test_byol_ma_weight_update_callback(tmpdir): + a = nn.Linear(100, 10) + b = deepcopy(a) + a_original = deepcopy(a) + b_original = deepcopy(b) + + # make sure a params and b params are the same + assert torch.equal(next(iter(a.parameters()))[0], next(iter(b.parameters()))[0]) + + # fake weight update + opt = torch.optim.SGD(a.parameters(), lr=0.1) + y = a(torch.randn(3, 100)) + loss = y.sum() + loss.backward() + opt.step() + opt.zero_grad() + + # make sure a did in fact update + assert not torch.equal(next(iter(a_original.parameters()))[0], next(iter(a.parameters()))[0]) + + # do update via callback + cb = BYOLMAWeightUpdate(0.8) + cb.update_weights(a, b) + + assert not torch.equal(next(iter(b_original.parameters()))[0], next(iter(b.parameters()))[0]) diff --git a/tests/models/test_self_supervised.py b/tests/models/test_self_supervised.py index 95679d1a5d..6487e4c282 100644 --- a/tests/models/test_self_supervised.py +++ b/tests/models/test_self_supervised.py @@ -1,6 +1,6 @@ import pytorch_lightning as pl -from pl_bolts.models.self_supervised import CPCV2, AMDIM, MocoV2, SimCLR +from pl_bolts.models.self_supervised import CPCV2, AMDIM, MocoV2, SimCLR, BYOL from pl_bolts.datamodules import TinyCIFAR10DataModule, CIFAR10DataModule from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10 from tests import reset_seed @@ -24,6 +24,21 @@ def test_cpcv2(tmpdir): assert loss > 0 +def test_byol(tmpdir): + reset_seed() + + datamodule = CIFAR10DataModule(data_dir=tmpdir, num_workers=0, batch_size=2) + datamodule.train_transforms = CPCTrainTransformsCIFAR10() + datamodule.val_transforms = CPCEvalTransformsCIFAR10() + + model = BYOL(data_dir=tmpdir, batch_size=2, datamodule=datamodule) + trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=2) + trainer.fit(model) + loss = trainer.progress_bar_dict['loss'] + + assert float(loss) < 1.0 + + def test_amdim(tmpdir): reset_seed()