Skip to content

Commit

Permalink
Add docstrings and fix VIVIT examples (huggingface#25628)
Browse files Browse the repository at this point in the history
* fix docstrings and examples

* docstring update

* add missing whitespace
  • Loading branch information
Geometrein authored and parambharat committed Sep 26, 2023
1 parent c6ff584 commit 8e14d67
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 6 deletions.
11 changes: 11 additions & 0 deletions src/transformers/models/git/convert_git_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,17 @@ def prepare_video():
np.random.seed(0)

def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
"""
Sample a given number of frame indices from the video.
Args:
clip_len (`int`): Total number of frames to sample.
frame_sample_rate (`int`): Sample every n-th frame.
seg_len (`int`): Maximum allowed index of sample's last frame.
Returns:
indices (`List[int]`): List of sampled frame indices
"""
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/timesformer/modeling_timesformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down Expand Up @@ -730,6 +739,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/videomae/modeling_videomae.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down Expand Up @@ -1008,6 +1017,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down
31 changes: 25 additions & 6 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand All @@ -547,8 +556,8 @@ def forward(
>>> container = av.open(file_path)
>>> # sample 32 frames
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container=container, indices=indices)
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
>>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
Expand Down Expand Up @@ -639,8 +648,9 @@ def forward(
```python
>>> import av
>>> import numpy as np
>>> import torch
>>> from transformers import VivitImageProcessor, VivitModel
>>> from transformers import VivitImageProcessor, VivitForVideoClassification
>>> from huggingface_hub import hf_hub_download
>>> np.random.seed(0)
Expand Down Expand Up @@ -668,6 +678,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand All @@ -683,8 +702,8 @@ def forward(
>>> container = av.open(file_path)
>>> # sample 32 frames
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader))
>>> video = videoreader.get_batch(indices).asnumpy()
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
>>> video = read_video_pyav(container=container, indices=indices)
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
>>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
Expand All @@ -698,7 +717,7 @@ def forward(
>>> # model predicts one of the 400 Kinetics-400 classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
eating spaghetti
LABEL_116
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/x_clip/modeling_x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down Expand Up @@ -1423,6 +1432,15 @@ def get_video_features(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down Expand Up @@ -1531,6 +1549,15 @@ def forward(
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
... '''
... Sample a given number of frame indices from the video.
... Args:
... clip_len (`int`): Total number of frames to sample.
... frame_sample_rate (`int`): Sample every n-th frame.
... seg_len (`int`): Maximum allowed index of sample's last frame.
... Returns:
... indices (`List[int]`): List of sampled frame indices
... '''
... converted_len = int(clip_len * frame_sample_rate)
... end_idx = np.random.randint(converted_len, seg_len)
... start_idx = end_idx - converted_len
Expand Down

0 comments on commit 8e14d67

Please sign in to comment.