-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference.py
128 lines (105 loc) · 4.47 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import time
import cv2
import os
import glob
import argparse
import numpy as np
import torch
import timm
from utils.utils_inference import plot_kpt, plot_vertices
from utils.utils_inference import process_input_mp
from utils.util import get_config
def main(config_file):
cfg = get_config(config_file)
local_rank = args.local_rank
isGPU = True
# ---- load detectors
if cfg.is_dlib:
import dlib
detector_path = 'data/net-data/mmod_human_face_detector.dat'
face_detector = dlib.cnn_face_detection_model_v1(
detector_path)
if cfg.is_mp:
import mediapipe as mp
mp_face_detection = mp.solutions.face_detection
checkpoint_path = '/mnt/satan/code/myGit/3DFaceLight/MyTest/checkpoints/mobile_net/Oct21/net_39.pth'
if isGPU:
net = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=cfg.num_verts * 3).to(local_rank)
else:
net = timm.create_model('mobilenetv2_100', pretrained=True, num_classes=cfg.num_verts * 3)
net.load_state_dict(torch.load(checkpoint_path))
net.eval()
# inference
if not cfg.use_cam: # on test-img
print("[*] Processing on images in {}. Press 's' to save result.".format(cfg.eval_img_path))
img_paths = glob.glob(os.path.join(cfg.eval_img_path, '*.jpg'))
for img_path in img_paths:
print(img_path)
img = cv2.imread(img_path)
vertices = process_input_mp(img_path, net, mp_face_detection, cuda=isGPU)
if vertices is None:
continue
kpt_filter = np.loadtxt(cfg.filtered_kpt_500).astype(int)
kpt = vertices[kpt_filter, :]
result_list = [img,
plot_vertices(img, vertices),
plot_kpt(img, kpt)]
cv2.imshow('Input', result_list[0])
cv2.imshow('Sparse alignment', result_list[1])
cv2.imshow('Sparse alignment GT', result_list[2])
cv2.moveWindow('Input', 0, 0)
cv2.moveWindow('Sparse alignment', 500, 0)
cv2.moveWindow('Sparse alignment GT', 1000, 0)
key = cv2.waitKey(0)
if key == ord('q'):
exit()
elif key == ord('s'):
save_path = os.path.dirname(checkpoint_path)
cv2.imwrite(os.path.join(save_path, os.path.basename(img_path)),
np.concatenate(result_list, axis=1))
print("Result saved in {}".format(save_path))
else: # webcam demo
cap = cv2.VideoCapture(0)
start_time = time.time()
count = 1
while (True):
_, image = cap.read()
pos = process_input_mp(image, net, mp_face_detection, cuda=isGPU)
fps_str = 'FPS: %.2f' % (1 / (time.time() - start_time))
start_time = time.time()
cv2.putText(image, fps_str, (25, 25),
cv2.FONT_HERSHEY_DUPLEX, 0.75, (0, 255, 0), 2)
key = cv2.waitKey(1)
if pos is None:
cv2.waitKey(1)
cv2.destroyWindow('Sparse alignment')
cv2.destroyWindow('Dense alignment')
cv2.destroyWindow('Pose')
if key & 0xFF == ord('q'):
break
continue
else:
vertices = pos
kpt_filter = np.loadtxt(cfg.filtered_kpt_500).astype(int)
kpt = vertices[kpt_filter, :]
result_list = [image,
plot_vertices(image, vertices),
plot_kpt(image, kpt)]
cv2.imshow('Input', result_list[0])
cv2.imshow('Sparse alignment', result_list[1])
cv2.imshow('Sparse alignment GT', result_list[2])
cv2.moveWindow('Input', 0, 0)
cv2.moveWindow('Sparse alignment', 500, 0)
cv2.moveWindow('Sparse alignment GT', 1000, 0)
if key & 0xFF == ord('q'):
cap.release()
# out.release()
cv2.destroyAllWindows()
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Perform model training.')
parser.add_argument('--configfile', default='configs/config.py', help='path to the configfile')
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
print(os.getcwd())
args = parser.parse_args()
main(args.configfile)