Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of ArcFace and MTCNN #40

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@ Open [postdoc position](https://mycore.core-cloud.net/public.php?service=files&t

## Installation

See https://github.com/onnx/models/blob/master/models/face_recognition/ArcFace/arcface_inference.ipynb to install ArcFace and MTCNN models

In python :
```python
import sys
import mxnet as mx
import os
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from mxnet.contrib.onnx.onnx2mx.import_model import import_model

PATH_DATA=''

# Download onnx model
mx.test_utils.download(dirname=PATH_DATA, url='https://s3.amazonaws.com/onnx-model-zoo/arcface/resnet100.onnx')

for i in range(4):
mx.test_utils.download(dirname=PATH_DATA+'mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}-0001.params'.format(i+1))
mx.test_utils.download(dirname=PATH_DATA+'mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}-symbol.json'.format(i+1))
mx.test_utils.download(dirname=PATH_DATA+'mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}.caffemodel'.format(i+1))
mx.test_utils.download(dirname=PATH_DATA+'mtcnn-model', url='https://s3.amazonaws.com/onnx-model-zoo/arcface/mtcnn-model/det{}.prototxt'.format(i+1))
```

Create a new `conda` environment:

```bash
Expand Down Expand Up @@ -44,3 +68,5 @@ $ jupyter notebook --notebook-dir="pyannote-video/doc"
## Documentation

No proper documentation for the time being...

When you launch `pyannote_face.py extract`, ther arguments are the movie, the file .track.txt, the path to the folder mtcnn-model, the model resnet100.onnx, the file .landmarks.txt and the file .embedding.txt
210 changes: 156 additions & 54 deletions pyannote/video/face/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,102 @@

# AUTHORS
# Herve BREDIN - http://herve.niderb.fr
# Benjamin MAURICE - [email protected]

"""Face processing"""

import numpy as np
import dlib
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import sys
import numpy as np
import mxnet as mx
import os
import sklearn
from sklearn.preprocessing import normalize
from pyannote.video.face.mtcnn_detector import MtcnnDetector
from skimage import transform as trans
from mxnet.contrib.onnx.onnx2mx.import_model import import_model

#print(os.path.dirname(os.path.abspath(__file__)))
#print(os.getcwd())

DLIB_SMALLEST_FACE = 36

def get_model(ctx, model):
image_size = (112,112)
# Import ONNX model
sym, arg_params, aux_params = import_model(model)
# Define and binds parameters to the network
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
return model

def preprocess(img, bbox=None, landmark=None, **kwargs):
M = None
image_size = []
str_image_size = kwargs.get('image_size', '')
# Assert input shape
if len(str_image_size)>0:
image_size = [int(x) for x in str_image_size.split(',')]
if len(image_size)==1:
image_size = [image_size[0], image_size[0]]
assert len(image_size)==2
assert image_size[0]==112
assert image_size[0]==112 or image_size[1]==96

# Do alignment using landmark points
if landmark is not None:
assert len(image_size)==2
src = np.array([
[30.2946, 51.6963],
[65.5318, 51.5014],
[48.0252, 71.7366],
[33.5493, 92.3655],
[62.7299, 92.2041] ], dtype=np.float32 )
if image_size[1]==112:
src[:,0] += 8.0
dst = landmark.astype(np.float32)
tform = trans.SimilarityTransform()
tform.estimate(dst, src)
M = tform.params[0:2,:]
assert len(image_size)==2
warped = cv2.warpAffine(img,M,(image_size[1],image_size[0]), borderValue = 0.0)
return warped

# If no landmark points available, do alignment using bounding box. If no bounding box available use center crop
if M is None:
if bbox is None:
det = np.zeros(4, dtype=np.int32)
det[0] = int(img.shape[1]*0.0625)
det[1] = int(img.shape[0]*0.0625)
det[2] = img.shape[1] - det[0]
det[3] = img.shape[0] - det[1]
else:
det = bbox
margin = kwargs.get('margin', 44)
bb = np.zeros(4, dtype=np.int32)
bb[0] = np.maximum(det[0]-margin/2, 0)
bb[1] = np.maximum(det[1]-margin/2, 0)
bb[2] = np.minimum(det[2]+margin/2, img.shape[1])
bb[3] = np.minimum(det[3]+margin/2, img.shape[0])
ret = img[bb[1]:bb[3],bb[0]:bb[2],:]
if len(image_size)>0:
ret = cv2.resize(ret, (image_size[1], image_size[0]))
return ret

def get_feature(model,aligned):
input_blob = np.expand_dims(aligned, axis=0)
data = mx.nd.array(input_blob)
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
embedding = model.get_outputs()[0].asnumpy()
#embedding = sklearn.preprocessing.normalize(embedding).flatten()
embedding = normalize(embedding).flatten()
return embedding


class Face(object):
"""Face processing"""
Expand All @@ -44,50 +131,92 @@ def __init__(self, landmarks=None, embedding=None):
Parameters
----------
landmarks : str
Path to dlib's 68 facial landmarks predictor model.
Path to MTCNN facial landmarks predictor model.
embedding : str
Path to dlib's face embedding model.
Path to ArcFace face embedding model.
"""
super(Face, self).__init__()

# face detection
self.face_detector_ = dlib.get_frontal_face_detector()
# Determine and set context
if len(mx.test_utils.list_gpus())==0:
ctx = mx.cpu()
else:
ctx = mx.gpu(0)

# landmark detection
# face detection
if landmarks is not None:
self.shape_predictor_ = dlib.shape_predictor(landmarks)
det_threshold = [0.6,0.7,0.8]
self.face_detector_ = MtcnnDetector(model_folder=landmarks, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=det_threshold)

# face embedding
if embedding is not None:
self.face_recognition_ = dlib.face_recognition_model_v1(embedding)

def iterfaces(self, rgb):
"""Iterate over all detected faces"""
for face in self.face_detector_(rgb, 1):
yield face

def get_landmarks(self, rgb, face):
return self.shape_predictor_(rgb, face)
#return np.float32([(p.x, p.y) for p in landmarks.parts()])

def get_embedding(self, rgb, landmarks):
embedding = self.face_recognition_.compute_face_descriptor(
rgb, landmarks)
return embedding
self.face_recognition_ = get_model(ctx , embedding)

def get_debug(self, image, face, landmarks):
"""Return face with overlaid landmarks"""
copy = image.copy()
for p in landmarks.parts():
x, y = p.x, p.y
for p in landmarks:
x, y = p[0], p[1]
cv2.rectangle(copy, (x, y), (x, y), (0, 255, 0), 2)
copy = copy[face.top():face.bottom(),
face.left():face.right()]
copy = cv2.resize(copy, (self.size, self.size))
return copy

def iter_data(self, rgb, return_landmarks=False, return_embedding=False,
return_debug=False):

ret = self.face_detector_.detect_face(rgb, det_type = 0)
#print(ret is None, ret)
if ret is not None:
bbox, points = ret
#print(bbox.shape)
if bbox.shape[0]!=0:
for id_f, face in enumerate(bbox):
# yield face if nothing else is asked for
if not (return_landmarks or return_embedding or return_debug):
yield face
continue

# always return face as first tuple element
result = (face, )

# compute landmarks
landmarks = []
mid_size = int(len(points[id_f])/2)
for i in range(mid_size):
landmarks.append((points[id_f][i], points[id_f][i+mid_size]))
landmarks = np.asarray(landmarks)
#[[135 166 147 127 158 95 101 108 128 135]] -> [(135,95), (166, 101), (147,108), (127,128), (158,135)]

# append landmarks
if return_landmarks:
result = result + (landmarks, )

# compute and append embedding
if return_embedding:
points_ = points[id_f,:].reshape((2,5)).T
# Call preprocess() to generate aligned images
nimg = preprocess(rgb, face, points_, image_size='112,112')
nimg = cv2.cvtColor(nimg, cv2.COLOR_BGR2RGB)
aligned = np.transpose(nimg, (2,0,1))
embedding = get_feature(self.face_recognition_, aligned)
result = result + (embedding, )
#[[107.74201095 61.86504232 182.06073904 163.76174854 0.99981207]]

# compute and append debug image
if return_debug:
debug = self.get_debug(rgb, face, landmarks)
result = result + (debug, )

yield result
return None

def iterfaces(self, rgb):
return self.iter_data(rgb)

def __call__(self, rgb, return_landmarks=False, return_embedding=False,
return_debug=False):
return_debug=False):
"""Iterate over all faces

Parameters
Expand All @@ -102,31 +231,4 @@ def __call__(self, rgb, return_landmarks=False, return_embedding=False,
Whether to yield debugging image. Defaults to False.
"""

for face in self.iterfaces(rgb):

# yield face if nothing else is asked for
if not (return_landmarks or return_embedding or return_debug):
yield face
continue

# always return face as first tuple element
result = (face, )

# compute landmarks
landmarks = self.get_landmarks(rgb, face)

# append landmarks
if return_landmarks:
result = result + (landmarks, )

# compute and append embedding
if return_embedding:
embedding = self.get_embedding(rgb, landmarks)
result = result + (embedding, )

# compute and append debug image
if return_debug:
debug = self.get_debug(rgb, face, landmarks)
result = result + (debug, )

yield result
return self.iter_data(rgb, return_landmarks, return_embedding, return_debug)
Loading