From 2a84256644d061fd46270fb2d2d543db06f2be2d Mon Sep 17 00:00:00 2001 From: BHUVAN M <121122109+bhuvanmdev@users.noreply.github.com> Date: Fri, 7 Jun 2024 23:14:16 +0530 Subject: [PATCH] interpolation added for TVP. (#30863) * Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * Update TVP model to interpolate pre-trained image pad prompter encodings * feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding * added required comments * docstring and argument fix * doc fixes and test case fix suggested in review. * varibale typo fix * styling and name fixes for padding interpolation flag. --- src/transformers/models/tvp/modeling_tvp.py | 127 ++++++++++++++++---- tests/models/tvp/test_modeling_tvp.py | 40 +++++- 2 files changed, 144 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index ba9acdbbcf93f7..ec00eee928617f 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -193,34 +193,81 @@ def __init__(self, config): self.token_type_embeddings = nn.Embedding(1, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings + self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings - def add_2d_positional_embeddings(self, grid): + def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + h0 = w0 = 1 + # if height dimension is to be interpolated + if height > self.max_grid_row_position_embeddings: + h0 = height / self.max_grid_row_position_embeddings + # if width dimension is to be interpolated + if width > self.max_grid_col_position_embeddings: + w0 = width / self.max_grid_col_position_embeddings + embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width) + embedding = nn.functional.interpolate( + embedding, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim) + return embedding + + def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False): """ Args: grid: (batch_size, height, width, hidden_dim) + interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. Returns: grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim) """ batch_size, height, width, hidden_dim = grid.shape # add row-wise position embeddings - row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, ) - row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim) - row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim) - grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically + # (height, ) + row_height = min(self.max_grid_row_position_embeddings, height) + row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device) + # (height, hidden_dim) + row_position_embeddings = self.row_position_embeddings(row_position_ids) + row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim) + # (batch_size, height, 1, hidden_dim) + row_position_embeddings = row_position_embeddings.view(*row_shape) # add column-wise position embeddings - col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, ) - col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim) - col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim) - return grid + col_position_embeddings.view(*col_shape) # broadcast automatically + row_width = min(self.max_grid_col_position_embeddings, width) + col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device) + # (width, hidden_dim) + col_position_embeddings = self.col_position_embeddings(col_position_ids) + col_shape = (batch_size, 1, row_width, hidden_dim) + # (batch_size, 1, width, hidden_dim) + col_position_embeddings = col_position_embeddings.view(*col_shape) + # (batch_size, height, width, hidden_dim) + positional_embeddings = row_position_embeddings + col_position_embeddings + + # This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings + if interpolate_pos_encoding and ( + height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings + ): + grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width) + else: + grid = grid + positional_embeddings + return grid - def forward(self, grid): + def forward(self, grid, interpolate_pos_encoding: bool = False): """ Args: grid: Array of shape (batch_size, num_frames, height, width, num_channels). It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note, num_frames can be 1 + interpolate_pos_encoding: (bool, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. Returns: embeddings: The embedding of grid with size (batch_size, height*width, num_channels) @@ -229,7 +276,7 @@ def forward(self, grid): batch_size, num_frames, height, width, num_channels = grid.shape # temporal mean pooling, (batch_size, height, width, hidden_size) grid = grid.mean(1) - grid = self.add_2d_positional_embeddings(grid) + grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding) # image token sequence, (batch_size, height*width, num_channels) visual_tokens = grid.view(batch_size, -1, num_channels) visual_tokens_shape = visual_tokens.shape[:-1] @@ -586,6 +633,9 @@ def _init_weights(self, module): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained image pad prompter encodings and positional encodings. """ @@ -639,7 +689,6 @@ def __init__(self, config): self.num_frames = config.num_frames self.max_img_size = config.max_img_size self.visual_prompter_apply = config.visual_prompter_apply - self.base_size = config.max_img_size - config.visual_prompt_size * 2 self.pad_up = nn.Parameter( torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) @@ -670,19 +719,49 @@ def __init__(self, config): ) ) - def forward(self, pixel_values): + def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high + resolution images (high resolution videos). + + """ + + # creates scale factor from height and width of original image wrt to the config.max_img_size + h0, w0 = height / self.max_img_size, width / self.max_img_size + + batch, num_frames, channels, prompt_height, prompt_width = prompt.shape + + # reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation + prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width) + prompt = nn.functional.interpolate( + prompt, + scale_factor=(h0, w0), + mode="bicubic", + align_corners=False, + ) + # reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width + prompt = prompt.reshape(batch, num_frames, channels, height, width) + return prompt + + def forward(self, pixel_values, interpolate_pad_encoding: bool = False): + height, width = ( + (pixel_values.shape[-2], pixel_values.shape[-1]) + if interpolate_pad_encoding + else (self.max_img_size, self.max_img_size) + ) if self.visual_prompter_apply not in ("add", "remove", "replace"): raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}") if self.visual_prompter_apply in ("replace", "remove"): - visual_prompt_mask = torch.ones( - [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device - ) + visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device) pixel_values *= visual_prompt_mask if self.visual_prompter_apply in ("replace", "add"): base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device) + prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4) prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3) prompt = torch.cat(pixel_values.size(0) * [prompt]) + if interpolate_pad_encoding: + prompt = self.interpolate_pad_encoding(prompt, height, width) pixel_values = pixel_values + prompt.to(pixel_values.dtype) return pixel_values @@ -738,6 +817,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" Returns: @@ -756,13 +836,17 @@ def forward( >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) ```""" return_dict = return_dict if return_dict is not None else self.config.return_dict - # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features. - pixel_values = self.vision_model(self.visual_prompter(pixel_values)) + pixel_values = self.vision_model( + self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding) + ) # (batch_size, sequence_length, hidden_size) text_embedding_output = self.embeddings(input_ids=input_ids) # (batch_size, visual_sequence_length, hidden_size) - visual_embedding_output = self.visual_embeddings(pixel_values) + visual_embedding_output = self.visual_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + if attention_mask is not None: # (batch_size, visual_sequence_length) visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2]) @@ -791,7 +875,6 @@ def forward( pooled_output = self.dropout(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -841,6 +924,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ): r""" labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*): @@ -869,9 +953,9 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, ) pooler_output = outputs[1] - logits = self.video_grounding_head(pooler_output) loss = None @@ -884,7 +968,6 @@ def forward( + self.config.distance_loss_weight * loss_dict["distance"] + self.config.duration_loss_weight * loss_dict["duration"] ) - if not return_dict: outputs = (logits,) + outputs[2:] if loss is not None: diff --git a/tests/models/tvp/test_modeling_tvp.py b/tests/models/tvp/test_modeling_tvp.py index a7db94c1a48969..b8f5fa48226c03 100644 --- a/tests/models/tvp/test_modeling_tvp.py +++ b/tests/models/tvp/test_modeling_tvp.py @@ -256,7 +256,7 @@ def prepare_img(): class TvpModelIntegrationTests(unittest.TestCase): @cached_property def default_image_processor(self): - return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None + return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") def test_inference_no_head(self): model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device) @@ -297,3 +297,41 @@ def test_inference_with_head(self): assert outputs.logits.shape == expected_shape expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4)) + + def test_interpolate_inference_no_head(self): + model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() # 480X640 + encoding = image_processor( + images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False + ) + input_ids = torch.tensor([[1, 2]]) + attention_mask = torch.tensor([[1, 1]]) + encoding.update({"input_ids": input_ids, "attention_mask": attention_mask}) + encoding.to(torch_device) + + with torch.no_grad(): + outputs = model(**encoding, interpolate_pos_encoding=True) + + expected_shape = torch.Size((1, 1212, 128)) + assert outputs.last_hidden_state.shape == expected_shape + + def test_interpolate_inference_with_head(self): + model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device) + + image_processor = self.default_image_processor + image = prepare_img() # 480X640 + encoding = image_processor( + images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False + ) + input_ids = torch.tensor([[1, 2]]) + attention_mask = torch.tensor([[1, 1]]) + encoding.update({"input_ids": input_ids, "attention_mask": attention_mask}) + encoding.to(torch_device) + + with torch.no_grad(): + outputs = model(**encoding, interpolate_pos_encoding=True, output_hidden_states=True) + + expected_shape = torch.Size((1, 1212, 128)) + assert outputs.hidden_states[-1].shape == expected_shape