-
-
Notifications
You must be signed in to change notification settings - Fork 192
/
Copy pathembedder_inference.py
77 lines (55 loc) · 2.59 KB
/
embedder_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Main"""
import torch
from dataset.video_extraction_conversion import select_frames, select_images_frames, generate_cropped_landmarks
from network.blocks import *
from network.model import Embedder
import face_alignment
import numpy as np
from params.params import path_to_chkpt
"""Hyperparameters and config"""
device = torch.device("cuda:0")
cpu = torch.device("cpu")
path_to_e_hat_video = 'e_hat_video.tar'
path_to_e_hat_images = 'e_hat_images.tar'
path_to_video = 'test_vid.mp4'
path_to_images = 'examples/fine_tuning/test_images'
T = 32
face_aligner = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device ='cuda:0')
"""Loading Embedder input"""
frame_mark_video = select_frames(path_to_video , T)
frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50, face_aligner=face_aligner)
frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(dtype = torch.float) #T,2,256,256,3
frame_mark_video = frame_mark_video.transpose(2,4).to(device)/255 #T,2,3,256,256
f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256
frame_mark_images = select_images_frames(path_to_images)
frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50, face_aligner=face_aligner)
frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(dtype = torch.float) #T,2,256,256,3
frame_mark_images = frame_mark_images.transpose(2,4).to(device)/255 #T,2,3,256,256
f_lm_images = frame_mark_images.unsqueeze(0) #1,T,2,3,256,256
E = Embedder(256).to(device)
E.eval()
"""Loading from past checkpoint"""
checkpoint = torch.load(path_to_chkpt, map_location=cpu)
E.load_state_dict(checkpoint['E_state_dict'])
"""Inference"""
with torch.no_grad():
#forward
# Calculate average encoding vector for video
f_lm = f_lm_video
f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224
e_vectors = E(f_lm_compact[:,0,:,:,:], f_lm_compact[:,1,:,:,:]) #BxT,512,1
e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1
e_hat_video = e_vectors.mean(dim=1)
f_lm = f_lm_images
f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224
e_vectors = E(f_lm_compact[:,0,:,:,:], f_lm_compact[:,1,:,:,:]) #BxT,512,1
e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1
e_hat_images = e_vectors.mean(dim=1)
print('Saving e_hat...')
torch.save({
'e_hat': e_hat_video
}, path_to_e_hat_video)
torch.save({
'e_hat': e_hat_images
}, path_to_e_hat_images)
print('...Done saving')