Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpolation added for TVP. #30863

Merged
merged 12 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 119 additions & 22 deletions src/transformers/models/tvp/modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,36 @@ def __init__(self, config):
self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings

def add_2d_positional_embeddings(self, grid):
def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
resolution images (high resolution videos).

"""
# if height dimension is to be interpolated
if height > self.max_grid_row_position_embeddings:
h0 = height / self.max_grid_row_position_embeddings
else:
h0 = 1
# if width dimension is to be interpolated
if width > self.max_grid_col_position_embeddings:
w0 = width / self.max_grid_col_position_embeddings
else:
w0 = 1
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
embedding = nn.functional.interpolate(
embedding,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
return embedding

def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding=False):
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
grid: (batch_size, height, width, hidden_dim)
Expand All @@ -194,18 +222,48 @@ def add_2d_positional_embeddings(self, grid):
batch_size, height, width, hidden_dim = grid.shape

# add row-wise position embeddings
row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, )
row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim)
row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim)
grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically

# (height, )
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
row_position_ids = torch.arange(
min(self.max_grid_row_position_embeddings, height), dtype=torch.long, device=grid.device
)
# (height, hidden_dim)
row_position_embeddings = self.row_position_embeddings(row_position_ids)
# (1, height, 1, hidden_dim)
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
row_shape = (1,) * (len(grid.shape) - 3) + (
min(self.max_grid_row_position_embeddings, height),
1,
hidden_dim,
)
# (1, height, 1, hidden_dim)
row_position_embeddings = row_position_embeddings.view(*row_shape)
# add column-wise position embeddings
col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, )
col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim)
col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim)
return grid + col_position_embeddings.view(*col_shape) # broadcast automatically
# (width, )
col_position_ids = torch.arange(
min(self.max_grid_col_position_embeddings, width), dtype=torch.long, device=grid.device
)
# (width, hidden_dim)
col_position_embeddings = self.col_position_embeddings(col_position_ids)
# (1, 1, width, hidden_dim)
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
col_shape = (
batch_size,
1,
min(self.max_grid_col_position_embeddings, width),
hidden_dim,
)
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
# (1, 1, width, hidden_dim)
col_position_embeddings = col_position_embeddings.view(*col_shape)
# (1, height, width, hidden_dim)
positional_embeddings = row_position_embeddings + col_position_embeddings
# This interpolation gets triggered ONLY when the input image dim is larger than 1600X1600 in any one of the dimension(i.e height or width) for self.max_grid_(row\col)_position_embeddings == 100
if interpolate_pos_encoding and (
height > self.max_grid_row_position_embeddings or width > self.max_grid_row_position_embeddings
):
grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
else:
grid = grid + positional_embeddings
return grid

def forward(self, grid):
def forward(self, grid, interpolate_pos_encoding=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, grid, interpolate_pos_encoding=False):
def forward(self, grid, interpolate_pos_encoding: bool = False):

"""
Args:
grid: Array of shape (batch_size, num_frames, height, width, num_channels).
Expand All @@ -219,7 +277,7 @@ def forward(self, grid):
batch_size, num_frames, height, width, num_channels = grid.shape
# temporal mean pooling, (batch_size, height, width, hidden_size)
grid = grid.mean(1)
grid = self.add_2d_positional_embeddings(grid)
grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
# image token sequence, (batch_size, height*width, num_channels)
visual_tokens = grid.view(batch_size, -1, num_channels)
visual_tokens_shape = visual_tokens.shape[:-1]
Expand Down Expand Up @@ -576,6 +634,9 @@ def _init_weights(self, module):

return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

interpolate_pos_encoding (`bool`, *Defaults* True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
interpolate_pos_encoding (`bool`, *Defaults* True):
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):

