From 7416a199a65366a8f427bdb0c2e7f17295859d47 Mon Sep 17 00:00:00 2001 From: maximzubkov Date: Wed, 5 May 2021 18:18:17 +0300 Subject: [PATCH 1/6] Add remove momentum updating from val step and add separate val queue --- .../self_supervised/moco/moco2_module.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 832866b1cc..67993acda2 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -121,6 +121,12 @@ def __init__( self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + # create the validation queue + self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) + self.queue = nn.functional.normalize(self.val_queue, dim=0) + + self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) + def init_encoders(self, base_encoder): """ Override to add your own encoders @@ -142,21 +148,21 @@ def _momentum_update_key_encoder(self): param_k.data = param_k.data * em + param_q.data * (1. - em) @torch.no_grad() - def _dequeue_and_enqueue(self, keys): + def _dequeue_and_enqueue(self, keys, queue_ptr, queue): # gather keys before updating queue if self.trainer.use_ddp or self.trainer.use_ddp2: keys = concat_all_gather(keys) batch_size = keys.shape[0] - ptr = int(self.queue_ptr) + ptr = int(queue_ptr) assert self.hparams.num_negatives % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) - self.queue[:, ptr:ptr + batch_size] = keys.T + queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer - self.queue_ptr[0] = ptr + queue_ptr[0] = ptr @torch.no_grad() def _batch_shuffle_ddp(self, x): # pragma: no cover @@ -205,11 +211,12 @@ def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover return x_gather[idx_this] - def forward(self, img_q, img_k): + def forward(self, img_q, img_k, queue): """ Input: im_q: a batch of query images im_k: a batch of key images + queue: a queue from which to pick negative samples Output: logits, targets """ @@ -220,7 +227,6 @@ def forward(self, img_q, img_k): # compute key features with torch.no_grad(): # no gradient to keys - self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN if self.trainer.use_ddp or self.trainer.use_ddp2: @@ -238,7 +244,7 @@ def forward(self, img_q, img_k): # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK - l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) @@ -250,10 +256,7 @@ def forward(self, img_q, img_k): labels = torch.zeros(logits.shape[0], dtype=torch.long) labels = labels.type_as(logits) - # dequeue and enqueue - self._dequeue_and_enqueue(k) - - return logits, labels + return logits, labels, k def training_step(self, batch, batch_idx): # in STL10 we pass in both lab+unl for online ft @@ -264,7 +267,10 @@ def training_step(self, batch, batch_idx): (img_1, img_2), _ = batch - output, target = self(img_q=img_1, img_k=img_2) + self._momentum_update_key_encoder() # update the key encoder + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue) + self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output.float(), target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) @@ -282,7 +288,9 @@ def validation_step(self, batch, batch_idx): (img_1, img_2), labels = batch - output, target = self(img_q=img_1, img_k=img_2) + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue) + self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output, target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) From 0f4b8ed88422c8ac953b61493b58a9b3074e251d Mon Sep 17 00:00:00 2001 From: maximzubkov Date: Wed, 5 May 2021 18:18:17 +0300 Subject: [PATCH 2/6] Remove momentum updating from val step and add separate val queue --- .../self_supervised/moco/moco2_module.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 832866b1cc..67993acda2 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -121,6 +121,12 @@ def __init__( self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + # create the validation queue + self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) + self.queue = nn.functional.normalize(self.val_queue, dim=0) + + self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) + def init_encoders(self, base_encoder): """ Override to add your own encoders @@ -142,21 +148,21 @@ def _momentum_update_key_encoder(self): param_k.data = param_k.data * em + param_q.data * (1. - em) @torch.no_grad() - def _dequeue_and_enqueue(self, keys): + def _dequeue_and_enqueue(self, keys, queue_ptr, queue): # gather keys before updating queue if self.trainer.use_ddp or self.trainer.use_ddp2: keys = concat_all_gather(keys) batch_size = keys.shape[0] - ptr = int(self.queue_ptr) + ptr = int(queue_ptr) assert self.hparams.num_negatives % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) - self.queue[:, ptr:ptr + batch_size] = keys.T + queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer - self.queue_ptr[0] = ptr + queue_ptr[0] = ptr @torch.no_grad() def _batch_shuffle_ddp(self, x): # pragma: no cover @@ -205,11 +211,12 @@ def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover return x_gather[idx_this] - def forward(self, img_q, img_k): + def forward(self, img_q, img_k, queue): """ Input: im_q: a batch of query images im_k: a batch of key images + queue: a queue from which to pick negative samples Output: logits, targets """ @@ -220,7 +227,6 @@ def forward(self, img_q, img_k): # compute key features with torch.no_grad(): # no gradient to keys - self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN if self.trainer.use_ddp or self.trainer.use_ddp2: @@ -238,7 +244,7 @@ def forward(self, img_q, img_k): # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK - l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) @@ -250,10 +256,7 @@ def forward(self, img_q, img_k): labels = torch.zeros(logits.shape[0], dtype=torch.long) labels = labels.type_as(logits) - # dequeue and enqueue - self._dequeue_and_enqueue(k) - - return logits, labels + return logits, labels, k def training_step(self, batch, batch_idx): # in STL10 we pass in both lab+unl for online ft @@ -264,7 +267,10 @@ def training_step(self, batch, batch_idx): (img_1, img_2), _ = batch - output, target = self(img_q=img_1, img_k=img_2) + self._momentum_update_key_encoder() # update the key encoder + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue) + self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output.float(), target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) @@ -282,7 +288,9 @@ def validation_step(self, batch, batch_idx): (img_1, img_2), labels = batch - output, target = self(img_q=img_1, img_k=img_2) + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue) + self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output, target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) From e4c3153022bbaad232e35c8dc9d778ef6689eec4 Mon Sep 17 00:00:00 2001 From: jirka Date: Mon, 10 May 2021 11:16:31 +0200 Subject: [PATCH 3/6] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5064c5c528..9ec35f1ca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631)) + ### Deprecated From 370f23e60eee61bfad799ddab6f9c410323b6285 Mon Sep 17 00:00:00 2001 From: maximzubkov Date: Mon, 24 May 2021 14:25:07 +0300 Subject: [PATCH 4/6] Fix val queue init --- pl_bolts/models/self_supervised/moco/moco2_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 67993acda2..2c7bb11992 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -123,7 +123,7 @@ def __init__( # create the validation queue self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) - self.queue = nn.functional.normalize(self.val_queue, dim=0) + self.val_queue = nn.functional.normalize(self.val_queue, dim=0) self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) From a53449a43d9964892e1105bbee4d1f684671fc51 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 17 Jun 2021 17:17:19 +0200 Subject: [PATCH 5/6] v0.3.4 & changelog --- CHANGELOG.md | 18 +++++------------- pl_bolts/__about__.py | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93db4fe211..1c8300df9e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,28 +5,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [unReleased] - 2021-MM-DD - -### Added - +## [0.3.4] - 2021-06-17 ### Changed - Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629)) - - - Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637)) - -### Deprecated - - -### Removed - - ### Fixed - Fixed the MNIST download giving HTTP 503 ([#633](https://github.com/PyTorchLightning/lightning-bolts/pull/633)) +- Fixed type annotation of `ExperienceSource.__iter__` ([#645](https://github.com/PyTorchLightning/lightning-bolts/pull/645)) +- Fixed `pretrained_urls` on Windows ([#652](https://github.com/PyTorchLightning/lightning-bolts/pull/652)) +- Fixed logistic regression ([#655](https://github.com/PyTorchLightning/lightning-bolts/pull/655), [#664](https://github.com/PyTorchLightning/lightning-bolts/pull/664)) +- Fixed double softmax in `SSLEvaluator` ([#663](https://github.com/PyTorchLightning/lightning-bolts/pull/663)) ## [0.3.3] - 2021-04-17 diff --git a/pl_bolts/__about__.py b/pl_bolts/__about__.py index d8228549a8..824fb9b0e3 100644 --- a/pl_bolts/__about__.py +++ b/pl_bolts/__about__.py @@ -1,4 +1,4 @@ -__version__ = '0.4.0dev' +__version__ = '0.3.4' __author__ = 'PyTorchLightning et al.' __author_email__ = 'name@pytorchlightning.ai' __license__ = 'Apache-2.0' From 087639e7f804a30e200d4be35eb886b5ebfcbdf9 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 27 Jun 2021 17:44:03 +0900 Subject: [PATCH 6/6] Update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52dd3f3d9e..a247597da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,11 +24,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the MNIST download giving HTTP 404 with torchvision>=0.9.1 ([#674](https://github.com/PyTorchLightning/lightning-bolts/pull/674)) +- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631)) + + ## [0.3.4] - 2021-06-17 ### Changed -- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631)) - Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629)) - Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637))