-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_vit_features.py
139 lines (106 loc) · 4.53 KB
/
extract_vit_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import pandas as pd
import argparse
import timm
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm
import cv2
import math
import pickle
activation = {}
def getActivation(name):
# the hook signature
def hook(model, input, output):
activation[name] = output.detach()
return hook
def run_frame_wise_feature_inference(model,processor,filename,device,dim=768,desired_frameRate=4):
vcap=cv2.VideoCapture(filename)
frameRate = vcap.get(5)
intfactor=math.ceil(frameRate/desired_frameRate)
feature_list=np.zeros((0,dim))
frame_id=0
length = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT))
tensor_list=[]
while True:
ret, frame = vcap.read()
if(ret==True):
if (frame_id % intfactor == 0):
#print(frame_id)
frame=cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
frame=Image.fromarray(frame)
inputs = processor(frame,return_tensors='pt')
#print(inputs.keys())
inputs['pixel_values']=inputs['pixel_values'].to(device)
with torch.no_grad():
outputs=model(**inputs,output_hidden_states=True)
hidden_states=outputs['hidden_states']
cls_embedding=hidden_states[-1][:,0,:].cpu().numpy()
feature_list=np.vstack([feature_list,cls_embedding]) #add the features to the numpy array
torch.cuda.empty_cache()
frame_id=frame_id+1
else:
break
if cv2.waitKey(1) & 0xFF == ord('q'):
break
return feature_list, frame_id
#argparse arguments
parser = argparse.ArgumentParser(description='Extract vit base features from a video file')
parser.add_argument('--feature_folder', type=str, help='path to the destination feature folder')
parser.add_argument('--video_folder', type=str, help='path to the destination feature folder')
parser.add_argument('--model_name', type=str, help='path to the model name')
parser.add_argument('--video_type',type=str,default='shot',help='path to the video type')
parser.add_argument('--shot_subfolder', type=str, help='path to the shot subfolder')
parser.add_argument('--shot_file_list',type=str,help='path to the shot file list')
args=parser.parse_args()
model_name=args.model_name
shot_subfolder=args.shot_subfolder
video_type=args.video_type
video_folder=args.video_folder
shot_file_list=args.shot_file_list
#declare vit models from timm specification
print('Loading model')
model = ViTForImageClassification.from_pretrained(model_name)
model.config.return_dict=True
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
model.eval()
processor=ViTFeatureExtractor.from_pretrained(model_name)
#print layer wise names
#declare the transforms
print('Loaded model')
#load the model along with the logits
# h1 = model.pre_logits.register_forward_hook(getActivation('pre_logits'))
if(video_type=='shot'):
#read the list of shot files already processed
with open(shot_file_list,'r') as f:
shot_filenames=f.readlines()
shot_filenames=[x.split("\n")[0].split("/")[-1] for x in shot_filenames]
#shot subfolder
#shot_subfolder=os.path.join(args.feature_folder,args.shot_subfolder)
shot_folder_name=os.path.join(video_folder,shot_subfolder)
#video file list
video_file_list=os.listdir(shot_folder_name)
for video_file in tqdm(video_file_list):
video_subfolder=os.path.join(shot_folder_name,video_file)
pkl_filename=video_file+'.pkl'
if(pkl_filename in shot_filenames):
print('Already processed',pkl_filename)
else:
destination_file=os.path.join(args.feature_folder,pkl_filename)
if(os.path.exists(destination_file) is False):
shot_list=os.listdir(video_subfolder) #list of shots
shot_dict=dict()
for shot_file in tqdm(shot_list):
shot_filename=os.path.join(video_subfolder,shot_file)
feat_list,_=run_frame_wise_feature_inference(model,processor,shot_filename,device)
shot_dict[shot_file]=feat_list
#save the shot_dict
with open(destination_file,'wb') as f:
pickle.dump(shot_dict,f)
else:
print('Already exists',pkl_filename)