From 8e14d6701dd0d902e2d28a5d5a6acb7c457ead5f Mon Sep 17 00:00:00 2001 From: Tigran Khachatryan <65066173+Geometrein@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:08:47 +0300 Subject: [PATCH] Add docstrings and fix VIVIT examples (#25628) * fix docstrings and examples * docstring update * add missing whitespace --- .../models/git/convert_git_to_pytorch.py | 11 +++++++ src/transformers/models/git/modeling_git.py | 9 ++++++ .../timesformer/modeling_timesformer.py | 18 +++++++++++ .../models/videomae/modeling_videomae.py | 18 +++++++++++ .../models/vivit/modeling_vivit.py | 31 +++++++++++++++---- .../models/x_clip/modeling_x_clip.py | 27 ++++++++++++++++ 6 files changed, 108 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/git/convert_git_to_pytorch.py b/src/transformers/models/git/convert_git_to_pytorch.py index e089ec89854f96..5dde4da15e5195 100644 --- a/src/transformers/models/git/convert_git_to_pytorch.py +++ b/src/transformers/models/git/convert_git_to_pytorch.py @@ -200,6 +200,17 @@ def prepare_video(): np.random.seed(0) def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + """ + Sample a given number of frame indices from the video. + + Args: + clip_len (`int`): Total number of frames to sample. + frame_sample_rate (`int`): Sample every n-th frame. + seg_len (`int`): Maximum allowed index of sample's last frame. + + Returns: + indices (`List[int]`): List of sampled frame indices + """ converted_len = int(clip_len * frame_sample_rate) end_idx = np.random.randint(converted_len, seg_len) start_idx = end_idx - converted_len diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 80eaca01e9463f..eafb72f3523ff2 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1465,6 +1465,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index ad5120ab575c7c..5fd932ece6640b 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -601,6 +601,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -730,6 +739,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index c62d0c4632cb68..07c32d14929037 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -612,6 +612,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -1008,6 +1017,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index fe633a5f205fb6..ff1bc003ad38f4 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -532,6 +532,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -547,8 +556,8 @@ def forward( >>> container = av.open(file_path) >>> # sample 32 frames - >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container=container, indices=indices) >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") >>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400") @@ -639,8 +648,9 @@ def forward( ```python >>> import av >>> import numpy as np + >>> import torch - >>> from transformers import VivitImageProcessor, VivitModel + >>> from transformers import VivitImageProcessor, VivitForVideoClassification >>> from huggingface_hub import hf_hub_download >>> np.random.seed(0) @@ -668,6 +678,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -683,8 +702,8 @@ def forward( >>> container = av.open(file_path) >>> # sample 32 frames - >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=len(videoreader)) - >>> video = videoreader.get_batch(indices).asnumpy() + >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container=container, indices=indices) >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") >>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400") @@ -698,7 +717,7 @@ def forward( >>> # model predicts one of the 400 Kinetics-400 classes >>> predicted_label = logits.argmax(-1).item() >>> print(model.config.id2label[predicted_label]) - eating spaghetti + LABEL_116 ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index d6f9bf9d818f3a..da7eddff8df838 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -1105,6 +1105,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -1423,6 +1432,15 @@ def get_video_features( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len @@ -1531,6 +1549,15 @@ def forward( >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' ... converted_len = int(clip_len * frame_sample_rate) ... end_idx = np.random.randint(converted_len, seg_len) ... start_idx = end_idx - converted_len