diff --git a/SatViT_model.py b/SatViT_model.py index 804863b..09a21ab 100644 --- a/SatViT_model.py +++ b/SatViT_model.py @@ -11,8 +11,7 @@ class SatViT(nn.Module): def __init__(self, - in_dim, - out_dim, + io_dim=3840, num_patches=256, encoder_dim=768, encoder_depth=12, @@ -41,6 +40,7 @@ def __init__(self, # If the encoder and decoder have different model widths (dim) we need to apply a linear projection from the # encoder to the decoder. If the models have equal width, no projection is needed. self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) + self.decoder = BaseTransformer(dim=decoder_dim, depth=decoder_depth, num_heads=decoder_num_heads, @@ -57,8 +57,9 @@ def __init__(self, self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # Input and output maps - self.linear_input = nn.Linear(in_dim, encoder_dim) - self.linear_output = nn.Linear(decoder_dim, out_dim) + self.linear_input = nn.Linear(io_dim, encoder_dim) + self.linear_output = nn.Linear(decoder_dim, io_dim) + self.norm_pix_loss = True def random_masking(self, x, mask_ratio): """ @@ -119,12 +120,17 @@ def forward_decoder(self, x, ids_restore): return self.linear_output(x) def forward_loss(self, imgs, pred, mask): + if self.norm_pix_loss: + mean = imgs.mean(dim=-1, keepdim=True) + var = imgs.var(dim=-1, keepdim=True) + imgs = (imgs - mean) / (var + 1.e-6)**.5 + loss = (pred - imgs) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss - + def encode(self, images_patches): """ We encode full images (i.e., no masking) by linearly projecting image patches, adding position embeddings, @@ -134,12 +140,7 @@ def encode(self, images_patches): return self.encoder(patch_encodings) def forward(self, patch_encodings, mask_ratio=0.75): - """ - *** Masked Autoencoding Pre-training *** - We encode a portion of image patches (1 - mask_ratio), then use the encoded representations of visible patches - to predict all hidden patches. - """ latent, mask, ids_restore = self.forward_encoder(patch_encodings, mask_ratio) - pred = self.forward_decoder(latent, ids_restore) # (BSZ, num_patches, io_dim) + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] loss = self.forward_loss(patch_encodings, pred, mask) return loss, pred, mask