-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added updates for shot wise feature extraction
- Loading branch information
Showing
11 changed files
with
453 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import os | ||
import pandas as pd | ||
import argparse | ||
import timm | ||
from PIL import Image | ||
from transformers import ViTFeatureExtractor, ViTForImageClassification | ||
import torch | ||
import torch.nn as nn | ||
import pandas as pd | ||
import numpy as np | ||
from tqdm import tqdm | ||
import cv2 | ||
import math | ||
import pickle | ||
|
||
activation = {} | ||
def getActivation(name): | ||
# the hook signature | ||
def hook(model, input, output): | ||
activation[name] = output.detach() | ||
return hook | ||
|
||
def run_frame_wise_feature_inference(model,processor,filename,device,dim=768,desired_frameRate=4): | ||
|
||
vcap=cv2.VideoCapture(filename) | ||
frameRate = vcap.get(5) | ||
intfactor=math.ceil(frameRate/desired_frameRate) | ||
feature_list=np.zeros((0,dim)) | ||
frame_id=0 | ||
|
||
length = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
tensor_list=[] | ||
|
||
while True: | ||
ret, frame = vcap.read() | ||
if(ret==True): | ||
if (frame_id % intfactor == 0): | ||
#print(frame_id) | ||
frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | ||
frame=Image.fromarray(frame) | ||
inputs = processor(frame,return_tensors='pt') | ||
#print(inputs.keys()) | ||
inputs['pixel_values']=inputs['pixel_values'].to(device) | ||
|
||
with torch.no_grad(): | ||
outputs=model(**inputs,output_hidden_states=True) | ||
|
||
hidden_states=outputs['hidden_states'] | ||
|
||
|
||
cls_embedding=hidden_states[-1][:,0,:].cpu().numpy() | ||
feature_list=np.vstack([feature_list,cls_embedding]) #add the features to the numpy array | ||
|
||
torch.cuda.empty_cache() | ||
frame_id=frame_id+1 | ||
else: | ||
break | ||
if cv2.waitKey(1) & 0xFF == ord('q'): | ||
break | ||
|
||
return feature_list, frame_id | ||
|
||
|
||
#argparse arguments | ||
parser = argparse.ArgumentParser(description='Extract vit base features from a video file') | ||
parser.add_argument('--feature_folder', type=str, help='path to the destination feature folder') | ||
parser.add_argument('--video_folder', type=str, help='path to the destination feature folder') | ||
parser.add_argument('--model_name', type=str, help='path to the model name') | ||
parser.add_argument('--video_type',type=str,default='shot',help='path to the video type') | ||
parser.add_argument('--shot_subfolder', type=str, help='path to the shot subfolder') | ||
args=parser.parse_args() | ||
|
||
model_name=args.model_name | ||
shot_subfolder=args.shot_subfolder | ||
video_type=args.video_type | ||
video_folder=args.video_folder | ||
|
||
#declare vit models from timm specification | ||
print('Loading model') | ||
model = ViTForImageClassification.from_pretrained(model_name) | ||
model.config.return_dict=True | ||
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model=model.to(device) | ||
model.eval() | ||
|
||
processor=ViTFeatureExtractor.from_pretrained(model_name) | ||
#print layer wise names | ||
#declare the transforms | ||
|
||
|
||
print('Loaded model') | ||
#load the model along with the logits | ||
# h1 = model.pre_logits.register_forward_hook(getActivation('pre_logits')) | ||
|
||
if(video_type=='shot'): | ||
#shot subfolder | ||
#shot_subfolder=os.path.join(args.feature_folder,args.shot_subfolder) | ||
shot_folder_name=os.path.join(video_folder,shot_subfolder) | ||
#video file list | ||
video_file_list=os.listdir(shot_folder_name) | ||
|
||
for video_file in tqdm(video_file_list): | ||
|
||
video_subfolder=os.path.join(shot_folder_name,video_file) | ||
destination_file=os.path.join(args.feature_folder,video_file+'.pkl') | ||
|
||
shot_list=os.listdir(video_subfolder) #list of shots | ||
shot_dict=dict() | ||
|
||
for shot_file in tqdm(shot_list): | ||
|
||
shot_filename=os.path.join(video_subfolder,shot_file) | ||
feat_list,_=run_frame_wise_feature_inference(model,processor,shot_filename,device) | ||
|
||
shot_dict[shot_file]=feat_list | ||
|
||
#save the shot_dict | ||
|
||
with open(destination_file,'wb') as f: | ||
pickle.dump(shot_dict,f) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 116 additions & 18 deletions
134
notebooks/.ipynb_checkpoints/ads_distribution_data_tone_social_message-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
143 changes: 132 additions & 11 deletions
143
notebooks/ads_distribution_data_tone_social_message.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions
48
preprocess_scripts/audio_transcription_extraction/audio_transcript_extraction_whisper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import pandas as pd | ||
import argparse | ||
import whisper | ||
import torch | ||
from tqdm import tqdm | ||
|
||
#argparse command line arguments | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name', type=str, help='large') | ||
parser.add_argument('--download_root_folder', type=str, help='Download path root for the model weights') | ||
parser.add_argument('--source_folder',type=str,help='Source folder for the audio files') | ||
parser.add_argument('--save_folder',type=str,help='Save folder for the extracted transcripts') | ||
|
||
#command line arguments | ||
args = parser.parse_args() | ||
model_name = args.model_name | ||
download_root_folder = args.download_root_folder | ||
source_folder=args.source_folder | ||
save_folder=args.save_folder | ||
|
||
#load the model | ||
model = whisper.load_model(model_name, download_root=download_root_folder) | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
wav_file_names=os.listdir(source_folder) | ||
wav_file_names=[os.path.join(source_folder,i) for i in wav_file_names] | ||
|
||
for wav_file in tqdm(wav_file_names): | ||
file_key=os.path.splitext(wav_file.split("/")[-1])[0] | ||
save_file_name=file_key+".txt" | ||
save_file_path=os.path.join(save_folder,save_file_name) | ||
|
||
if (os.path.exists(save_file_path) is False): | ||
result = model.transcribe(wav_file) | ||
|
||
#save the text in a specific file | ||
|
||
file_key=os.path.splitext(wav_file.split("/")[-1])[0] | ||
save_file_name=file_key+".txt" | ||
|
||
save_file_path=os.path.join(save_folder,save_file_name) | ||
|
||
with open(save_file_path, 'w') as f: | ||
f.write(result['text']) | ||
|
||
#print(result['text']) | ||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ bertopic | |
importlib-resources | ||
transformers | ||
scenedetect==0.6.1 | ||
timm |