Skip to content

Commit

Permalink
Update SatViT_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
antofuller authored Oct 26, 2022
1 parent a903a78 commit e9534e6
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions SatViT_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

class SatViT(nn.Module):
def __init__(self,
in_dim,
out_dim,
io_dim=3840,
num_patches=256,
encoder_dim=768,
encoder_depth=12,
Expand Down Expand Up @@ -41,6 +40,7 @@ def __init__(self,
# If the encoder and decoder have different model widths (dim) we need to apply a linear projection from the
# encoder to the decoder. If the models have equal width, no projection is needed.
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim)

self.decoder = BaseTransformer(dim=decoder_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
Expand All @@ -57,8 +57,9 @@ def __init__(self,
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

# Input and output maps
self.linear_input = nn.Linear(in_dim, encoder_dim)
self.linear_output = nn.Linear(decoder_dim, out_dim)
self.linear_input = nn.Linear(io_dim, encoder_dim)
self.linear_output = nn.Linear(decoder_dim, io_dim)
self.norm_pix_loss = True

def random_masking(self, x, mask_ratio):
"""
Expand Down Expand Up @@ -119,12 +120,17 @@ def forward_decoder(self, x, ids_restore):
return self.linear_output(x)

def forward_loss(self, imgs, pred, mask):
if self.norm_pix_loss:
mean = imgs.mean(dim=-1, keepdim=True)
var = imgs.var(dim=-1, keepdim=True)
imgs = (imgs - mean) / (var + 1.e-6)**.5

loss = (pred - imgs) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch

loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss

def encode(self, images_patches):
"""
We encode full images (i.e., no masking) by linearly projecting image patches, adding position embeddings,
Expand All @@ -134,12 +140,7 @@ def encode(self, images_patches):
return self.encoder(patch_encodings)

def forward(self, patch_encodings, mask_ratio=0.75):
"""
*** Masked Autoencoding Pre-training ***
We encode a portion of image patches (1 - mask_ratio), then use the encoded representations of visible patches
to predict all hidden patches.
"""
latent, mask, ids_restore = self.forward_encoder(patch_encodings, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # (BSZ, num_patches, io_dim)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(patch_encodings, pred, mask)
return loss, pred, mask

0 comments on commit e9534e6

Please sign in to comment.