Skip to content

Commit

Permalink
added interpolation for vitmae model in pytorch as well as tf. (huggi…
Browse files Browse the repository at this point in the history
…ngface#30732)

* added interpolation for vitmae model in pytorch as well as tf.

* Update modeling_vit_mae.py

irreugalr import fixed

* small changes and proper formatting

* changes suggested in review.

* modified decoder interpolate_func

* arguments and docstring fix

* Apply suggestions from code review

doc fixes

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
bhuvanmdev and amyeroberts authored May 24, 2024
1 parent a3cdff4 commit e5103a7
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 60 deletions.
164 changes: 133 additions & 31 deletions src/transformers/models/vit_mae/modeling_tf_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,38 @@ def build(self, input_shape=None):
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)

def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

batch_size, seq_len, dim = shape_list(embeddings)
num_patches = seq_len - 1

_, num_positions, _ = shape_list(self.position_embeddings)
num_positions -= 1

if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0),
method="bicubic",
)

patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
Expand Down Expand Up @@ -281,17 +313,23 @@ def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):

return sequence_unmasked, mask, ids_restore

def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)

def call(
self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :]
embeddings = embeddings + position_embeddings[:, 1:, :]

# masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)

# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
embeddings = tf.concat([cls_tokens, embeddings], axis=1)

Expand Down Expand Up @@ -329,15 +367,17 @@ def __init__(self, config: ViTMAEConfig, **kwargs):
name="projection",
)

def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly():
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
Expand Down Expand Up @@ -741,9 +781,13 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise
pixel_values=pixel_values,
training=training,
noise=noise,
interpolate_pos_encoding=interpolate_pos_encoding,
)

# Prepare head mask if needed
Expand Down Expand Up @@ -874,6 +918,9 @@ class TFViTMAEPreTrainedModel(TFPreTrainedModel):
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the position encodings at the encoder and decoder.
"""


Expand Down Expand Up @@ -902,6 +949,7 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -931,6 +979,7 @@ def call(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)

return outputs
Expand Down Expand Up @@ -1004,17 +1053,50 @@ def build(self, input_shape=None):
with tf.name_scope(layer.name):
layer.build(None)

def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

# [batch_size, num_patches + 1, hidden_size]
_, num_positions, dim = shape_list(self.decoder_pos_embed)

# -1 removes the class dimension since we later append it without interpolation
seq_len = shape_list(embeddings)[1] - 1
num_positions = num_positions - 1

# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, :1, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]

# interpolate the position embeddings
patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
size=(1, seq_len),
method="bicubic",
)

# [1, seq_len, hidden_size]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
# Adding the class token back
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

def call(
self,
hidden_states,
ids_restore,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
interpolate_pos_encoding=False,
):
# embed tokens
x = self.decoder_embed(hidden_states)

# append mask tokens to sequence
mask_tokens = tf.tile(
self.mask_token,
Expand All @@ -1023,10 +1105,12 @@ def call(
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token

if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
# add pos embed
hidden_states = x + self.decoder_pos_embed

hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -1083,11 +1167,13 @@ def get_input_embeddings(self):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError

def patchify(self, pixel_values):
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
interpolate_pos_encoding (`bool`, default `False`):
interpolation flag passed during the forward pass.
Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Expand All @@ -1099,11 +1185,12 @@ def patchify(self, pixel_values):
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

# sanity checks
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
if not interpolate_pos_encoding:
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1] % patch_size,
0,
Expand All @@ -1119,51 +1206,61 @@ def patchify(self, pixel_values):

# patchify
batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
num_patches_h = shape_list(pixel_values)[1] // patch_size
num_patches_w = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape(
pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
)
return patchified_pixel_values

def unpatchify(self, patchified_pixel_values):
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
"""
Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
num_patches_h * num_patches_w,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
)

# unpatchify
batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
)
return pixel_values

def forward_loss(self, pixel_values, pred, mask):
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
Expand All @@ -1172,11 +1269,13 @@ def forward_loss(self, pixel_values, pred, mask):
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns:
`tf.Tensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
Expand All @@ -1201,6 +1300,7 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -1234,16 +1334,18 @@ def call(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)

latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask

decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
# [batch_size, num_patches, patch_size**2*3]
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits

loss = self.forward_loss(pixel_values, logits, mask)
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)

if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
Expand Down
Loading

0 comments on commit e5103a7

Please sign in to comment.