Whether to interpolate the pre-trained image pad prompter encodings.
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
"""


Expand Down Expand Up @@ -629,7 +690,6 @@ def __init__(self, config):
self.num_frames = config.num_frames
self.max_img_size = config.max_img_size
self.visual_prompter_apply = config.visual_prompter_apply

self.base_size = config.max_img_size - config.visual_prompt_size * 2
self.pad_up = nn.Parameter(
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
Expand Down Expand Up @@ -660,20 +720,51 @@ def __init__(self, config):
)
)

def forward(self, pixel_values):
def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
resolution images (high resolution videos).

"""

# creates scale factor from height and width of original image wrt to the config.max_img_size
h0, w0 = height / self.max_img_size, width / self.max_img_size

batch, num_frames, channels, prompt_height, prompt_width = prompt.shape

# reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
prompt = nn.functional.interpolate(
prompt,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
# reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
prompt = prompt.reshape(batch, num_frames, channels, height, width)
return prompt

def forward(self, pixel_values, interpolate_pos_encoding=False):
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
h, w = (
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
(pixel_values.shape[-2], pixel_values.shape[-1])
if interpolate_pos_encoding
else (self.max_img_size, self.max_img_size)
)
if self.visual_prompter_apply not in ("add", "remove", "replace"):
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
if self.visual_prompter_apply in ("replace", "remove"):
visual_prompt_mask = torch.ones(
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
visual_prompt_mask = torch.ones([h, w], dtype=pixel_values.dtype, device=pixel_values.device)
pixel_values *= visual_prompt_mask
if self.visual_prompter_apply in ("replace", "add"):
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)

prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
prompt = torch.cat(pixel_values.size(0) * [prompt])
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
if interpolate_pos_encoding:
pixel_values = pixel_values + self.interpolate_pad_encoding(prompt, h, w).to(pixel_values.dtype)
else:
pixel_values = pixel_values + prompt.to(pixel_values.dtype)
return pixel_values


Expand Down Expand Up @@ -728,6 +819,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
):
r"""
Returns:
Expand All @@ -746,13 +838,17 @@ def forward(
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
```"""
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
pixel_values = self.vision_model(self.visual_prompter(pixel_values))
pixel_values = self.vision_model(
self.visual_prompter(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
)
# (batch_size, sequence_length, hidden_size)
text_embedding_output = self.embeddings(input_ids=input_ids)
# (batch_size, visual_sequence_length, hidden_size)
visual_embedding_output = self.visual_embeddings(pixel_values)
visual_embedding_output = self.visual_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)

if attention_mask is not None:
# (batch_size, visual_sequence_length)
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
Expand Down Expand Up @@ -831,6 +927,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
):
r"""
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
Expand Down Expand Up @@ -859,9 +956,9 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooler_output = outputs[1]

logits = self.video_grounding_head(pooler_output)

loss = None
Expand Down
36 changes: 35 additions & 1 deletion tests/models/tvp/test_modeling_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def prepare_img():
class TvpModelIntegrationTests(unittest.TestCase):
@cached_property
def default_image_processor(self):
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp")

def test_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
Expand Down Expand Up @@ -265,3 +265,37 @@ def test_inference_with_head(self):
assert outputs.logits.shape == expected_shape
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))

def test_interpolate_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)

image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(images=image, return_tensors="pt", do_resize=False, do_pad=False)
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)

with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True)

expected_shape = torch.Size((1, 1212, 128))
assert outputs.last_hidden_state.shape == expected_shape

def test_interpolate_inference_with_head(self):
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)

image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(images=image, return_tensors="pt", do_resize=False, do_pad=False)
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)

with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True)

expected_shape = torch.Size((1, 2))
assert outputs.logits.shape == expected_shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this has a task specific head, we should test the shape of the final hidden states here instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a clarification, the model does not return the last hidden state by default.
Should I..

  1. Add last_hidden_state parameter in TvpVideoGroundingOutput in order to obtain the last_hidden_state.
    OR
  2. Use output_hidden_states=True and do assert outputs.hidden_states[-1].shape == expected_shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 2 :)