From faa136b51ec4ec5858e5b0ae40eb7ef89a88b475 Mon Sep 17 00:00:00 2001 From: ad Date: Sun, 6 Oct 2024 11:21:43 +0200 Subject: [PATCH] optimize --- .../models/git/convert_git_to_pytorch.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/git/convert_git_to_pytorch.py b/src/transformers/models/git/convert_git_to_pytorch.py index 008eb1fcef1..ce14ef352c3 100644 --- a/src/transformers/models/git/convert_git_to_pytorch.py +++ b/src/transformers/models/git/convert_git_to_pytorch.py @@ -222,21 +222,19 @@ def sample_frame_indices(clip_len, frame_sample_rate, seg_len): file_path = hf_hub_download(repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset") with av.open(file_path) as container: - video_stream = next(s for s in container.streams if s.type == 'video') - total_frames = video_stream.frames - - # sample 6 frames + video = container.streams.video[0] + total_frames = video.frames indices = sample_frame_indices(clip_len=6, frame_sample_rate=4, seg_len=total_frames) frames = [] + container.seek(indices[0]) for i, frame in enumerate(container.decode(video=0)): - if i in indices: - frames.append(frame.to_ndarray(format='rgb24')) if len(frames) == 6: break - - video = np.stack(frames) - return video + if i in indices: + frames.append(frame.to_ndarray(format='rgb24')) + + return np.stack(frames) @torch.no_grad()