Skip to content

Commit

Permalink
interpolation added for TVP. (#30863)
Browse files Browse the repository at this point in the history
* Update TVP model to interpolate pre-trained image pad prompter encodings

* feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding

* added required comments

* Update TVP model to interpolate pre-trained image pad prompter encodings

* feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding

* added required comments

* docstring and argument fix

* doc fixes and test case fix suggested in review.

* varibale typo fix

* styling and name fixes for padding interpolation flag.
  • Loading branch information
bhuvanmdev authored and itazap committed Jun 17, 2024
1 parent 5782d7d commit 2a84256
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 23 deletions.
127 changes: 105 additions & 22 deletions src/transformers/models/tvp/modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,34 +193,81 @@ def __init__(self, config):
self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings

def add_2d_positional_embeddings(self, grid):
def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
resolution images (high resolution videos).
"""
h0 = w0 = 1
# if height dimension is to be interpolated
if height > self.max_grid_row_position_embeddings:
h0 = height / self.max_grid_row_position_embeddings
# if width dimension is to be interpolated
if width > self.max_grid_col_position_embeddings:
w0 = width / self.max_grid_col_position_embeddings
embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
embedding = nn.functional.interpolate(
embedding,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
return embedding

def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
"""
Args:
grid: (batch_size, height, width, hidden_dim)
interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
Returns:
grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
"""
batch_size, height, width, hidden_dim = grid.shape

# add row-wise position embeddings
row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, )
row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim)
row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim)
grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically
# (height, )
row_height = min(self.max_grid_row_position_embeddings, height)
row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
# (height, hidden_dim)
row_position_embeddings = self.row_position_embeddings(row_position_ids)
row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
# (batch_size, height, 1, hidden_dim)
row_position_embeddings = row_position_embeddings.view(*row_shape)

# add column-wise position embeddings
col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, )
col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim)
col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim)
return grid + col_position_embeddings.view(*col_shape) # broadcast automatically
row_width = min(self.max_grid_col_position_embeddings, width)
col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
# (width, hidden_dim)
col_position_embeddings = self.col_position_embeddings(col_position_ids)
col_shape = (batch_size, 1, row_width, hidden_dim)
# (batch_size, 1, width, hidden_dim)
col_position_embeddings = col_position_embeddings.view(*col_shape)
# (batch_size, height, width, hidden_dim)
positional_embeddings = row_position_embeddings + col_position_embeddings

# This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings
if interpolate_pos_encoding and (
height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
):
grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
else:
grid = grid + positional_embeddings
return grid

def forward(self, grid):
def forward(self, grid, interpolate_pos_encoding: bool = False):
"""
Args:
grid: Array of shape (batch_size, num_frames, height, width, num_channels).
It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
num_frames can be 1
interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
Returns:
embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
Expand All @@ -229,7 +276,7 @@ def forward(self, grid):
batch_size, num_frames, height, width, num_channels = grid.shape
# temporal mean pooling, (batch_size, height, width, hidden_size)
grid = grid.mean(1)
grid = self.add_2d_positional_embeddings(grid)
grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
# image token sequence, (batch_size, height*width, num_channels)
visual_tokens = grid.view(batch_size, -1, num_channels)
visual_tokens_shape = visual_tokens.shape[:-1]
Expand Down Expand Up @@ -586,6 +633,9 @@ def _init_weights(self, module):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained image pad prompter encodings and positional encodings.
"""


Expand Down Expand Up @@ -639,7 +689,6 @@ def __init__(self, config):
self.num_frames = config.num_frames
self.max_img_size = config.max_img_size
self.visual_prompter_apply = config.visual_prompter_apply

self.base_size = config.max_img_size - config.visual_prompt_size * 2
self.pad_up = nn.Parameter(
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
Expand Down Expand Up @@ -670,19 +719,49 @@ def __init__(self, config):
)
)

def forward(self, pixel_values):
def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
resolution images (high resolution videos).
"""

# creates scale factor from height and width of original image wrt to the config.max_img_size
h0, w0 = height / self.max_img_size, width / self.max_img_size

batch, num_frames, channels, prompt_height, prompt_width = prompt.shape

# reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
prompt = nn.functional.interpolate(
prompt,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
# reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
prompt = prompt.reshape(batch, num_frames, channels, height, width)
return prompt

def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
height, width = (
(pixel_values.shape[-2], pixel_values.shape[-1])
if interpolate_pad_encoding
else (self.max_img_size, self.max_img_size)
)
if self.visual_prompter_apply not in ("add", "remove", "replace"):
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
if self.visual_prompter_apply in ("replace", "remove"):
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply in ("replace", "add"):
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)

prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
prompt = torch.cat(pixel_values.size(0) * [prompt])
if interpolate_pad_encoding:
prompt = self.interpolate_pad_encoding(prompt, height, width)
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
return pixel_values

Expand Down Expand Up @@ -738,6 +817,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
r"""
Returns:
Expand All @@ -756,13 +836,17 @@ def forward(
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
```"""
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
pixel_values = self.vision_model(self.visual_prompter(pixel_values))
pixel_values = self.vision_model(
self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
)
# (batch_size, sequence_length, hidden_size)
text_embedding_output = self.embeddings(input_ids=input_ids)
# (batch_size, visual_sequence_length, hidden_size)
visual_embedding_output = self.visual_embeddings(pixel_values)
visual_embedding_output = self.visual_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)

if attention_mask is not None:
# (batch_size, visual_sequence_length)
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
Expand Down Expand Up @@ -791,7 +875,6 @@ def forward(
pooled_output = self.dropout(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
Expand Down Expand Up @@ -841,6 +924,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
):
r"""
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
Expand Down Expand Up @@ -869,9 +953,9 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooler_output = outputs[1]

logits = self.video_grounding_head(pooler_output)

loss = None
Expand All @@ -884,7 +968,6 @@ def forward(
+ self.config.distance_loss_weight * loss_dict["distance"]
+ self.config.duration_loss_weight * loss_dict["duration"]
)

if not return_dict:
outputs = (logits,) + outputs[2:]
if loss is not None:
Expand Down
40 changes: 39 additions & 1 deletion tests/models/tvp/test_modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def prepare_img():
class TvpModelIntegrationTests(unittest.TestCase):
@cached_property
def default_image_processor(self):
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp")

def test_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
Expand Down Expand Up @@ -297,3 +297,41 @@ def test_inference_with_head(self):
assert outputs.logits.shape == expected_shape
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))

def test_interpolate_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)

image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
)
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)

with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True)

expected_shape = torch.Size((1, 1212, 128))
assert outputs.last_hidden_state.shape == expected_shape

def test_interpolate_inference_with_head(self):
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)

image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
)
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)

with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True, output_hidden_states=True)

expected_shape = torch.Size((1, 1212, 128))
assert outputs.hidden_states[-1].shape == expected_shape

0 comments on commit 2a84256

Please sign in to comment.