-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdata.py
30 lines (26 loc) · 953 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import pickle
import numpy as np
from tqdm import tqdm
import cv2
def load_data(face_img_path, aus_pkl_path):
'''
:param:
face_img_path: folder path of face images
aus_pkl_path: path of 'aus.pkl'
:return:
imgs: RGB face np.array, shape [n, 128, 128, 3]
aus: Action Unit np.array, shape [n, 17]
'''
imgs_names = os.listdir(face_img_path)
imgs_names.sort()
with open(aus_pkl_path, 'rb') as f:
aus_dict = pickle.load(f)
imgs = np.zeros((len(imgs_names), 128, 128, 3), dtype=np.float32)
aus = np.zeros((len(imgs_names), 17), dtype=np.float32)
for i, img_name in tqdm(enumerate(imgs_names)):
img = cv2.imread(os.path.join(face_img_path, img_name))[:, :, ::-1] # BGR -> RGB
img = img / 127.5 - 1 # rescale within [-1,1]
imgs[i] = img
aus[i] = aus_dict[img_name.split('.')[0]] / 5
return imgs, aus