Skip to content

Commit

Permalink
Add check for mode collapse in feature representation
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed May 11, 2023
1 parent 02474d7 commit 3480ec3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
30 changes: 22 additions & 8 deletions torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
Expand Down Expand Up @@ -251,21 +252,25 @@ def __init__(
# Define loss function
self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed)

def forward(self, x: Tensor) -> Tensor:
# Initialize moving average of output
self.avg_output_std = 0.0

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Forward pass of the model.
Args:
x: Mini-batch of images.
Returns:
Output from the model.
Output from the model and backbone
"""
q = self.backbone(x)
h = self.backbone(x)
q = h
if self.hparams["version"] > 1:
q = self.projection_head(q)
if self.hparams["version"] == 3:
q = self.prediction_head(q)
return cast(Tensor, q)
return cast(Tensor, q), cast(Tensor, h)

def forward_momentum(self, x: Tensor) -> Tensor:
"""Forward pass of the momentum model.
Expand Down Expand Up @@ -309,29 +314,38 @@ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:

m = self.hparams["moco_momentum"]
if self.hparams["version"] == 1:
q = self.forward(x1)
q, h = self.forward(x1)
with torch.no_grad():
update_momentum(self.backbone, self.backbone_momentum, m)
k = self.forward_momentum(x2)
loss = self.criterion(q, k)
elif self.hparams["version"] == 2:
q = self.forward(x1)
q, h = self.forward(x1)
with torch.no_grad():
update_momentum(self.backbone, self.backbone_momentum, m)
update_momentum(self.projection_head, self.projection_head_momentum, m)
k = self.forward_momentum(x2)
loss = self.criterion(q, k)
if self.hparams["version"] == 3:
m = cosine_schedule(self.current_epoch, self.trainer.max_epochs, m, 1)
q1 = self.forward(x1)
q2 = self.forward(x2)
q1, h1 = self.forward(x1)
q2, h2 = self.forward(x2)
with torch.no_grad():
update_momentum(self.backbone, self.backbone_momentum, m)
update_momentum(self.projection_head, self.projection_head_momentum, m)
k1 = self.forward_momentum(x1)
k2 = self.forward_momentum(x2)
loss = self.criterion(q1, k2) + self.criterion(q2, k1)

# Calculate the mean normalized standard deviation over features dimensions.
# If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything.
output = h1.detach()
output = F.normalize(output, dim=1)
output_std = torch.std(output, dim=0)
output_std = torch.mean(output_std, dim=0)
self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item()

self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)

return cast(Tensor, loss)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x: Mini-batch of images.
Returns:
Output from the backbone and projection head.
Output from the model and backbone.
"""
h = self.backbone(x) # shape of batch_size x num_features
z = self.projection_head(h)
Expand Down

0 comments on commit 3480ec3

Please sign in to comment.