Skip to content

Commit

Permalink
Added updates for shot wise feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
digbose92 committed Mar 19, 2023
1 parent f51ca48 commit 0cd6f64
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 40 deletions.
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ conda create --prefix <path> python=3.8
* Install pytorch using the following command:

```bash
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
```
* Install additional requirements using the following:

Expand All @@ -52,14 +52,37 @@ pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
```

## Extracting transcripts using whisper-X
## Trancripts extraction

### Extracting transcripts using Whisper

* Install whisper using the following command:
```bash
pip install -U openai-whisper
```
* While instantiating the model provide the download root path for the model:
```python
import whisper
model=whisper.load_model("large", download_root="path/to/download/model")
```

### Extracting transcripts using whisper-X

* Follow the instructions listed in [Whisper-X](https://github.com/m-bain/whisperX) for installation:

```
pip install git+https://github.com/m-bain/whisperx.git
```

## Feature extraction

* Go to folder feature_extraction and for extracting shot level features using vision transformers use the following

```bash
CUDA_VISIBLE_DEVICES=3 python extract_vit_features.py --feature_folder <destination vit features> --video_folder <base folder containing the shots> --
model_name google/vit-base-patch16-224 --video_type shot --shot_subfolder <type of shot here>
```

## TODOS

* LSTM on the CLIP features (variable length) and MHA baselines
Expand Down
122 changes: 122 additions & 0 deletions feature_extraction/extract_vit_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import pandas as pd
import argparse
import timm
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import cv2
import math
import pickle

activation = {}
def getActivation(name):
# the hook signature
def hook(model, input, output):
activation[name] = output.detach()
return hook

def run_frame_wise_feature_inference(model,processor,filename,device,dim=768,desired_frameRate=4):

vcap=cv2.VideoCapture(filename)
frameRate = vcap.get(5)
intfactor=math.ceil(frameRate/desired_frameRate)
feature_list=np.zeros((0,dim))
frame_id=0

length = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT))
tensor_list=[]

while True:
ret, frame = vcap.read()
if(ret==True):
if (frame_id % intfactor == 0):
#print(frame_id)
frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
frame=Image.fromarray(frame)
inputs = processor(frame,return_tensors='pt')
#print(inputs.keys())
inputs['pixel_values']=inputs['pixel_values'].to(device)

with torch.no_grad():
outputs=model(**inputs,output_hidden_states=True)

hidden_states=outputs['hidden_states']


cls_embedding=hidden_states[-1][:,0,:].cpu().numpy()
feature_list=np.vstack([feature_list,cls_embedding]) #add the features to the numpy array

torch.cuda.empty_cache()
frame_id=frame_id+1
else:
break
if cv2.waitKey(1) & 0xFF == ord('q'):
break

return feature_list, frame_id


#argparse arguments
parser = argparse.ArgumentParser(description='Extract vit base features from a video file')
parser.add_argument('--feature_folder', type=str, help='path to the destination feature folder')
parser.add_argument('--video_folder', type=str, help='path to the destination feature folder')
parser.add_argument('--model_name', type=str, help='path to the model name')
parser.add_argument('--video_type',type=str,default='shot',help='path to the video type')
parser.add_argument('--shot_subfolder', type=str, help='path to the shot subfolder')
args=parser.parse_args()

model_name=args.model_name
shot_subfolder=args.shot_subfolder
video_type=args.video_type
video_folder=args.video_folder

#declare vit models from timm specification
print('Loading model')
model = ViTForImageClassification.from_pretrained(model_name)
model.config.return_dict=True
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
model.eval()

processor=ViTFeatureExtractor.from_pretrained(model_name)
#print layer wise names
#declare the transforms


print('Loaded model')
#load the model along with the logits
# h1 = model.pre_logits.register_forward_hook(getActivation('pre_logits'))

if(video_type=='shot'):
#shot subfolder
#shot_subfolder=os.path.join(args.feature_folder,args.shot_subfolder)
shot_folder_name=os.path.join(video_folder,shot_subfolder)
#video file list
video_file_list=os.listdir(shot_folder_name)

for video_file in tqdm(video_file_list):

video_subfolder=os.path.join(shot_folder_name,video_file)
destination_file=os.path.join(args.feature_folder,video_file+'.pkl')

shot_list=os.listdir(video_subfolder) #list of shots
shot_dict=dict()

for shot_file in tqdm(shot_list):

shot_filename=os.path.join(video_subfolder,shot_file)
feat_list,_=run_frame_wise_feature_inference(model,processor,shot_filename,device)

shot_dict[shot_file]=feat_list

#save the shot_dict

with open(destination_file,'wb') as f:
pickle.dump(shot_dict,f)


18 changes: 9 additions & 9 deletions feature_extraction/torchaudio_examples_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@

n_frames = fbank.shape[0]
difference = max_length - n_frames
print(n_frames,fbank.shape)
# # pad or truncate, depending on difference
# if difference > 0:
# pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
# fbank = pad_module(fbank)
# elif difference < 0:
# fbank = fbank[0:max_length, :]
# fbank = fbank.numpy()

# pad or truncate, depending on difference
if difference > 0:
pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
fbank = pad_module(fbank)
elif difference < 0:
fbank = fbank[0:max_length, :]
fbank = fbank.numpy()

print(fbank.shape)
# print(fbank.shape)

Large diffs are not rendered by default.

143 changes: 132 additions & 11 deletions notebooks/ads_distribution_data_tone_social_message.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import pandas as pd
import argparse
import whisper
import torch
from tqdm import tqdm

#argparse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, help='large')
parser.add_argument('--download_root_folder', type=str, help='Download path root for the model weights')
parser.add_argument('--source_folder',type=str,help='Source folder for the audio files')
parser.add_argument('--save_folder',type=str,help='Save folder for the extracted transcripts')

#command line arguments
args = parser.parse_args()
model_name = args.model_name
download_root_folder = args.download_root_folder
source_folder=args.source_folder
save_folder=args.save_folder

#load the model
model = whisper.load_model(model_name, download_root=download_root_folder)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

wav_file_names=os.listdir(source_folder)
wav_file_names=[os.path.join(source_folder,i) for i in wav_file_names]

for wav_file in tqdm(wav_file_names):
file_key=os.path.splitext(wav_file.split("/")[-1])[0]
save_file_name=file_key+".txt"
save_file_path=os.path.join(save_folder,save_file_name)

if (os.path.exists(save_file_path) is False):
result = model.transcribe(wav_file)

#save the text in a specific file

file_key=os.path.splitext(wav_file.split("/")[-1])[0]
save_file_name=file_key+".txt"

save_file_path=os.path.join(save_folder,save_file_name)

with open(save_file_path, 'w') as f:
f.write(result['text'])

#print(result['text'])

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ bertopic
importlib-resources
transformers
scenedetect==0.6.1
timm

0 comments on commit 0cd6f64

Please sign in to comment.